shithub: riscv

Download patch

ref: 39f18c9d88f52a22373790dec5721fa3521d3f00
parent: 4a6ab355c1af789f7ddb4edbf4d82d17efd9d2bf
author: cinap_lenrek <[email protected]>
date: Fri Dec 25 12:05:05 EST 2015

libsec: implement TLS-PSK for tlsClient()/tlsServer()

--- a/sys/include/libsec.h
+++ b/sys/include/libsec.h
@@ -412,8 +412,10 @@
 	char	dir[40];	/* connection directory */
 	uchar	*cert;	/* certificate (local on input, remote on output) */
 	uchar	*sessionID;
+	uchar	*psk;
 	int	certlen;
 	int	sessionIDlen;
+	int	psklen;
 	int	(*trace)(char*fmt, ...);
 	PEMChain*chain;	/* optional extra certificate evidence for servers to present */
 	char	*sessionType;
@@ -421,6 +423,7 @@
 	int	sessionKeylen;
 	char	*sessionConst;
 	char	*serverName;
+	char	*pskID;
 } TLSconn;
 
 /* tlshand.c */
--- a/sys/man/2/pushtls
+++ b/sys/man/2/pushtls
@@ -100,7 +100,8 @@
 	char	dir[40];		/* OUT    connection directory */
 	uchar *cert;		/* IN/OUT certificate */
 	uchar *sessionID;	/* IN/OUT session ID */
-	int	certlen, sessionIDlen;
+	uchar *psk;		/* opt IN pre-shared key */
+	int	certlen, sessionIDlen, psklen;
 	int	(*trace)(char*fmt, ...);
 	PEMChain *chain;
 	char	*sessionType;	/* opt IN  session type */
@@ -108,6 +109,7 @@
 	int	sessionKeylen;	/* opt IN  session key length */
 	char	*sessionConst;	/* opt IN  session constant */
 	char	*serverName;	/* opt IN  server name */
+	char	*pskID;		/* opt IN  pre-shared key ID */
 } TLSconn;
 .EE
 .PP
--- a/sys/src/libsec/port/tlshand.c
+++ b/sys/src/libsec/port/tlshand.c
@@ -135,9 +135,11 @@
 			Bytes **cas;
 		} certificateRequest;
 		struct {
+			Bytes *pskid;
 			Bytes *key;
 		} clientKeyExchange;
 		struct {
+			Bytes *pskid;
 			Bytes *dh_p;
 			Bytes *dh_g;
 			Bytes *dh_Ys;
@@ -159,6 +161,8 @@
 	int ok;	// <0 killed; == 0 in progress; >0 reusable
 	RSApub *rsapub;
 	AuthRpc *rpc;	// factotum for rsa private key
+	uchar *psk;	// pre-shared key
+	int psklen;
 	uchar sec[MasterSecretSize];	// master secret
 	uchar crandom[RandomSize];	// client random
 	uchar srandom[RandomSize];	// server random
@@ -223,6 +227,7 @@
 	EInternalError = 80,
 	EUserCanceled = 90,
 	ENoRenegotiation = 100,
+	EUnknownPSKidentity = 115,
 	EMax = 256
 };
 
@@ -274,6 +279,16 @@
 	TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA	= 0XC013,
 	TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA	= 0XC014,
 	TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256	= 0xC027,
+
+	TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305	= 0xCCA8,
+	TLS_DHE_RSA_WITH_CHACHA20_POLY1305	= 0xCCAA,
+
+	GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305	= 0xCC13,
+	GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305	= 0xCC15,
+
+	TLS_PSK_WITH_CHACHA20_POLY1305		= 0xCCAB,
+	TLS_PSK_WITH_AES_128_CBC_SHA256		= 0x00AE,
+	TLS_PSK_WITH_AES_128_CBC_SHA		= 0x008C,
 };
 
 // compression methods
@@ -283,10 +298,12 @@
 };
 
 static Algs cipherAlgs[] = {
-	{"ccpoly96_aead", "clear", 2*(32+12), 0xCCA8},	// TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 (IETF)
-	{"ccpoly96_aead", "clear", 2*(32+12), 0xCCAA},	// TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 (IETF)
-	{"ccpoly64_aead", "clear", 2*32, 0xCC13},	// TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 (draft)
-	{"ccpoly64_aead", "clear", 2*32, 0xCC15},	// TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 (draft)
+	{"ccpoly96_aead", "clear", 2*(32+12), TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305},
+	{"ccpoly96_aead", "clear", 2*(32+12), TLS_DHE_RSA_WITH_CHACHA20_POLY1305},
+
+	{"ccpoly64_aead", "clear", 2*32, GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305},
+	{"ccpoly64_aead", "clear", 2*32, GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305},
+
 	{"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256},
 	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
 	{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA},
@@ -299,6 +316,11 @@
 	{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA},
 	{"3des_ede_cbc","sha1",	2*(4*8+SHA1dlen), TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA},
 	{"3des_ede_cbc","sha1",	2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
+
+	// PSK cipher suits
+	{"ccpoly96_aead", "clear", 2*(32+12), TLS_PSK_WITH_CHACHA20_POLY1305},
+	{"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_PSK_WITH_AES_128_CBC_SHA256},
+	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_PSK_WITH_AES_128_CBC_SHA},
 };
 
 static uchar compressors[] = {
@@ -327,8 +349,15 @@
 	0x0201,		/* SHA1 RSA */
 };
 
-static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chain);
-static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen, int (*trace)(char*fmt, ...));
+static TlsConnection *tlsServer2(int ctl, int hand,
+	uchar *cert, int certlen,
+	char *pskid, uchar *psk, int psklen,
+	int (*trace)(char*fmt, ...), PEMChain *chain);
+static TlsConnection *tlsClient2(int ctl, int hand,
+	uchar *csid, int ncsid,
+	uchar *cert, int certlen,
+	char *pskid, uchar *psk, int psklen,
+	uchar *ext, int extlen, int (*trace)(char*fmt, ...));
 static void	msgClear(Msg *m);
 static char* msgPrint(char *buf, int n, Msg *m);
 static int	msgRecv(TlsConnection *c, Msg *m);
@@ -340,15 +369,17 @@
 static void tlsConnectionFree(TlsConnection *c);
 
 static int setAlgs(TlsConnection *c, int a);
-static int okCipher(Ints *cv);
+static int okCipher(Ints *cv, int ispsk);
 static int okCompression(Bytes *cv);
 static int initCiphers(void);
-static Ints* makeciphers(void);
+static Ints* makeciphers(int ispsk);
 
 static TlsSec*	tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
 static int	tlsSecRSAs(TlsSec *sec, int vers, Bytes *epm);
+static int	tlsSecPSKs(TlsSec *sec, int vers);
 static TlsSec*	tlsSecInitc(int cvers, uchar *crandom);
 static Bytes*	tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers);
+static int	tlsSecPSKc(TlsSec *sec, uchar *srandom, int vers);
 static Bytes*	tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, Bytes *p, Bytes *g, Bytes *Ys);
 static Bytes*	tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys);
 static int	tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient);
@@ -424,7 +455,10 @@
 		return -1;
 	}
 	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
-	tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain);
+	tls = tlsServer2(ctl, hand,
+		conn->cert, conn->certlen,
+		conn->pskID, conn->psk, conn->psklen,
+		conn->trace, conn->chain);
 	snprint(dname, sizeof(dname), "#a/tls/%s/data", buf);
 	data = open(dname, ORDWR);
 	close(hand);
@@ -435,7 +469,7 @@
 		return -1;
 	}
 	free(conn->cert);
-	conn->cert = 0;  // client certificates are not yet implemented
+	conn->cert = nil;  // client certificates are not yet implemented
 	conn->certlen = 0;
 	conn->sessionIDlen = tls->sid->len;
 	conn->sessionID = emalloc(conn->sessionIDlen);
@@ -561,7 +595,10 @@
 	}
 	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
 	ext = tlsClientExtensions(conn, &n);
-	tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->cert, conn->certlen, 
+	tls = tlsClient2(ctl, hand,
+		conn->sessionID, conn->sessionIDlen,
+		conn->cert, conn->certlen, 
+		conn->pskID, conn->psk, conn->psklen,
 		ext, n, conn->trace);
 	free(ext);
 	close(hand);
@@ -570,9 +607,14 @@
 		close(data);
 		return -1;
 	}
-	conn->certlen = tls->cert->len;
-	conn->cert = emalloc(conn->certlen);
-	memcpy(conn->cert, tls->cert->data, conn->certlen);
+	if(tls->cert != nil){
+		conn->certlen = tls->cert->len;
+		conn->cert = emalloc(conn->certlen);
+		memcpy(conn->cert, tls->cert->data, conn->certlen);
+	} else {
+		conn->certlen = 0;
+		conn->cert = nil;
+	}
 	conn->sessionIDlen = tls->sid->len;
 	conn->sessionID = emalloc(conn->sessionIDlen);
 	memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
@@ -603,7 +645,10 @@
 }
 
 static TlsConnection *
-tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chp)
+tlsServer2(int ctl, int hand,
+	uchar *cert, int certlen,
+	char *pskid, uchar *psk, int psklen,
+	int (*trace)(char*fmt, ...), PEMChain *chp)
 {
 	TlsConnection *c;
 	Msg m;
@@ -641,7 +686,7 @@
 	}
 
 	memmove(c->crandom, m.u.clientHello.random, RandomSize);
-	cipher = okCipher(m.u.clientHello.ciphers);
+	cipher = okCipher(m.u.clientHello.ciphers, psklen > 0);
 	if(cipher < 0) {
 		// reply with EInsufficientSecurity if we know that's the case
 		if(cipher == -2)
@@ -662,21 +707,27 @@
 
 	csid = m.u.clientHello.sid;
 	if(trace)
-		trace("  cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
+		trace("  cipher %x, compressor %x, csidlen %d\n", cipher, compressor, csid->len);
 	c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
 	if(c->sec == nil){
 		tlsError(c, EHandshakeFailure, "can't initialize security: %r");
 		goto Err;
 	}
-	c->sec->rpc = factotum_rsa_open(cert, certlen);
-	if(c->sec->rpc == nil){
-		tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
-		goto Err;
+	if(psklen > 0){
+		c->sec->psk = psk;
+		c->sec->psklen = psklen;
 	}
-	c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
-	if(c->sec->rsapub == nil){
-		tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate");
-		goto Err;
+	if(certlen > 0){
+		c->sec->rpc = factotum_rsa_open(cert, certlen);
+		if(c->sec->rpc == nil){
+			tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
+			goto Err;
+		}
+		c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
+		if(c->sec->rsapub == nil){
+			tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate");
+			goto Err;
+		}
 	}
 	msgClear(&m);
 
@@ -691,16 +742,18 @@
 		goto Err;
 	msgClear(&m);
 
-	m.tag = HCertificate;
-	numcerts = countchain(chp);
-	m.u.certificate.ncert = 1 + numcerts;
-	m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
-	m.u.certificate.certs[0] = makebytes(cert, certlen);
-	for (i = 0; i < numcerts && chp; i++, chp = chp->next)
-		m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
-	if(!msgSend(c, &m, AQueue))
-		goto Err;
-	msgClear(&m);
+	if(certlen > 0){
+		m.tag = HCertificate;
+		numcerts = countchain(chp);
+		m.u.certificate.ncert = 1 + numcerts;
+		m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
+		m.u.certificate.certs[0] = makebytes(cert, certlen);
+		for (i = 0; i < numcerts && chp; i++, chp = chp->next)
+			m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
+		if(!msgSend(c, &m, AQueue))
+			goto Err;
+		msgClear(&m);
+	}
 
 	m.tag = HServerHelloDone;
 	if(!msgSend(c, &m, AFlush))
@@ -713,10 +766,29 @@
 		tlsError(c, EUnexpectedMessage, "expected a client key exchange");
 		goto Err;
 	}
-	if(tlsSecRSAs(c->sec, c->version, m.u.clientKeyExchange.key) < 0){
-		tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
+	if(pskid != nil){
+		if(m.u.clientKeyExchange.pskid == nil
+		|| m.u.clientKeyExchange.pskid->len != strlen(pskid)
+		|| memcmp(pskid, m.u.clientKeyExchange.pskid->data, m.u.clientKeyExchange.pskid->len) != 0){
+			tlsError(c, EUnknownPSKidentity, "unknown or missing pskid");
+			goto Err;
+		}
+	}
+	if(certlen > 0){
+		if(tlsSecRSAs(c->sec, c->version, m.u.clientKeyExchange.key) < 0){
+			tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
+			goto Err;
+		}
+	} else if(psklen > 0){
+		if(tlsSecPSKs(c->sec, c->version) < 0){
+			tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
+			goto Err;
+		}
+	} else {
+		tlsError(c, EInternalError, "no psk or certificate");
 		goto Err;
 	}
+
 	setSecrets(c->sec, kd, c->nsecret);
 	if(trace)
 		trace("tls secrets\n");
@@ -786,7 +858,8 @@
  	case TLS_DHE_RSA_WITH_AES_128_CBC_SHA:
  	case TLS_DHE_RSA_WITH_AES_256_CBC_SHA:
  	case TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA:
- 	case 0xCCAA: case 0xCC15:	// TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256
+	case TLS_DHE_RSA_WITH_CHACHA20_POLY1305:
+	case GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305:
 		return 1;
 	}
 	return 0;
@@ -799,12 +872,25 @@
 	case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
 	case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA:
 	case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
-	case 0xCCA8: case 0xCC13:	// TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
+	case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305:
+	case GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305:
 		return 1;
 	}
 	return 0;
 }
 
+static int
+isPSK(int tlsid)
+{
+	switch(tlsid){
+	case TLS_PSK_WITH_CHACHA20_POLY1305:
+	case TLS_PSK_WITH_AES_128_CBC_SHA256:
+	case TLS_PSK_WITH_AES_128_CBC_SHA:
+		return 1;
+	}
+	return 0;
+}
+
 static Bytes*
 tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, 
 	Bytes *p, Bytes *g, Bytes *Ys)
@@ -980,9 +1066,19 @@
 	RSApub *pk;
 	char *err;
 
-	if(sig == nil || sig->len <= 0)
+	if(par == nil || par->len <= 0)
+		return "no dh parameters";
+
+	if(sig == nil || sig->len <= 0){
+		if(c->sec->psklen > 0)
+			return nil;
+
 		return "no signature";
+	}
 
+	if(c->cert == nil)
+		return "no certificate";
+
 	pk = X509toRSApub(c->cert->data, c->cert->len, nil, 0);
 	if(pk == nil)
 		return "bad certificate";
@@ -1015,7 +1111,11 @@
 }
 
 static TlsConnection *
-tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen,
+tlsClient2(int ctl, int hand,
+	uchar *csid, int ncsid, 
+	uchar *cert, int certlen,
+	char *pskid, uchar *psk, int psklen,
+	uchar *ext, int extlen,
 	int (*trace)(char*fmt, ...))
 {
 	TlsConnection *c;
@@ -1036,10 +1136,17 @@
 	c->trace = trace;
 	c->isClient = 1;
 	c->clientVersion = c->version;
+	c->cert = nil;
 
 	c->sec = tlsSecInitc(c->clientVersion, c->crandom);
 	if(c->sec == nil)
 		goto Err;
+
+	if(psklen > 0){
+		c->sec->psk = psk;
+		c->sec->psklen = psklen;
+	}
+
 	/* client hello */
 	memset(&m, 0, sizeof(m));
 	m.tag = HClientHello;
@@ -1046,7 +1153,7 @@
 	m.u.clientHello.version = c->clientVersion;
 	memmove(m.u.clientHello.random, c->crandom, RandomSize);
 	m.u.clientHello.sid = makebytes(csid, ncsid);
-	m.u.clientHello.ciphers = makeciphers();
+	m.u.clientHello.ciphers = makeciphers(psklen > 0);
 	m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
 	m.u.clientHello.extensions = makebytes(ext, extlen);
 	if(!msgSend(c, &m, AFlush))
@@ -1071,7 +1178,7 @@
 		goto Err;
 	}
 	cipher = m.u.serverHello.cipher;
-	if(!setAlgs(c, cipher)) {
+	if((psklen > 0) != isPSK(cipher) || !setAlgs(c, cipher)) {
 		tlsError(c, EIllegalParameter, "invalid cipher suite");
 		goto Err;
 	}
@@ -1081,48 +1188,47 @@
 	}
 	msgClear(&m);
 
-	/* certificate */
-	if(!msgRecv(c, &m) || m.tag != HCertificate) {
-		tlsError(c, EUnexpectedMessage, "expected a certificate");
-		goto Err;
-	}
-	if(m.u.certificate.ncert < 1) {
-		tlsError(c, EIllegalParameter, "runt certificate");
-		goto Err;
-	}
-	c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
-	msgClear(&m);
-
-	/* server key exchange */
 	dhx = isDHE(cipher) || isECDHE(cipher);
 	if(!msgRecv(c, &m))
 		goto Err;
+	if(m.tag == HCertificate){
+		if(m.u.certificate.ncert < 1) {
+			tlsError(c, EIllegalParameter, "runt certificate");
+			goto Err;
+		}
+		c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
+		msgClear(&m);
+		if(!msgRecv(c, &m))
+			goto Err;
+	} else if(psklen == 0) {
+		tlsError(c, EUnexpectedMessage, "expected a certificate");
+		goto Err;
+	}
 	if(m.tag == HServerKeyExchange) {
-		char *err;
-
-		if(!dhx){
+		if(dhx){
+			char *err = verifyDHparams(c,
+				m.u.serverKeyExchange.dh_parameters,
+				m.u.serverKeyExchange.dh_signature,
+				m.u.serverKeyExchange.sigalg);
+			if(err != nil){
+				tlsError(c, EBadCertificate, "can't verify dh parameters: %s", err);
+				goto Err;
+			}
+			if(isECDHE(cipher))
+				epm = tlsSecECDHEc(c->sec, c->srandom, c->version,
+					m.u.serverKeyExchange.curve,
+					m.u.serverKeyExchange.dh_Ys);
+			else
+				epm = tlsSecDHEc(c->sec, c->srandom, c->version,
+					m.u.serverKeyExchange.dh_p, 
+					m.u.serverKeyExchange.dh_g,
+					m.u.serverKeyExchange.dh_Ys);
+			if(epm == nil)
+				goto Badcert;
+		} else if(psklen == 0){
 			tlsError(c, EUnexpectedMessage, "got an server key exchange");
 			goto Err;
 		}
-		err = verifyDHparams(c,
-			m.u.serverKeyExchange.dh_parameters,
-			m.u.serverKeyExchange.dh_signature,
-			m.u.serverKeyExchange.sigalg);
-		if(err != nil){
-			tlsError(c, EBadCertificate, "can't verify dh parameters: %s", err);
-			goto Err;
-		}
-		if(isECDHE(cipher))
-			epm = tlsSecECDHEc(c->sec, c->srandom, c->version,
-				m.u.serverKeyExchange.curve,
-				m.u.serverKeyExchange.dh_Ys);
-		else
-			epm = tlsSecDHEc(c->sec, c->srandom, c->version,
-				m.u.serverKeyExchange.dh_p, 
-				m.u.serverKeyExchange.dh_g,
-				m.u.serverKeyExchange.dh_Ys);
-		if(epm == nil)
-			goto Badcert;
 		msgClear(&m);
 		if(!msgRecv(c, &m))
 			goto Err;
@@ -1146,14 +1252,22 @@
 	}
 	msgClear(&m);
 
-	if(!dhx)
-		epm = tlsSecRSAc(c->sec, c->sid->data, c->sid->len, c->srandom,
-			c->cert->data, c->cert->len, c->version);
-
-	if(epm == nil){
-	Badcert:
-		tlsError(c, EBadCertificate, "bad certificate: %r");
-		goto Err;
+	if(!dhx){
+		if(c->cert != nil){
+			epm = tlsSecRSAc(c->sec, c->sid->data, c->sid->len, c->srandom,
+				c->cert->data, c->cert->len, c->version);
+			if(epm == nil){
+			Badcert:
+				tlsError(c, EBadCertificate, "bad certificate: %r");
+				goto Err;
+			}
+		} else if(psklen > 0) {
+			if(tlsSecPSKc(c->sec, c->srandom, c->version) < 0)
+				goto Badcert;
+		} else {
+			tlsError(c, EInternalError, "no psk or certificate");
+			goto Err;
+		}
 	}
 
 	setSecrets(c->sec, kd, c->nsecret);
@@ -1182,12 +1296,13 @@
 
 	/* client key exchange */
 	m.tag = HClientKeyExchange;
+	if(psklen > 0){
+		if(pskid == nil)
+			pskid = "";
+		m.u.clientKeyExchange.pskid = makebytes((uchar*)pskid, strlen(pskid));
+	}
 	m.u.clientKeyExchange.key = epm;
 	epm = nil;
-	if(m.u.clientKeyExchange.key == nil) {
-		tlsError(c, EHandshakeFailure, "can't set secret: %r");
-		goto Err;
-	}
 	 
 	if(!msgSend(c, &m, AFlush))
 		goto Err;
@@ -1423,8 +1538,17 @@
 		p += 2;
 		memmove(p, m->u.certificateVerify.signature->data, m->u.certificateVerify.signature->len);
 		p += m->u.certificateVerify.signature->len;
-		break;	
+		break;
 	case HClientKeyExchange:
+		if(m->u.clientKeyExchange.pskid != nil){
+			n = m->u.clientKeyExchange.pskid->len;
+			put16(p, n);
+			p += 2;
+			memmove(p, m->u.clientKeyExchange.pskid->data, n);
+			p += n;
+		}
+		if(m->u.clientKeyExchange.key == nil)
+			break;
 		n = m->u.clientKeyExchange.key->len;
 		if(c->version != SSL3Version){
 			if(isECDHE(c->cipher))
@@ -1737,6 +1861,18 @@
 	case HServerHelloDone:
 		break;
 	case HServerKeyExchange:
+		if(isPSK(c->cipher)){
+			if(n < 2)
+				goto Short;
+			nn = get16(p);
+			p += 2, n -= 2;
+			if(nn > n)
+				goto Short;
+			m->u.serverKeyExchange.pskid = makebytes(p, nn);
+			p += nn, n -= nn;
+			if(n == 0)
+				break;
+		}
 		if(n < 2)
 			goto Short;
 		s = p;
@@ -1805,6 +1941,18 @@
 		 * this message depends upon the encryption selected
 		 * assume rsa.
 		 */
+		if(isPSK(c->cipher)){
+			if(n < 2)
+				goto Short;
+			nn = get16(p);
+			p += 2, n -= 2;
+			if(nn > n)
+				goto Short;
+			m->u.clientKeyExchange.pskid = makebytes(p, nn);
+			p += nn, n -= nn;
+			if(n == 0)
+				break;
+		}
 		if(c->version == SSL3Version)
 			nn = n;
 		else{
@@ -1883,6 +2031,7 @@
 	case HServerHelloDone:
 		break;
 	case HServerKeyExchange:
+		freebytes(m->u.serverKeyExchange.pskid);
 		freebytes(m->u.serverKeyExchange.dh_p);
 		freebytes(m->u.serverKeyExchange.dh_g);
 		freebytes(m->u.serverKeyExchange.dh_Ys);
@@ -1890,6 +2039,7 @@
 		freebytes(m->u.serverKeyExchange.dh_signature);
 		break;
 	case HClientKeyExchange:
+		freebytes(m->u.clientKeyExchange.pskid);
 		freebytes(m->u.clientKeyExchange.key);
 		break;
 	case HFinished:
@@ -1998,6 +2148,10 @@
 		break;
 	case HServerKeyExchange:
 		bs = seprint(bs, be, "HServerKeyExchange\n");
+		if(m->u.serverKeyExchange.pskid != nil)
+			bs = bytesPrint(bs, be, "\tpskid: ", m->u.serverKeyExchange.pskid, "\n");
+		if(m->u.serverKeyExchange.dh_parameters == nil)
+			break;
 		if(m->u.serverKeyExchange.curve != 0){
 			bs = seprint(bs, be, "\tcurve: %.4x\n", m->u.serverKeyExchange.curve);
 		} else {
@@ -2012,7 +2166,10 @@
 		break;
 	case HClientKeyExchange:
 		bs = seprint(bs, be, "HClientKeyExchange\n");
-		bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
+		if(m->u.clientKeyExchange.pskid != nil)
+			bs = bytesPrint(bs, be, "\tpskid: ", m->u.clientKeyExchange.pskid, "\n");
+		if(m->u.clientKeyExchange.key != nil)
+			bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
 		break;
 	case HFinished:
 		bs = seprint(bs, be, "HFinished\n");
@@ -2137,7 +2294,7 @@
 }
 
 static int
-okCipher(Ints *cv)
+okCipher(Ints *cv, int ispsk)
 {
 	int weak, i, j, c;
 
@@ -2148,6 +2305,8 @@
 			weak = 0;
 		else
 			weak &= weakCipher[c];
+		if(isPSK(c) != ispsk)
+			continue;
 		if(isDHE(c) || isECDHE(c))
 			continue;	/* TODO: not implemented for server */
 		for(j = 0; j < nelem(cipherAlgs); j++)
@@ -2243,7 +2402,7 @@
 }
 
 static Ints*
-makeciphers(void)
+makeciphers(int ispsk)
 {
 	Ints *is;
 	int i, j;
@@ -2250,10 +2409,10 @@
 
 	is = newints(nciphers);
 	j = 0;
-	for(i = 0; i < nelem(cipherAlgs); i++){
-		if(cipherAlgs[i].ok)
+	for(i = 0; i < nelem(cipherAlgs); i++)
+		if(cipherAlgs[i].ok && isPSK(cipherAlgs[i].tlsid) == ispsk)
 			is->data[j++] = cipherAlgs[i].tlsid;
-	}
+	is->len = j;
 	return is;
 }
 
@@ -2489,6 +2648,17 @@
 	return -1;
 }
 
+static int
+tlsSecPSKs(TlsSec *sec, int vers)
+{
+	if(setVers(sec, vers) < 0){
+		sec->ok = -1;
+		return -1;
+	}
+	setMasterSecret(sec, newbytes(sec->psklen));
+	return 0;
+}
+
 static TlsSec*
 tlsSecInitc(int cvers, uchar *crandom)
 {
@@ -2500,6 +2670,18 @@
 	return sec;
 }
 
+static int
+tlsSecPSKc(TlsSec *sec, uchar *srandom, int vers)
+{
+	memmove(sec->srandom, srandom, RandomSize);
+	if(setVers(sec, vers) < 0){
+		sec->ok = -1;
+		return -1;
+	}
+	setMasterSecret(sec, newbytes(sec->psklen));
+	return 0;
+}
+
 static Bytes*
 tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers)
 {
@@ -2608,6 +2790,22 @@
 static void
 setMasterSecret(TlsSec *sec, Bytes *pm)
 {
+	if(sec->psklen > 0){
+		Bytes *opm = pm;
+		uchar *p;
+
+		/* concatenate psk to pre-master secret */
+		pm = newbytes(4 + opm->len + sec->psklen);
+		p = pm->data;
+		put16(p, opm->len), p += 2;
+		memmove(p, opm->data, opm->len), p += opm->len;
+		put16(p, sec->psklen), p += 2;
+		memmove(p, sec->psk, sec->psklen);
+
+		memset(opm->data, 0, opm->len);
+		freebytes(opm);
+	}
+
 	(*sec->prf)(sec->sec, MasterSecretSize, pm->data, pm->len, "master secret",
 			sec->crandom, RandomSize, sec->srandom, RandomSize);