shithub: riscv

Download patch

ref: de80075fc6bdc6dc785a67db2deaa59df020cfa6
parent: 02ffb19904f03cad21dd10a774705b9152d89010
author: cinap_lenrek <[email protected]>
date: Mon Apr 3 21:59:17 EDT 2017

tlshand: fix mpint to bytes conversion, reorganize send/recv buffer, check for overflow in msgSend()

when converting mpint to bytes, always pad it to the size of
the modulus (RSA,DHE,ECDHE). mptobytes() now takes a byte len
parameter which the caller usually calculates from the group
modulus using mpsignif(). this bug sometimes caused "bad record mac"
after the handshake.

use a shared buffer, given that msgSend()/msgRecv() don't overlap
we can use the first half for sending, and the top half for
receiving, shifting down as neccesary. the space beween sendp and
recvp is free.

explicitely check for overflow in msgSend().

--- a/sys/src/libsec/port/tlshand.c
+++ b/sys/src/libsec/port/tlshand.c
@@ -17,7 +17,6 @@
 	TLSFinishedLen = 12,
 	SSL3FinishedLen = MD5dlen+SHA1dlen,
 	MaxKeyData = 160,	// amount of secret we may need
-	MaxChunk = 1<<15,
 	MAXdlen = SHA2_512dlen,
 	RandomSize = 32,
 	MasterSecretSize = 48,
@@ -100,13 +99,8 @@
 	HandshakeHash	handhash;
 	Finished	finished;
 
-	// input buffer for handshake messages
-	uchar recvbuf[MaxChunk];
-	uchar *rp, *ep;
-
-	// output buffer
-	uchar sendbuf[MaxChunk];
-	uchar *sendp;
+	uchar *sendp, *recvp, *recvw;
+	uchar buf[1<<16];
 } TlsConnection;
 
 typedef struct Msg{
@@ -444,7 +438,7 @@
 static int get16(uchar *p);
 static Bytes* newbytes(int len);
 static Bytes* makebytes(uchar* buf, int len);
-static Bytes* mptobytes(mpint* big);
+static Bytes* mptobytes(mpint* big, int len);
 static mpint* bytestomp(Bytes* bytes);
 static void freebytes(Bytes* b);
 static Ints* newints(int len);
@@ -696,6 +690,8 @@
 	c->hand = hand;
 	c->trace = trace;
 	c->version = ProtocolVersion;
+	c->sendp = c->buf;
+	c->recvp = c->recvw = &c->buf[sizeof(c->buf)];
 
 	memset(&m, 0, sizeof(m));
 	if(!msgRecv(c, &m)){
@@ -895,6 +891,7 @@
 	DHstate *dh = &sec->dh;
 	mpint *G, *P, *Y, *K;
 	Bytes *Yc;
+	int n;
 
 	if(p == nil || g == nil || Ys == nil)
 		return nil;
@@ -907,7 +904,8 @@
 
 	if(dh_new(dh, P, nil, G) == nil)
 		goto Out;
-	Yc = mptobytes(dh->y);
+	n = (mpsignif(P)+7)/8;
+	Yc = mptobytes(dh->y, n);
 	K = dh_finish(dh, Y);	/* zeros dh */
 	if(K == nil){
 		freebytes(Yc);
@@ -914,7 +912,7 @@
 		Yc = nil;
 		goto Out;
 	}
-	setMasterSecret(sec, mptobytes(K));
+	setMasterSecret(sec, mptobytes(K, n));
 
 Out:
 	mpfree(K);
@@ -934,6 +932,7 @@
 	ECpub *pub;
 	ECpoint K;
 	Bytes *Yc;
+	int n;
 
 	if(Ys == nil)
 		return nil;
@@ -959,8 +958,10 @@
 
 	ecgen(dom, Q);
 	ecmul(dom, pub, Q->d, &K);
-	setMasterSecret(sec, mptobytes(K.x));
-	Yc = newbytes(1 + 2*((mpsignif(dom->p)+7)/8));
+
+	n = (mpsignif(dom->p)+7)/8;
+	setMasterSecret(sec, mptobytes(K.x, n));
+	Yc = newbytes(1 + 2*n);
 	Yc->len = ecencodepub(dom, Q, Yc->data, Yc->len);
 
 	mpfree(K.x);
@@ -994,6 +995,8 @@
 	c->hand = hand;
 	c->trace = trace;
 	c->cert = nil;
+	c->sendp = c->buf;
+	c->recvp = c->recvw = &c->buf[sizeof(c->buf)];
 
 	c->version = ProtocolVersion;
 	tlsSecInitc(c->sec, c->version);
@@ -1257,14 +1260,13 @@
 static int
 msgSend(TlsConnection *c, Msg *m, int act)
 {
-	uchar *p; // sendp = start of new message;  p = write pointer
-	int nn, n, i;
+	uchar *p, *e; // sendp = start of new message;  p = write pointer; e = end pointer
+	int n, i;
 
-	if(c->sendp == nil)
-		c->sendp = c->sendbuf;
 	p = c->sendp;
+	e = c->recvp;
 	if(c->trace)
-		c->trace("send %s", msgPrint((char*)p, (sizeof(c->sendbuf)) - (p - c->sendbuf), m));
+		c->trace("send %s", msgPrint((char*)p, e - p, m));
 
 	p[0] = m->tag;	// header - fill in size later
 	p += 4;
@@ -1274,119 +1276,111 @@
 		tlsError(c, EInternalError, "can't encode a %d", m->tag);
 		goto Err;
 	case HClientHello:
-		// version
-		put16(p, m->u.clientHello.version);
-		p += 2;
-
-		// random
+		if(p+2+RandomSize > e)
+			goto Overflow;
+		put16(p, m->u.clientHello.version), p += 2;
 		memmove(p, m->u.clientHello.random, RandomSize);
 		p += RandomSize;
 
-		// sid
-		n = m->u.clientHello.sid->len;
-		p[0] = n;
-		memmove(p+1, m->u.clientHello.sid->data, n);
-		p += n+1;
+		if(p+1+(n = m->u.clientHello.sid->len) > e)
+			goto Overflow;
+		*p++ = n;
+		memmove(p, m->u.clientHello.sid->data, n);
+		p += n;
 
-		n = m->u.clientHello.ciphers->len;
-		put16(p, n*2);
-		p += 2;
-		for(i=0; i<n; i++) {
-			put16(p, m->u.clientHello.ciphers->data[i]);
-			p += 2;
-		}
+		if(p+2+(n = m->u.clientHello.ciphers->len) > e)
+			goto Overflow;
+		put16(p, n*2), p += 2;
+		for(i=0; i<n; i++)
+			put16(p, m->u.clientHello.ciphers->data[i]), p += 2;
 
-		n = m->u.clientHello.compressors->len;
-		p[0] = n;
-		memmove(p+1, m->u.clientHello.compressors->data, n);
-		p += n+1;
+		if(p+1+(n = m->u.clientHello.compressors->len) > e)
+			goto Overflow;
+		*p++ = n;
+		memmove(p, m->u.clientHello.compressors->data, n);
+		p += n;
 
-		if(m->u.clientHello.extensions == nil)
+		if(m->u.clientHello.extensions == nil
+		|| (n = m->u.clientHello.extensions->len) == 0)
 			break;
-		n = m->u.clientHello.extensions->len;
-		if(n == 0)
-			break;
-		put16(p, n);
-		memmove(p+2, m->u.clientHello.extensions->data, n);
-		p += n+2;
+		if(p+2+n > e)
+			goto Overflow;
+		put16(p, n), p += 2;
+		memmove(p, m->u.clientHello.extensions->data, n);
+		p += n;
 		break;
 	case HServerHello:
-		put16(p, m->u.serverHello.version);
-		p += 2;
-
-		// random
+		if(p+2+RandomSize > e)
+			goto Overflow;
+		put16(p, m->u.serverHello.version), p += 2;
 		memmove(p, m->u.serverHello.random, RandomSize);
 		p += RandomSize;
 
-		// sid
-		n = m->u.serverHello.sid->len;
-		p[0] = n;
-		memmove(p+1, m->u.serverHello.sid->data, n);
-		p += n+1;
+		if(p+1+(n = m->u.serverHello.sid->len) > e)
+			goto Overflow;
+		*p++ = n;
+		memmove(p, m->u.serverHello.sid->data, n);
+		p += n;
 
-		put16(p, m->u.serverHello.cipher);
-		p += 2;
-		p[0] = m->u.serverHello.compressor;
-		p += 1;
+		if(p+2+1 > e)
+			goto Overflow;
+		put16(p, m->u.serverHello.cipher), p += 2;
+		*p++ = m->u.serverHello.compressor;
 
-		if(m->u.serverHello.extensions == nil)
-			break;
-		n = m->u.serverHello.extensions->len;
-		if(n == 0)
+		if(m->u.serverHello.extensions == nil
+		|| (n = m->u.serverHello.extensions->len) == 0)
 			break;
-		put16(p, n);
-		memmove(p+2, m->u.serverHello.extensions->data, n);
-		p += n+2;
+		if(p+2+n > e)
+			goto Overflow;
+		put16(p, n), p += 2;
+		memmove(p, m->u.serverHello.extensions->data, n);
+		p += n;
 		break;
 	case HServerHelloDone:
 		break;
 	case HCertificate:
-		nn = 0;
+		n = 0;
 		for(i = 0; i < m->u.certificate.ncert; i++)
-			nn += 3 + m->u.certificate.certs[i]->len;
-		if(p + 3 + nn - c->sendbuf > sizeof(c->sendbuf)) {
-			tlsError(c, EInternalError, "output buffer too small for certificate");
-			goto Err;
-		}
-		put24(p, nn);
-		p += 3;
+			n += 3 + m->u.certificate.certs[i]->len;
+		if(p+3+n > e)
+			goto Overflow;
+		put24(p, n), p += 3;
 		for(i = 0; i < m->u.certificate.ncert; i++){
-			put24(p, m->u.certificate.certs[i]->len);
-			p += 3;
-			memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
-			p += m->u.certificate.certs[i]->len;
+			n = m->u.certificate.certs[i]->len;
+			put24(p, n), p += 3;
+			memmove(p, m->u.certificate.certs[i]->data, n);
+			p += n;
 		}
 		break;
 	case HCertificateVerify:
-		if(m->u.certificateVerify.sigalg != 0){
-			put16(p, m->u.certificateVerify.sigalg);
-			p += 2;
-		}
-		put16(p, m->u.certificateVerify.signature->len);
-		p += 2;
-		memmove(p, m->u.certificateVerify.signature->data, m->u.certificateVerify.signature->len);
-		p += m->u.certificateVerify.signature->len;
+		if(p+2+2+(n = m->u.certificateVerify.signature->len) > e)
+			goto Overflow;
+		if(m->u.certificateVerify.sigalg != 0)
+			put16(p, m->u.certificateVerify.sigalg), p += 2;
+		put16(p, n), p += 2;
+		memmove(p, m->u.certificateVerify.signature->data, n);
+		p += n;
 		break;
 	case HServerKeyExchange:
 		if(m->u.serverKeyExchange.pskid != nil){
-			n = m->u.serverKeyExchange.pskid->len;
-			put16(p, n);
-			p += 2;
+			if(p+2+(n = m->u.serverKeyExchange.pskid->len) > e)
+				goto Overflow;
+			put16(p, n), p += 2;
 			memmove(p, m->u.serverKeyExchange.pskid->data, n);
 			p += n;
 		}
 		if(m->u.serverKeyExchange.dh_parameters == nil)
 			break;
-		n = m->u.serverKeyExchange.dh_parameters->len;
+		if(p+(n = m->u.serverKeyExchange.dh_parameters->len) > e)
+			goto Overflow;
 		memmove(p, m->u.serverKeyExchange.dh_parameters->data, n);
 		p += n;
 		if(m->u.serverKeyExchange.dh_signature == nil)
 			break;
-		if(c->version >= TLS12Version){
-			put16(p, m->u.serverKeyExchange.sigalg);
-			p += 2;
-		}
-		n = m->u.serverKeyExchange.dh_signature->len;
+		if(p+2+2+(n = m->u.serverKeyExchange.dh_signature->len) > e)
+			goto Overflow;
+		if(c->version >= TLS12Version)
+			put16(p, m->u.serverKeyExchange.sigalg), p += 2;
 		put16(p, n), p += 2;
 		memmove(p, m->u.serverKeyExchange.dh_signature->data, n);
 		p += n;
@@ -1393,15 +1387,16 @@
 		break;
 	case HClientKeyExchange:
 		if(m->u.clientKeyExchange.pskid != nil){
-			n = m->u.clientKeyExchange.pskid->len;
-			put16(p, n);
-			p += 2;
+			if(p+2+(n = m->u.clientKeyExchange.pskid->len) > e)
+				goto Overflow;
+			put16(p, n), p += 2;
 			memmove(p, m->u.clientKeyExchange.pskid->data, n);
 			p += n;
 		}
 		if(m->u.clientKeyExchange.key == nil)
 			break;
-		n = m->u.clientKeyExchange.key->len;
+		if(p+2+(n = m->u.clientKeyExchange.key->len) > e)
+			goto Overflow;
 		if(isECDHE(c->cipher))
 			*p++ = n;
 		else if(isDHE(c->cipher) || c->version != SSL3Version)
@@ -1410,6 +1405,8 @@
 		p += n;
 		break;
 	case HFinished:
+		if(p+m->u.finished.n > e)
+			goto Overflow;
 		memmove(p, m->u.finished.verify, m->u.finished.n);
 		p += m->u.finished.n;
 		break;
@@ -1417,7 +1414,6 @@
 
 	// go back and fill in size
 	n = p - c->sendp;
-	assert(n <= sizeof(c->sendbuf));
 	put24(c->sendp+1, n-4);
 
 	// remember hash of Handshake messages
@@ -1426,8 +1422,8 @@
 
 	c->sendp = p;
 	if(act == AFlush){
-		c->sendp = c->sendbuf;
-		if(write(c->hand, c->sendbuf, p - c->sendbuf) < 0){
+		c->sendp = c->buf;
+		if(write(c->hand, c->buf, p - c->buf) < 0){
 			fprint(2, "write error: %r\n");
 			goto Err;
 		}
@@ -1434,6 +1430,8 @@
 	}
 	msgClear(m);
 	return 1;
+Overflow:
+	tlsError(c, EInternalError, "not enougth send buffer for message (%d)", m->tag);
 Err:
 	msgClear(m);
 	return 0;
@@ -1442,25 +1440,28 @@
 static uchar*
 tlsReadN(TlsConnection *c, int n)
 {
-	uchar *p;
-	int nn, nr;
+	uchar *p, *e;
 
-	nn = c->ep - c->rp;
-	if(nn < n){
-		if(c->rp != c->recvbuf){
-			memmove(c->recvbuf, c->rp, nn);
-			c->rp = c->recvbuf;
-			c->ep = &c->recvbuf[nn];
-		}
-		for(; nn < n; nn += nr) {
-			nr = read(c->hand, &c->rp[nn], n - nn);
-			if(nr <= 0)
-				return nil;
-			c->ep += nr;
-		}
+	p = c->recvp;
+	if(n <= c->recvw - p){
+		c->recvp += n;
+		return p;
 	}
-	p = c->rp;
-	c->rp += n;
+	e = &c->buf[sizeof(c->buf)];
+	c->recvp = e - n;
+	if(c->recvp < c->sendp || n > sizeof(c->buf)){
+		tlsError(c, EDecodeError, "handshake message too long %d", n);
+		return nil;
+	}
+	memmove(c->recvp, p, c->recvw - p);
+	c->recvw -= p - c->recvp;
+	p = c->recvp;
+	c->recvp += n;
+	while(c->recvw < c->recvp){
+		if((n = read(c->hand, c->recvw, e - c->recvw)) <= 0)
+			return nil;
+		c->recvw += n;
+	}
 	return p;
 }
 
@@ -1486,11 +1487,6 @@
 		}
 	}
 
-	if(n > sizeof(c->recvbuf)) {
-		tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->recvbuf));
-		return 0;
-	}
-
 	if(type == HSSL2ClientHello){
 		/* Cope with an SSL3 ClientHello expressed in SSL2 record format.
 			This is sent by some clients that we must interoperate
@@ -1513,10 +1509,8 @@
 		p += 6;
 		n -= 6;
 		if(nsid != 0 	/* no sid's, since shouldn't restart using ssl2 header */
-				|| nrandom < 16 || nn % 3)
+		|| nrandom < 16 || nn % 3 || n - nrandom < nn)
 			goto Err;
-		if(c->trace && (n - nrandom != nn))
-			c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
 		/* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
 		nciph = 0;
 		for(i = 0; i < nn; i += 3)
@@ -1806,15 +1800,11 @@
 		break;
 	}
 
-	if(type != HClientHello && type != HServerHello && n != 0)
+	if(n != 0 && type != HClientHello && type != HServerHello)
 		goto Short;
 Ok:
-	if(c->trace){
-		char *buf;
-		buf = emalloc(8000);
-		c->trace("recv %s", msgPrint(buf, 8000, m));
-		free(buf);
-	}
+	if(c->trace)
+		c->trace("recv %s", msgPrint((char*)c->sendp, c->recvp - c->sendp, m));
 	return 1;
 Short:
 	tlsError(c, EDecodeError, "handshake message (%d) has invalid length", type);
@@ -2624,8 +2614,9 @@
 	K.y = mpnew(0);
 
 	ecmul(dom, Y, Q->d, &K);
-	setMasterSecret(sec, mptobytes(K.x));
 
+	setMasterSecret(sec, mptobytes(K.x, (mpsignif(dom->p)+7)/8));
+
 	mpfree(K.x);
 	mpfree(K.y);
 
@@ -2858,7 +2849,7 @@
 	y = factotum_rsa_decrypt(sec->rpc, bytestomp(data));
 	if(y == nil)
 		return nil;
-	data = mptobytes(y);
+	data = mptobytes(y, (mpsignif(y)+7)/8);
 	if((data->len = pkcs1unpadbuf(data->data, data->len, sec->rsapub->n, 2)) < 0){
 		freebytes(data);
 		return nil;
@@ -2884,10 +2875,11 @@
 		werrstr("bad digest algorithm");
 		return nil;
 	}
+
 	signedMP = factotum_rsa_decrypt(sec->rpc, pkcs1padbuf(buf, digestlen, sec->rsapub->n, 1));
 	if(signedMP == nil)
 		return nil;
-	signature = mptobytes(signedMP);
+	signature = mptobytes(signedMP, (mpsignif(sec->rsapub->n)+7)/8);
 	mpfree(signedMP);
 	return signature;
 }
@@ -2999,14 +2991,12 @@
  * Convert mpint* to Bytes, putting high order byte first.
  */
 static Bytes*
-mptobytes(mpint* big)
+mptobytes(mpint *big, int len)
 {
 	Bytes* ans;
-	int n;
 
-	n = (mpsignif(big)+7)/8;
-	if(n == 0) n = 1;
-	ans = newbytes(n);
+	if(len == 0) len++;
+	ans = newbytes(len);
 	mptober(big, ans->data, ans->len);
 	return ans;
 }