shithub: riscv

Download patch

ref: 56836bfdbdca9fd6a5b608d249d178a22d3337d8
parent: be5992955d4e417ca625b07af93a800464d4c11f
author: cinap_lenrek <[email protected]>
date: Sat Sep 14 15:19:08 EDT 2013

tls: fix various tlsClient()/tlsServer() related bugs

- TLSconn structure on stack but not initialized (zeroed)
- original filedescriptor double closed in error case
- original filedescriptor leaked in success case
- leaked TLSconn.sessionID and TLSconn.cert
- clarify in pushtls(2) and pushssl(2)

--- a/sys/src/cmd/ip/ftpfs/proto.c
+++ b/sys/src/cmd/ip/ftpfs/proto.c
@@ -55,6 +55,18 @@
 static int	getpassword(char*, char*);
 static int	nw_mode(char dirlet, char *s);
 
+static void
+starttls(int *fd)
+{
+	TLSconn conn;
+	
+	memset(&conn, 0, sizeof(conn));
+	if((*fd = tlsClient(*fd, &conn)) < 0)
+		fatal("starting tls: %r");
+	free(conn.cert);
+	free(conn.sessionID);
+}
+
 /*
  *  connect to remote server, default network is "tcp/ip"
  */
@@ -63,7 +75,6 @@
 {
 	char *p;
 	char dir[Maxpath];
-	TLSconn conn;
 
 	Binit(&stdin, 0, OREAD);	/* init for later use */
 
@@ -93,11 +104,8 @@
 		if(getreply(&ctlin, msg, sizeof(msg), 1) != Success)
 			fatal("bad auth tls");
 
-		ctlfd = tlsClient(ctlfd, &conn);
-		if(ctlfd < 0)
-			fatal("starting tls: %r");
-		free(conn.cert);
-
+		starttls(&ctlfd);
+	
 		Binit(&ctlin, ctlfd, OREAD);
 
 		sendrequest("PBSZ", "0");
@@ -1227,7 +1235,6 @@
 	int cfd, dfd, rv;
 	char newdir[Maxpath];
 	char datafile[Maxpath + 6];
-	TLSconn conn;
 
 	if(port() < 0)
 		return TempFail;
@@ -1253,13 +1260,8 @@
 	if(dfd < 0)
 		fatal("opening data connection");
 
-	if(usetls){
-		memset(&conn, 0, sizeof(conn));
-		dfd = tlsClient(dfd, &conn);
-		if(dfd < 0)
-			fatal("starting tls: %r");
-		free(conn.cert);
-	}
+	if(usetls)
+		starttls(&dfd);
 
 	Binit(&dbuf, dfd, mode);
 	*bpp = &dbuf;
@@ -1277,7 +1279,6 @@
 	char *f[6];
 	char *p;
 	int x, fd;
-	TLSconn conn;
 
 	if(nopassive)
 		return Impossible;
@@ -1327,13 +1328,9 @@
 		return x;
 	}
 
-	if(usetls){
-		memset(&conn, 0, sizeof(conn));
-		fd = tlsClient(fd, &conn);
-		if(fd < 0)
-			fatal("starting tls: %r");
-		free(conn.cert);
-	}
+	if(usetls)
+		starttls(&fd);
+
 	Binit(&dbuf, fd, mode);
 
 	*bpp = &dbuf;
--- a/sys/src/cmd/ip/httpd/httpd.c
+++ b/sys/src/cmd/ip/httpd/httpd.c
@@ -172,7 +172,6 @@
 	NetConnInfo *nci;
 	char ndir[NETPATHLEN], dir[NETPATHLEN], *p, *scheme;
 	int ctl, nctl, data, t, ok, spotchk;
-	TLSconn conn;
 
 	spotchk = 0;
 	syslog(0, HTTPLOG, "httpd starting");
@@ -217,6 +216,8 @@
 			 */
 			data = accept(ctl, ndir);
 			if(data >= 0 && certificate != nil){
+				TLSconn conn;
+
 				memset(&conn, 0, sizeof(conn));
 				conn.cert = certificate;
 				conn.certlen = certlen;
@@ -223,6 +224,8 @@
 				if (certchain != nil)
 					conn.chain = certchain;
 				data = tlsServer(data, &conn);
+				free(conn.cert);
+				free(conn.sessionID);
 				scheme = "https";
 			}else
 				scheme = "http";
--- a/sys/src/cmd/ip/httpfile.c
+++ b/sys/src/cmd/ip/httpfile.c
@@ -186,12 +186,11 @@
 {
 	TLSconn conn;
 
+	memset(&conn, 0, sizeof(conn));
 	if((fd=tlsClient(fd, &conn)) < 0)
 		sysfatal("tlsclient: %r");
-
-	if(conn.cert != nil)
-		free(conn.cert);
-
+	free(conn.cert);
+	free(conn.sessionID);
 	return fd;
 }
 
--- a/sys/src/cmd/tlsclient.c
+++ b/sys/src/cmd/tlsclient.c
@@ -38,7 +38,7 @@
 void
 main(int argc, char **argv)
 {
-	int fd, netfd, debug;
+	int fd, debug;
 	uchar digest[20];
 	TLSconn *conn;
 	char *addr, *file, *filex, *ccert;
@@ -78,7 +78,7 @@
 	}
 
 	addr = argv[0];
-	if((netfd = dial(addr, 0, 0, 0)) < 0)
+	if((fd = dial(addr, 0, 0, 0)) < 0)
 		sysfatal("dial %s: %r", addr);
 
 	conn = (TLSconn*)mallocz(sizeof *conn, 1);
@@ -86,7 +86,7 @@
 		conn->cert = readcert(ccert, &conn->certlen);
 	if(debug)
 		conn->trace = reporter;
-	fd = tlsClient(netfd, conn);
+	fd = tlsClient(fd, conn);
 	if(fd < 0)
 		sysfatal("tlsclient: %r");
 	if(thumb){
@@ -98,8 +98,6 @@
 			sysfatal("server certificate %.*H not recognized", SHA1dlen, digest);
 		}
 	}
-	free(conn->cert);
-	close(netfd);
 
 	rfork(RFNOTEG);
 	switch(fork()){
--- a/sys/src/cmd/tlssrv.c
+++ b/sys/src/cmd/tlssrv.c
@@ -147,7 +147,7 @@
 	if(conn == nil)
 		sysfatal("out of memory");
 	conn->chain = readcertchain(cert);
-	if (conn->chain == nil)
+	if(conn->chain == nil)
 		sysfatal("can't read certificate");
 	conn->cert = conn->chain->pem;
 	conn->certlen = conn->chain->pemlen;
--- a/sys/src/cmd/upas/fs/pop3.c
+++ b/sys/src/cmd/upas/fs/pop3.c
@@ -129,6 +129,9 @@
 		err = "tls error";
 		goto out;
 	}
+	pop->fd = fd;
+	Binit(&pop->bin, pop->fd, OREAD);
+	Binit(&pop->bout, pop->fd, OWRITE);
 	if(conn.cert==nil || conn.certlen <= 0){
 		err = "server did not provide TLS certificate";
 		goto out;
@@ -140,17 +143,10 @@
 		err = "bad server certificate";
 		goto out;
 	}
-	close(pop->fd);
-	pop->fd = fd;
 	pop->encrypted = 1;
-	Binit(&pop->bin, pop->fd, OREAD);
-	Binit(&pop->bout, pop->fd, OWRITE);
-	fd = -1;
 out:
 	free(conn.sessionID);
 	free(conn.cert);
-	if(fd >= 0)
-		close(fd);
 	return err;
 }
 
--- a/sys/src/cmd/upas/pop3/pop3.c
+++ b/sys/src/cmd/upas/pop3/pop3.c
@@ -551,21 +551,23 @@
 static int
 stlscmd(char*)
 {
-	int fd;
 	TLSconn conn;
+	int fd;
 
 	if(didtls)
 		return senderr("tls already started");
 	if(!tlscert)
 		return senderr("don't have any tls credentials");
-	sendok("");
-	Bflush(&out);
-
 	memset(&conn, 0, sizeof conn);
-	conn.cert = tlscert;
 	conn.certlen = ntlscert;
+	conn.cert = malloc(ntlscert);
+	if(conn.cert == nil)
+		return senderr("out of memory");
+	memmove(conn.cert, tlscert, ntlscert);
 	if(debug)
 		conn.trace = trace;
+	sendok("");
+	Bflush(&out);
 	fd = tlsServer(0, &conn);
 	if(fd < 0)
 		sysfatal("tlsServer: %r");
@@ -572,6 +574,8 @@
 	dup(fd, 0);
 	dup(fd, 1);
 	close(fd);
+	free(conn.cert);
+	free(conn.sessionID);
 	Binit(&in, 0, OREAD);
 	Binit(&out, 1, OWRITE);
 	didtls = 1;
--- a/sys/src/cmd/upas/smtp/smtp.c
+++ b/sys/src/cmd/upas/smtp/smtp.c
@@ -324,26 +324,30 @@
 {
 	TLSconn *c;
 	Thumbprint *goodcerts;
-	char *h;
+	char *h, *err;
 	int fd;
 	uchar hash[SHA1dlen];
 
+	err = Giveup;
 	c = mallocz(sizeof(*c), 1);
 	if (c == nil)
-		return Giveup;
+		return err;
 
 	fd = tlsClient(Bfildes(&bout), c);
 	if (fd < 0) {
 		syslog(0, "smtp", "tlsClient to %q: %r", ddomain);
-		free(c);
-		return Giveup;
+		goto Out;
 	}
+	Bterm(&bout);
+	Binit(&bout, fd, OWRITE);
+	fd = dup(fd, Bfildes(&bin));
+	Bterm(&bin);
+	Binit(&bin, fd, OREAD);
+
 	goodcerts = initThumbprints(smtpthumbs, smtpexclthumbs);
 	if (goodcerts == nil) {
 		syslog(0, "smtp", "bad thumbprints in %s", smtpthumbs);
-		close(fd);
-		free(c);
-		return Giveup;		/* how to recover? TLS is started */
+		goto Out;
 	}
 	/* compute sha1 hash of remote's certificate, see if we know it */
 	sha1(c->cert, c->certlen, hash, nil);
@@ -350,33 +354,23 @@
 	if (!okThumbprint(hash, goodcerts)) {
 		/* TODO? if not excluded, add hash to thumb list */
 		h = malloc(2*sizeof hash + 1);
-		if (h != nil) {
-			enc16(h, 2*sizeof hash + 1, hash, sizeof hash);
-			// fprint(2, "x509 sha1=%s", h);
-			syslog(0, "smtp",
-		"remote cert. has bad thumbprint: x509 sha1=%s server=%q",
-				h, ddomain);
-			free(h);
-		}
-		close(fd);
-		free(c);
-		return Giveup;		/* how to recover? TLS is started */
+		if (h == nil)
+			goto Out;
+		enc16(h, 2*sizeof hash + 1, hash, sizeof hash);
+		syslog(0, "smtp", "remote cert. has bad thumbprint: x509 sha1=%s server=%q",
+			h, ddomain);
+		free(h);
+		goto Out;
 	}
-	freeThumbprints(goodcerts);
-	free(c);
-	Bterm(&bin);
-	Bterm(&bout);
-
-	/*
-	 * set up bin & bout to use the TLS fd, i/o upon which generates
-	 * i/o on the original, underlying fd.
-	 */
-	Binit(&bin, fd, OREAD);
-	fd = dup(fd, -1);
-	Binit(&bout, fd, OWRITE);
-
 	syslog(0, "smtp", "started TLS to %q", ddomain);
-	return nil;
+	err = nil;
+Out:
+	if(goodcerts != nil)
+		freeThumbprints(goodcerts);
+	free(c->cert);
+	free(c->sessionID);
+	free(c);
+	return err;
 }
 
 /*
--- a/sys/src/cmd/upas/smtp/smtpd.c
+++ b/sys/src/cmd/upas/smtp/smtpd.c
@@ -1551,38 +1551,33 @@
 {
 	int certlen, fd;
 	uchar *cert;
-	TLSconn *conn;
+	TLSconn conn;
 
 	if (tlscert == nil) {
 		reply("500 5.5.1 illegal command or bad syntax\r\n");
 		return;
 	}
-	conn = mallocz(sizeof *conn, 1);
 	cert = readcert(tlscert, &certlen);
-	if (conn == nil || cert == nil) {
-		if (conn != nil)
-			free(conn);
+	if (cert == nil) {
 		reply("454 4.7.5 TLS not available\r\n");
 		return;
 	}
 	reply("220 2.0.0 Go ahead make my day\r\n");
-	conn->cert = cert;
-	conn->certlen = certlen;
-	fd = tlsServer(Bfildes(&bin), conn);
+	memset(&conn, 0, sizeof(conn));
+	conn.cert = cert;
+	conn.certlen = certlen;
+	fd = tlsServer(Bfildes(&bin), &conn);
 	if (fd < 0) {
-		free(cert);
-		free(conn);
 		syslog(0, "smtpd", "TLS start-up failed with %s", him);
-
-		/* force the client to hang up */
-		close(Bfildes(&bin));		/* probably fd 0 */
-		close(1);
 		exits("tls failed");
 	}
 	Bterm(&bin);
-	Binit(&bin, fd, OREAD);
-	if (dup(fd, 1) < 0)
+	if (dup(fd, 0) < 0 || dup(fd, 1) < 0)
 		fprint(2, "dup of %d failed: %r\n", fd);
+	close(fd);
+	Binit(&bin, 0, OREAD);
+	free(conn.cert);
+	free(conn.sessionID);
 	passwordinclear = 1;
 	syslog(0, "smtpd", "started TLS with %s", him);
 }
--- a/sys/src/cmd/vnc/vncs.c
+++ b/sys/src/cmd/vnc/vncs.c
@@ -152,7 +152,7 @@
 		exits(nil);
 	}
 
-	if(altnet && !cert)
+	if(altnet && cert == nil)
 		sysfatal("announcing on alternate network requires TLS (-c)");
 
 	if(argc == 0)
@@ -524,7 +524,6 @@
 {
 	char buf[32];
 	int fd;
-	TLSconn conn;
 
 	/* caller returns to listen */
 	switch(rfork(RFPROC|RFMEM|RFNAMEG)){
@@ -546,6 +545,8 @@
 	}
 
 	if(cert != nil){
+		TLSconn conn;
+
 		memset(&conn, 0, sizeof conn);
 		conn.cert = readcert(cert, &conn.certlen);
 		if(conn.cert == nil){
@@ -556,11 +557,9 @@
 		if(fd < 0){
 			fprint(2, "%V: tlsServer: %r; hanging up\n", v);
 			free(conn.cert);
-			if(conn.sessionID)
-				free(conn.sessionID);
+			free(conn.sessionID);
 			exits(nil);
 		}
-		close(v->datafd);
 		v->datafd = fd;
 		free(conn.cert);
 		free(conn.sessionID);
--- a/sys/src/cmd/vnc/vncv.c
+++ b/sys/src/cmd/vnc/vncv.c
@@ -84,7 +84,6 @@
 	int p, dfd, cfd, shared;
 	char *keypattern, *addr, *label;
 	Point d;
-	TLSconn conn;
 
 	keypattern = nil;
 	shared = 0;
@@ -123,10 +122,14 @@
 	if(dfd < 0)
 		sysfatal("cannot dial %s: %r", addr);
 	if(tls){
-		dfd = tlsClient(dfd, &conn);
-		if(dfd < 0)
+		TLSconn conn;
+
+		memset(&conn, 0, sizeof(conn));
+		if((dfd = tlsClient(dfd, &conn)) < 0)
 			sysfatal("tlsClient: %r");
 		/* XXX check thumbprint */
+		free(conn.cert);
+		free(conn.sessionID);
 	}
 	vnc = vncinit(dfd, cfd, nil);
 
--- a/sys/src/cmd/webfs/http.c
+++ b/sys/src/cmd/webfs/http.c
@@ -65,7 +65,7 @@
 {
 	char addr[128];
 	Hconn *h, *p;
-	int fd, ctl, ofd;
+	int fd, ofd, ctl;
 
 	snprint(addr, sizeof(addr), "tcp!%s!%s", u->host, u->port ? u->port : u->scheme);
 
@@ -90,18 +90,16 @@
 		return nil;
 	if(strcmp(u->scheme, "https") == 0){
 		char err[ERRMAX];
-		TLSconn *tc;
+		TLSconn conn;
 
-		tc = emalloc(sizeof(*tc));
 		strcpy(err, "tls error");
-		if((fd = tlsClient(ofd = fd, tc)) < 0)
+		memset(&conn, 0, sizeof(conn));
+		if((fd = tlsClient(ofd = fd, &conn)) < 0)
 			errstr(err, sizeof(err));
-		close(ofd);
-		/* BUG: should validate but how? */
-		free(tc->cert);
-		free(tc->sessionID);
-		free(tc);
+		free(conn.cert);
+		free(conn.sessionID);
 		if(fd < 0){
+			close(ofd);
 			close(ctl);
 			if(debug) fprint(2, "tlsClient: %s\n", err);
 			errstr(err, sizeof(err));
--- a/sys/src/libsec/port/tlshand.c
+++ b/sys/src/libsec/port/tlshand.c
@@ -335,8 +335,8 @@
 		return -1;
 	}
 	buf[n] = 0;
-	sprint(conn->dir, "#a/tls/%s", buf);
-	sprint(dname, "#a/tls/%s/hand", buf);
+	snprint(conn->dir, sizeof(conn->dir), "#a/tls/%s", buf);
+	snprint(dname, sizeof(dname), "#a/tls/%s/hand", buf);
 	hand = open(dname, ORDWR);
 	if(hand < 0){
 		close(ctl);
@@ -344,27 +344,32 @@
 	}
 	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
 	tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain);
-	sprint(dname, "#a/tls/%s/data", buf);
+	snprint(dname, sizeof(dname), "#a/tls/%s/data", buf);
 	data = open(dname, ORDWR);
-	close(fd);
 	close(hand);
 	close(ctl);
-	if(data < 0)
+	if(data < 0 || tls == nil){
+		if(tls != nil)
+			tlsConnectionFree(tls);
 		return -1;
-	if(tls == nil){
-		close(data);
-		return -1;
 	}
-	if(conn->cert)
-		free(conn->cert);
+	free(conn->cert);
 	conn->cert = 0;  // client certificates are not yet implemented
 	conn->certlen = 0;
 	conn->sessionIDlen = tls->sid->len;
 	conn->sessionID = emalloc(conn->sessionIDlen);
 	memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
-	if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0)
-		tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst,  tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize);
+	if(conn->sessionKey != nil
+	&& conn->sessionType != nil
+	&& strcmp(conn->sessionType, "ttls") == 0)
+		tls->sec->prf(
+			conn->sessionKey, conn->sessionKeylen,
+			tls->sec->sec, MasterSecretSize,
+			conn->sessionConst, 
+			tls->sec->crandom, RandomSize,
+			tls->sec->srandom, RandomSize);
 	tlsConnectionFree(tls);
+	close(fd);
 	return data;
 }
 
@@ -378,7 +383,7 @@
 	int n, data, ctl, hand;
 	TlsConnection *tls;
 
-	if(!conn)
+	if(conn == nil)
 		return -1;
 	ctl = open("#a/tls/clone", ORDWR);
 	if(ctl < 0)
@@ -389,14 +394,14 @@
 		return -1;
 	}
 	buf[n] = 0;
-	sprint(conn->dir, "#a/tls/%s", buf);
-	sprint(dname, "#a/tls/%s/hand", buf);
+	snprint(conn->dir, sizeof(conn->dir), "#a/tls/%s", buf);
+	snprint(dname, sizeof(dname), "#a/tls/%s/hand", buf);
 	hand = open(dname, ORDWR);
 	if(hand < 0){
 		close(ctl);
 		return -1;
 	}
-	sprint(dname, "#a/tls/%s/data", buf);
+	snprint(dname, sizeof(dname), "#a/tls/%s/data", buf);
 	data = open(dname, ORDWR);
 	if(data < 0){
 		close(hand);
@@ -407,7 +412,6 @@
 	tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->cert, conn->certlen, conn->trace);
 	close(hand);
 	close(ctl);
-	close(fd);
 	if(tls == nil){
 		close(data);
 		return -1;
@@ -418,9 +422,17 @@
 	conn->sessionIDlen = tls->sid->len;
 	conn->sessionID = emalloc(conn->sessionIDlen);
 	memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
-	if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0)
-		tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst,  tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize);
+	if(conn->sessionKey != nil
+	&& conn->sessionType != nil
+	&& strcmp(conn->sessionType, "ttls") == 0)
+		tls->sec->prf(
+			conn->sessionKey, conn->sessionKeylen,
+			tls->sec->sec, MasterSecretSize,
+			conn->sessionConst, 
+			tls->sec->crandom, RandomSize,
+			tls->sec->srandom, RandomSize);
 	tlsConnectionFree(tls);
+	close(fd);
 	return data;
 }