ref: 39f18c9d88f52a22373790dec5721fa3521d3f00
parent: 4a6ab355c1af789f7ddb4edbf4d82d17efd9d2bf
author: cinap_lenrek <[email protected]>
date: Fri Dec 25 12:05:05 EST 2015
libsec: implement TLS-PSK for tlsClient()/tlsServer()
--- a/sys/include/libsec.h
+++ b/sys/include/libsec.h
@@ -412,8 +412,10 @@
char dir[40]; /* connection directory */
uchar *cert; /* certificate (local on input, remote on output) */
uchar *sessionID;
+ uchar *psk;
int certlen;
int sessionIDlen;
+ int psklen;
int (*trace)(char*fmt, ...);
PEMChain*chain; /* optional extra certificate evidence for servers to present */
char *sessionType;
@@ -421,6 +423,7 @@
int sessionKeylen;
char *sessionConst;
char *serverName;
+ char *pskID;
} TLSconn;
/* tlshand.c */
--- a/sys/man/2/pushtls
+++ b/sys/man/2/pushtls
@@ -100,7 +100,8 @@
char dir[40]; /* OUT connection directory */
uchar *cert; /* IN/OUT certificate */
uchar *sessionID; /* IN/OUT session ID */
- int certlen, sessionIDlen;
+ uchar *psk; /* opt IN pre-shared key */
+ int certlen, sessionIDlen, psklen;
int (*trace)(char*fmt, ...);
PEMChain *chain;
char *sessionType; /* opt IN session type */
@@ -108,6 +109,7 @@
int sessionKeylen; /* opt IN session key length */
char *sessionConst; /* opt IN session constant */
char *serverName; /* opt IN server name */
+ char *pskID; /* opt IN pre-shared key ID */
} TLSconn;
.EE
.PP
--- a/sys/src/libsec/port/tlshand.c
+++ b/sys/src/libsec/port/tlshand.c
@@ -135,9 +135,11 @@
Bytes **cas;
} certificateRequest;
struct {
+ Bytes *pskid;
Bytes *key;
} clientKeyExchange;
struct {
+ Bytes *pskid;
Bytes *dh_p;
Bytes *dh_g;
Bytes *dh_Ys;
@@ -159,6 +161,8 @@
int ok; // <0 killed; == 0 in progress; >0 reusable
RSApub *rsapub;
AuthRpc *rpc; // factotum for rsa private key
+ uchar *psk; // pre-shared key
+ int psklen;
uchar sec[MasterSecretSize]; // master secret
uchar crandom[RandomSize]; // client random
uchar srandom[RandomSize]; // server random
@@ -223,6 +227,7 @@
EInternalError = 80,
EUserCanceled = 90,
ENoRenegotiation = 100,
+ EUnknownPSKidentity = 115,
EMax = 256
};
@@ -274,6 +279,16 @@
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA = 0XC013,
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA = 0XC014,
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xC027,
+
+ TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = 0xCCA8,
+ TLS_DHE_RSA_WITH_CHACHA20_POLY1305 = 0xCCAA,
+
+ GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305 = 0xCC13,
+ GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305 = 0xCC15,
+
+ TLS_PSK_WITH_CHACHA20_POLY1305 = 0xCCAB,
+ TLS_PSK_WITH_AES_128_CBC_SHA256 = 0x00AE,
+ TLS_PSK_WITH_AES_128_CBC_SHA = 0x008C,
};
// compression methods
@@ -283,10 +298,12 @@
};
static Algs cipherAlgs[] = {
- {"ccpoly96_aead", "clear", 2*(32+12), 0xCCA8}, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 (IETF)
- {"ccpoly96_aead", "clear", 2*(32+12), 0xCCAA}, // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 (IETF)
- {"ccpoly64_aead", "clear", 2*32, 0xCC13}, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 (draft)
- {"ccpoly64_aead", "clear", 2*32, 0xCC15}, // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 (draft)
+ {"ccpoly96_aead", "clear", 2*(32+12), TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305},
+ {"ccpoly96_aead", "clear", 2*(32+12), TLS_DHE_RSA_WITH_CHACHA20_POLY1305},
+
+ {"ccpoly64_aead", "clear", 2*32, GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305},
+ {"ccpoly64_aead", "clear", 2*32, GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305},
+
{"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256},
{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA},
@@ -299,6 +316,11 @@
{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA},
{"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA},
{"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
+
+ // PSK cipher suits
+ {"ccpoly96_aead", "clear", 2*(32+12), TLS_PSK_WITH_CHACHA20_POLY1305},
+ {"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_PSK_WITH_AES_128_CBC_SHA256},
+ {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_PSK_WITH_AES_128_CBC_SHA},
};
static uchar compressors[] = {
@@ -327,8 +349,15 @@
0x0201, /* SHA1 RSA */
};
-static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chain);
-static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen, int (*trace)(char*fmt, ...));
+static TlsConnection *tlsServer2(int ctl, int hand,
+ uchar *cert, int certlen,
+ char *pskid, uchar *psk, int psklen,
+ int (*trace)(char*fmt, ...), PEMChain *chain);
+static TlsConnection *tlsClient2(int ctl, int hand,
+ uchar *csid, int ncsid,
+ uchar *cert, int certlen,
+ char *pskid, uchar *psk, int psklen,
+ uchar *ext, int extlen, int (*trace)(char*fmt, ...));
static void msgClear(Msg *m);
static char* msgPrint(char *buf, int n, Msg *m);
static int msgRecv(TlsConnection *c, Msg *m);
@@ -340,15 +369,17 @@
static void tlsConnectionFree(TlsConnection *c);
static int setAlgs(TlsConnection *c, int a);
-static int okCipher(Ints *cv);
+static int okCipher(Ints *cv, int ispsk);
static int okCompression(Bytes *cv);
static int initCiphers(void);
-static Ints* makeciphers(void);
+static Ints* makeciphers(int ispsk);
static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
static int tlsSecRSAs(TlsSec *sec, int vers, Bytes *epm);
+static int tlsSecPSKs(TlsSec *sec, int vers);
static TlsSec* tlsSecInitc(int cvers, uchar *crandom);
static Bytes* tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers);
+static int tlsSecPSKc(TlsSec *sec, uchar *srandom, int vers);
static Bytes* tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, Bytes *p, Bytes *g, Bytes *Ys);
static Bytes* tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys);
static int tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient);
@@ -424,7 +455,10 @@
return -1;
}
fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
- tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain);
+ tls = tlsServer2(ctl, hand,
+ conn->cert, conn->certlen,
+ conn->pskID, conn->psk, conn->psklen,
+ conn->trace, conn->chain);
snprint(dname, sizeof(dname), "#a/tls/%s/data", buf);
data = open(dname, ORDWR);
close(hand);
@@ -435,7 +469,7 @@
return -1;
}
free(conn->cert);
- conn->cert = 0; // client certificates are not yet implemented
+ conn->cert = nil; // client certificates are not yet implemented
conn->certlen = 0;
conn->sessionIDlen = tls->sid->len;
conn->sessionID = emalloc(conn->sessionIDlen);
@@ -561,7 +595,10 @@
}
fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
ext = tlsClientExtensions(conn, &n);
- tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->cert, conn->certlen,
+ tls = tlsClient2(ctl, hand,
+ conn->sessionID, conn->sessionIDlen,
+ conn->cert, conn->certlen,
+ conn->pskID, conn->psk, conn->psklen,
ext, n, conn->trace);
free(ext);
close(hand);
@@ -570,9 +607,14 @@
close(data);
return -1;
}
- conn->certlen = tls->cert->len;
- conn->cert = emalloc(conn->certlen);
- memcpy(conn->cert, tls->cert->data, conn->certlen);
+ if(tls->cert != nil){
+ conn->certlen = tls->cert->len;
+ conn->cert = emalloc(conn->certlen);
+ memcpy(conn->cert, tls->cert->data, conn->certlen);
+ } else {
+ conn->certlen = 0;
+ conn->cert = nil;
+ }
conn->sessionIDlen = tls->sid->len;
conn->sessionID = emalloc(conn->sessionIDlen);
memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
@@ -603,7 +645,10 @@
}
static TlsConnection *
-tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chp)
+tlsServer2(int ctl, int hand,
+ uchar *cert, int certlen,
+ char *pskid, uchar *psk, int psklen,
+ int (*trace)(char*fmt, ...), PEMChain *chp)
{
TlsConnection *c;
Msg m;
@@ -641,7 +686,7 @@
}
memmove(c->crandom, m.u.clientHello.random, RandomSize);
- cipher = okCipher(m.u.clientHello.ciphers);
+ cipher = okCipher(m.u.clientHello.ciphers, psklen > 0);
if(cipher < 0) {
// reply with EInsufficientSecurity if we know that's the case
if(cipher == -2)
@@ -662,21 +707,27 @@
csid = m.u.clientHello.sid;
if(trace)
- trace(" cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
+ trace(" cipher %x, compressor %x, csidlen %d\n", cipher, compressor, csid->len);
c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
if(c->sec == nil){
tlsError(c, EHandshakeFailure, "can't initialize security: %r");
goto Err;
}
- c->sec->rpc = factotum_rsa_open(cert, certlen);
- if(c->sec->rpc == nil){
- tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
- goto Err;
+ if(psklen > 0){
+ c->sec->psk = psk;
+ c->sec->psklen = psklen;
}
- c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
- if(c->sec->rsapub == nil){
- tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate");
- goto Err;
+ if(certlen > 0){
+ c->sec->rpc = factotum_rsa_open(cert, certlen);
+ if(c->sec->rpc == nil){
+ tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
+ goto Err;
+ }
+ c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
+ if(c->sec->rsapub == nil){
+ tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate");
+ goto Err;
+ }
}
msgClear(&m);
@@ -691,16 +742,18 @@
goto Err;
msgClear(&m);
- m.tag = HCertificate;
- numcerts = countchain(chp);
- m.u.certificate.ncert = 1 + numcerts;
- m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
- m.u.certificate.certs[0] = makebytes(cert, certlen);
- for (i = 0; i < numcerts && chp; i++, chp = chp->next)
- m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
- if(!msgSend(c, &m, AQueue))
- goto Err;
- msgClear(&m);
+ if(certlen > 0){
+ m.tag = HCertificate;
+ numcerts = countchain(chp);
+ m.u.certificate.ncert = 1 + numcerts;
+ m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
+ m.u.certificate.certs[0] = makebytes(cert, certlen);
+ for (i = 0; i < numcerts && chp; i++, chp = chp->next)
+ m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
+ if(!msgSend(c, &m, AQueue))
+ goto Err;
+ msgClear(&m);
+ }
m.tag = HServerHelloDone;
if(!msgSend(c, &m, AFlush))
@@ -713,10 +766,29 @@
tlsError(c, EUnexpectedMessage, "expected a client key exchange");
goto Err;
}
- if(tlsSecRSAs(c->sec, c->version, m.u.clientKeyExchange.key) < 0){
- tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
+ if(pskid != nil){
+ if(m.u.clientKeyExchange.pskid == nil
+ || m.u.clientKeyExchange.pskid->len != strlen(pskid)
+ || memcmp(pskid, m.u.clientKeyExchange.pskid->data, m.u.clientKeyExchange.pskid->len) != 0){
+ tlsError(c, EUnknownPSKidentity, "unknown or missing pskid");
+ goto Err;
+ }
+ }
+ if(certlen > 0){
+ if(tlsSecRSAs(c->sec, c->version, m.u.clientKeyExchange.key) < 0){
+ tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
+ goto Err;
+ }
+ } else if(psklen > 0){
+ if(tlsSecPSKs(c->sec, c->version) < 0){
+ tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
+ goto Err;
+ }
+ } else {
+ tlsError(c, EInternalError, "no psk or certificate");
goto Err;
}
+
setSecrets(c->sec, kd, c->nsecret);
if(trace)
trace("tls secrets\n");
@@ -786,7 +858,8 @@
case TLS_DHE_RSA_WITH_AES_128_CBC_SHA:
case TLS_DHE_RSA_WITH_AES_256_CBC_SHA:
case TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA:
- case 0xCCAA: case 0xCC15: // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256
+ case TLS_DHE_RSA_WITH_CHACHA20_POLY1305:
+ case GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305:
return 1;
}
return 0;
@@ -799,12 +872,25 @@
case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA:
case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
- case 0xCCA8: case 0xCC13: // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
+ case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305:
+ case GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305:
return 1;
}
return 0;
}
+static int
+isPSK(int tlsid)
+{
+ switch(tlsid){
+ case TLS_PSK_WITH_CHACHA20_POLY1305:
+ case TLS_PSK_WITH_AES_128_CBC_SHA256:
+ case TLS_PSK_WITH_AES_128_CBC_SHA:
+ return 1;
+ }
+ return 0;
+}
+
static Bytes*
tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers,
Bytes *p, Bytes *g, Bytes *Ys)
@@ -980,9 +1066,19 @@
RSApub *pk;
char *err;
- if(sig == nil || sig->len <= 0)
+ if(par == nil || par->len <= 0)
+ return "no dh parameters";
+
+ if(sig == nil || sig->len <= 0){
+ if(c->sec->psklen > 0)
+ return nil;
+
return "no signature";
+ }
+ if(c->cert == nil)
+ return "no certificate";
+
pk = X509toRSApub(c->cert->data, c->cert->len, nil, 0);
if(pk == nil)
return "bad certificate";
@@ -1015,7 +1111,11 @@
}
static TlsConnection *
-tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen,
+tlsClient2(int ctl, int hand,
+ uchar *csid, int ncsid,
+ uchar *cert, int certlen,
+ char *pskid, uchar *psk, int psklen,
+ uchar *ext, int extlen,
int (*trace)(char*fmt, ...))
{
TlsConnection *c;
@@ -1036,10 +1136,17 @@
c->trace = trace;
c->isClient = 1;
c->clientVersion = c->version;
+ c->cert = nil;
c->sec = tlsSecInitc(c->clientVersion, c->crandom);
if(c->sec == nil)
goto Err;
+
+ if(psklen > 0){
+ c->sec->psk = psk;
+ c->sec->psklen = psklen;
+ }
+
/* client hello */
memset(&m, 0, sizeof(m));
m.tag = HClientHello;
@@ -1046,7 +1153,7 @@
m.u.clientHello.version = c->clientVersion;
memmove(m.u.clientHello.random, c->crandom, RandomSize);
m.u.clientHello.sid = makebytes(csid, ncsid);
- m.u.clientHello.ciphers = makeciphers();
+ m.u.clientHello.ciphers = makeciphers(psklen > 0);
m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
m.u.clientHello.extensions = makebytes(ext, extlen);
if(!msgSend(c, &m, AFlush))
@@ -1071,7 +1178,7 @@
goto Err;
}
cipher = m.u.serverHello.cipher;
- if(!setAlgs(c, cipher)) {
+ if((psklen > 0) != isPSK(cipher) || !setAlgs(c, cipher)) {
tlsError(c, EIllegalParameter, "invalid cipher suite");
goto Err;
}
@@ -1081,48 +1188,47 @@
}
msgClear(&m);
- /* certificate */
- if(!msgRecv(c, &m) || m.tag != HCertificate) {
- tlsError(c, EUnexpectedMessage, "expected a certificate");
- goto Err;
- }
- if(m.u.certificate.ncert < 1) {
- tlsError(c, EIllegalParameter, "runt certificate");
- goto Err;
- }
- c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
- msgClear(&m);
-
- /* server key exchange */
dhx = isDHE(cipher) || isECDHE(cipher);
if(!msgRecv(c, &m))
goto Err;
+ if(m.tag == HCertificate){
+ if(m.u.certificate.ncert < 1) {
+ tlsError(c, EIllegalParameter, "runt certificate");
+ goto Err;
+ }
+ c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
+ msgClear(&m);
+ if(!msgRecv(c, &m))
+ goto Err;
+ } else if(psklen == 0) {
+ tlsError(c, EUnexpectedMessage, "expected a certificate");
+ goto Err;
+ }
if(m.tag == HServerKeyExchange) {
- char *err;
-
- if(!dhx){
+ if(dhx){
+ char *err = verifyDHparams(c,
+ m.u.serverKeyExchange.dh_parameters,
+ m.u.serverKeyExchange.dh_signature,
+ m.u.serverKeyExchange.sigalg);
+ if(err != nil){
+ tlsError(c, EBadCertificate, "can't verify dh parameters: %s", err);
+ goto Err;
+ }
+ if(isECDHE(cipher))
+ epm = tlsSecECDHEc(c->sec, c->srandom, c->version,
+ m.u.serverKeyExchange.curve,
+ m.u.serverKeyExchange.dh_Ys);
+ else
+ epm = tlsSecDHEc(c->sec, c->srandom, c->version,
+ m.u.serverKeyExchange.dh_p,
+ m.u.serverKeyExchange.dh_g,
+ m.u.serverKeyExchange.dh_Ys);
+ if(epm == nil)
+ goto Badcert;
+ } else if(psklen == 0){
tlsError(c, EUnexpectedMessage, "got an server key exchange");
goto Err;
}
- err = verifyDHparams(c,
- m.u.serverKeyExchange.dh_parameters,
- m.u.serverKeyExchange.dh_signature,
- m.u.serverKeyExchange.sigalg);
- if(err != nil){
- tlsError(c, EBadCertificate, "can't verify dh parameters: %s", err);
- goto Err;
- }
- if(isECDHE(cipher))
- epm = tlsSecECDHEc(c->sec, c->srandom, c->version,
- m.u.serverKeyExchange.curve,
- m.u.serverKeyExchange.dh_Ys);
- else
- epm = tlsSecDHEc(c->sec, c->srandom, c->version,
- m.u.serverKeyExchange.dh_p,
- m.u.serverKeyExchange.dh_g,
- m.u.serverKeyExchange.dh_Ys);
- if(epm == nil)
- goto Badcert;
msgClear(&m);
if(!msgRecv(c, &m))
goto Err;
@@ -1146,14 +1252,22 @@
}
msgClear(&m);
- if(!dhx)
- epm = tlsSecRSAc(c->sec, c->sid->data, c->sid->len, c->srandom,
- c->cert->data, c->cert->len, c->version);
-
- if(epm == nil){
- Badcert:
- tlsError(c, EBadCertificate, "bad certificate: %r");
- goto Err;
+ if(!dhx){
+ if(c->cert != nil){
+ epm = tlsSecRSAc(c->sec, c->sid->data, c->sid->len, c->srandom,
+ c->cert->data, c->cert->len, c->version);
+ if(epm == nil){
+ Badcert:
+ tlsError(c, EBadCertificate, "bad certificate: %r");
+ goto Err;
+ }
+ } else if(psklen > 0) {
+ if(tlsSecPSKc(c->sec, c->srandom, c->version) < 0)
+ goto Badcert;
+ } else {
+ tlsError(c, EInternalError, "no psk or certificate");
+ goto Err;
+ }
}
setSecrets(c->sec, kd, c->nsecret);
@@ -1182,12 +1296,13 @@
/* client key exchange */
m.tag = HClientKeyExchange;
+ if(psklen > 0){
+ if(pskid == nil)
+ pskid = "";
+ m.u.clientKeyExchange.pskid = makebytes((uchar*)pskid, strlen(pskid));
+ }
m.u.clientKeyExchange.key = epm;
epm = nil;
- if(m.u.clientKeyExchange.key == nil) {
- tlsError(c, EHandshakeFailure, "can't set secret: %r");
- goto Err;
- }
if(!msgSend(c, &m, AFlush))
goto Err;
@@ -1423,8 +1538,17 @@
p += 2;
memmove(p, m->u.certificateVerify.signature->data, m->u.certificateVerify.signature->len);
p += m->u.certificateVerify.signature->len;
- break;
+ break;
case HClientKeyExchange:
+ if(m->u.clientKeyExchange.pskid != nil){
+ n = m->u.clientKeyExchange.pskid->len;
+ put16(p, n);
+ p += 2;
+ memmove(p, m->u.clientKeyExchange.pskid->data, n);
+ p += n;
+ }
+ if(m->u.clientKeyExchange.key == nil)
+ break;
n = m->u.clientKeyExchange.key->len;
if(c->version != SSL3Version){
if(isECDHE(c->cipher))
@@ -1737,6 +1861,18 @@
case HServerHelloDone:
break;
case HServerKeyExchange:
+ if(isPSK(c->cipher)){
+ if(n < 2)
+ goto Short;
+ nn = get16(p);
+ p += 2, n -= 2;
+ if(nn > n)
+ goto Short;
+ m->u.serverKeyExchange.pskid = makebytes(p, nn);
+ p += nn, n -= nn;
+ if(n == 0)
+ break;
+ }
if(n < 2)
goto Short;
s = p;
@@ -1805,6 +1941,18 @@
* this message depends upon the encryption selected
* assume rsa.
*/
+ if(isPSK(c->cipher)){
+ if(n < 2)
+ goto Short;
+ nn = get16(p);
+ p += 2, n -= 2;
+ if(nn > n)
+ goto Short;
+ m->u.clientKeyExchange.pskid = makebytes(p, nn);
+ p += nn, n -= nn;
+ if(n == 0)
+ break;
+ }
if(c->version == SSL3Version)
nn = n;
else{
@@ -1883,6 +2031,7 @@
case HServerHelloDone:
break;
case HServerKeyExchange:
+ freebytes(m->u.serverKeyExchange.pskid);
freebytes(m->u.serverKeyExchange.dh_p);
freebytes(m->u.serverKeyExchange.dh_g);
freebytes(m->u.serverKeyExchange.dh_Ys);
@@ -1890,6 +2039,7 @@
freebytes(m->u.serverKeyExchange.dh_signature);
break;
case HClientKeyExchange:
+ freebytes(m->u.clientKeyExchange.pskid);
freebytes(m->u.clientKeyExchange.key);
break;
case HFinished:
@@ -1998,6 +2148,10 @@
break;
case HServerKeyExchange:
bs = seprint(bs, be, "HServerKeyExchange\n");
+ if(m->u.serverKeyExchange.pskid != nil)
+ bs = bytesPrint(bs, be, "\tpskid: ", m->u.serverKeyExchange.pskid, "\n");
+ if(m->u.serverKeyExchange.dh_parameters == nil)
+ break;
if(m->u.serverKeyExchange.curve != 0){
bs = seprint(bs, be, "\tcurve: %.4x\n", m->u.serverKeyExchange.curve);
} else {
@@ -2012,7 +2166,10 @@
break;
case HClientKeyExchange:
bs = seprint(bs, be, "HClientKeyExchange\n");
- bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
+ if(m->u.clientKeyExchange.pskid != nil)
+ bs = bytesPrint(bs, be, "\tpskid: ", m->u.clientKeyExchange.pskid, "\n");
+ if(m->u.clientKeyExchange.key != nil)
+ bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
break;
case HFinished:
bs = seprint(bs, be, "HFinished\n");
@@ -2137,7 +2294,7 @@
}
static int
-okCipher(Ints *cv)
+okCipher(Ints *cv, int ispsk)
{
int weak, i, j, c;
@@ -2148,6 +2305,8 @@
weak = 0;
else
weak &= weakCipher[c];
+ if(isPSK(c) != ispsk)
+ continue;
if(isDHE(c) || isECDHE(c))
continue; /* TODO: not implemented for server */
for(j = 0; j < nelem(cipherAlgs); j++)
@@ -2243,7 +2402,7 @@
}
static Ints*
-makeciphers(void)
+makeciphers(int ispsk)
{
Ints *is;
int i, j;
@@ -2250,10 +2409,10 @@
is = newints(nciphers);
j = 0;
- for(i = 0; i < nelem(cipherAlgs); i++){
- if(cipherAlgs[i].ok)
+ for(i = 0; i < nelem(cipherAlgs); i++)
+ if(cipherAlgs[i].ok && isPSK(cipherAlgs[i].tlsid) == ispsk)
is->data[j++] = cipherAlgs[i].tlsid;
- }
+ is->len = j;
return is;
}
@@ -2489,6 +2648,17 @@
return -1;
}
+static int
+tlsSecPSKs(TlsSec *sec, int vers)
+{
+ if(setVers(sec, vers) < 0){
+ sec->ok = -1;
+ return -1;
+ }
+ setMasterSecret(sec, newbytes(sec->psklen));
+ return 0;
+}
+
static TlsSec*
tlsSecInitc(int cvers, uchar *crandom)
{
@@ -2500,6 +2670,18 @@
return sec;
}
+static int
+tlsSecPSKc(TlsSec *sec, uchar *srandom, int vers)
+{
+ memmove(sec->srandom, srandom, RandomSize);
+ if(setVers(sec, vers) < 0){
+ sec->ok = -1;
+ return -1;
+ }
+ setMasterSecret(sec, newbytes(sec->psklen));
+ return 0;
+}
+
static Bytes*
tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers)
{
@@ -2608,6 +2790,22 @@
static void
setMasterSecret(TlsSec *sec, Bytes *pm)
{
+ if(sec->psklen > 0){
+ Bytes *opm = pm;
+ uchar *p;
+
+ /* concatenate psk to pre-master secret */
+ pm = newbytes(4 + opm->len + sec->psklen);
+ p = pm->data;
+ put16(p, opm->len), p += 2;
+ memmove(p, opm->data, opm->len), p += opm->len;
+ put16(p, sec->psklen), p += 2;
+ memmove(p, sec->psk, sec->psklen);
+
+ memset(opm->data, 0, opm->len);
+ freebytes(opm);
+ }
+
(*sec->prf)(sec->sec, MasterSecretSize, pm->data, pm->len, "master secret",
sec->crandom, RandomSize, sec->srandom, RandomSize);