ref: 4472cbe66edbb03865a9fb0f67970e0ce9deb0f2
dir: /sys/src/libsec/port/tlshand.c/
#include <u.h> #include <libc.h> #include <bio.h> #include <auth.h> #include <mp.h> #include <libsec.h> // The main groups of functions are: // client/server - main handshake protocol definition // message functions - formating handshake messages // cipher choices - catalog of digest and encrypt algorithms // security functions - PKCS#1, sslHMAC, session keygen // general utility functions - malloc, serialization // The handshake protocol builds on the TLS/SSL3 record layer protocol, // which is implemented in kernel device #a. See also /lib/rfc/rfc2246. enum { TLSFinishedLen = 12, SSL3FinishedLen = MD5dlen+SHA1dlen, MaxKeyData = 104, // amount of secret we may need MaxChunk = 1<<14, RandomSize = 32, SidSize = 32, MasterSecretSize = 48, AQueue = 0, AFlush = 1, }; typedef struct TlsSec TlsSec; typedef struct Bytes{ int len; uchar data[1]; // [len] } Bytes; typedef struct Ints{ int len; int data[1]; // [len] } Ints; typedef struct Algs{ char *enc; char *digest; int nsecret; int tlsid; int ok; } Algs; typedef struct Finished{ uchar verify[SSL3FinishedLen]; int n; } Finished; typedef struct TlsConnection{ TlsSec *sec; // security management goo int hand, ctl; // record layer file descriptors int erred; // set when tlsError called int (*trace)(char*fmt, ...); // for debugging int version; // protocol we are speaking int verset; // version has been set int ver2hi; // server got a version 2 hello int isClient; // is this the client or server? Bytes *sid; // SessionID Bytes *cert; // only last - no chain Lock statelk; int state; // must be set using setstate // input buffer for handshake messages uchar buf[MaxChunk+2048]; uchar *rp, *ep; uchar crandom[RandomSize]; // client random uchar srandom[RandomSize]; // server random int clientVersion; // version in ClientHello char *digest; // name of digest algorithm to use char *enc; // name of encryption algorithm to use int nsecret; // amount of secret data to init keys // for finished messages MD5state hsmd5; // handshake hash SHAstate hssha1; // handshake hash Finished finished; } TlsConnection; typedef struct Msg{ int tag; union { struct { int version; uchar random[RandomSize]; Bytes* sid; Ints* ciphers; Bytes* compressors; } clientHello; struct { int version; uchar random[RandomSize]; Bytes* sid; int cipher; int compressor; } serverHello; struct { int ncert; Bytes **certs; } certificate; struct { Bytes *types; int nca; Bytes **cas; } certificateRequest; struct { Bytes *key; } clientKeyExchange; Finished finished; } u; } Msg; typedef struct TlsSec{ char *server; // name of remote; nil for server int ok; // <0 killed; == 0 in progress; >0 reusable RSApub *rsapub; AuthRpc *rpc; // factotum for rsa private key uchar sec[MasterSecretSize]; // master secret uchar crandom[RandomSize]; // client random uchar srandom[RandomSize]; // server random int clientVers; // version in ClientHello int vers; // final version // byte generation and handshake checksum void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int); void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int); int nfin; } TlsSec; enum { TLSVersion = 0x0301, SSL3Version = 0x0300, ProtocolVersion = 0x0301, // maximum version we speak MinProtoVersion = 0x0300, // limits on version we accept MaxProtoVersion = 0x03ff, }; // handshake type enum { HHelloRequest, HClientHello, HServerHello, HSSL2ClientHello = 9, /* local convention; see devtls.c */ HCertificate = 11, HServerKeyExchange, HCertificateRequest, HServerHelloDone, HCertificateVerify, HClientKeyExchange, HFinished = 20, HMax }; // alerts enum { ECloseNotify = 0, EUnexpectedMessage = 10, EBadRecordMac = 20, EDecryptionFailed = 21, ERecordOverflow = 22, EDecompressionFailure = 30, EHandshakeFailure = 40, ENoCertificate = 41, EBadCertificate = 42, EUnsupportedCertificate = 43, ECertificateRevoked = 44, ECertificateExpired = 45, ECertificateUnknown = 46, EIllegalParameter = 47, EUnknownCa = 48, EAccessDenied = 49, EDecodeError = 50, EDecryptError = 51, EExportRestriction = 60, EProtocolVersion = 70, EInsufficientSecurity = 71, EInternalError = 80, EUserCanceled = 90, ENoRenegotiation = 100, EMax = 256 }; // cipher suites enum { TLS_NULL_WITH_NULL_NULL = 0x0000, TLS_RSA_WITH_NULL_MD5 = 0x0001, TLS_RSA_WITH_NULL_SHA = 0x0002, TLS_RSA_EXPORT_WITH_RC4_40_MD5 = 0x0003, TLS_RSA_WITH_RC4_128_MD5 = 0x0004, TLS_RSA_WITH_RC4_128_SHA = 0x0005, TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 = 0X0006, TLS_RSA_WITH_IDEA_CBC_SHA = 0X0007, TLS_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X0008, TLS_RSA_WITH_DES_CBC_SHA = 0X0009, TLS_RSA_WITH_3DES_EDE_CBC_SHA = 0X000A, TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA = 0X000B, TLS_DH_DSS_WITH_DES_CBC_SHA = 0X000C, TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA = 0X000D, TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X000E, TLS_DH_RSA_WITH_DES_CBC_SHA = 0X000F, TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA = 0X0010, TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA = 0X0011, TLS_DHE_DSS_WITH_DES_CBC_SHA = 0X0012, TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA = 0X0013, // ZZZ must be implemented for tls1.0 compliance TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X0014, TLS_DHE_RSA_WITH_DES_CBC_SHA = 0X0015, TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA = 0X0016, TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 = 0x0017, TLS_DH_anon_WITH_RC4_128_MD5 = 0x0018, TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA = 0X0019, TLS_DH_anon_WITH_DES_CBC_SHA = 0X001A, TLS_DH_anon_WITH_3DES_EDE_CBC_SHA = 0X001B, TLS_RSA_WITH_AES_128_CBC_SHA = 0X002f, // aes, aka rijndael with 128 bit blocks TLS_DH_DSS_WITH_AES_128_CBC_SHA = 0X0030, TLS_DH_RSA_WITH_AES_128_CBC_SHA = 0X0031, TLS_DHE_DSS_WITH_AES_128_CBC_SHA = 0X0032, TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0X0033, TLS_DH_anon_WITH_AES_128_CBC_SHA = 0X0034, TLS_RSA_WITH_AES_256_CBC_SHA = 0X0035, TLS_DH_DSS_WITH_AES_256_CBC_SHA = 0X0036, TLS_DH_RSA_WITH_AES_256_CBC_SHA = 0X0037, TLS_DHE_DSS_WITH_AES_256_CBC_SHA = 0X0038, TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0X0039, TLS_DH_anon_WITH_AES_256_CBC_SHA = 0X003A, CipherMax }; // compression methods enum { CompressionNull = 0, CompressionMax }; static Algs cipherAlgs[] = { {"rc4_128", "md5", 2 * (16 + MD5dlen), TLS_RSA_WITH_RC4_128_MD5}, {"rc4_128", "sha1", 2 * (16 + SHA1dlen), TLS_RSA_WITH_RC4_128_SHA}, {"3des_ede_cbc","sha1",2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA}, }; static uchar compressors[] = { CompressionNull, }; static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chain); static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...)); static void msgClear(Msg *m); static char* msgPrint(char *buf, int n, Msg *m); static int msgRecv(TlsConnection *c, Msg *m); static int msgSend(TlsConnection *c, Msg *m, int act); static void tlsError(TlsConnection *c, int err, char *msg, ...); #pragma varargck argpos tlsError 3 static int setVersion(TlsConnection *c, int version); static int finishedMatch(TlsConnection *c, Finished *f); static void tlsConnectionFree(TlsConnection *c); static int setAlgs(TlsConnection *c, int a); static int okCipher(Ints *cv); static int okCompression(Bytes *cv); static int initCiphers(void); static Ints* makeciphers(void); static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom); static int tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd); static TlsSec* tlsSecInitc(int cvers, uchar *crandom); static int tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd); static int tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient); static void tlsSecOk(TlsSec *sec); static void tlsSecKill(TlsSec *sec); static void tlsSecClose(TlsSec *sec); static void setMasterSecret(TlsSec *sec, Bytes *pm); static void serverMasterSecret(TlsSec *sec, uchar *epm, int nepm); static void setSecrets(TlsSec *sec, uchar *kd, int nkd); static int clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm); static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype); static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm); static void tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient); static void sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient); static void sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1); static int setVers(TlsSec *sec, int version); static AuthRpc* factotum_rsa_open(uchar *cert, int certlen); static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher); static void factotum_rsa_close(AuthRpc*rpc); static void* emalloc(int); static void* erealloc(void*, int); static void put32(uchar *p, u32int); static void put24(uchar *p, int); static void put16(uchar *p, int); static u32int get32(uchar *p); static int get24(uchar *p); static int get16(uchar *p); static Bytes* newbytes(int len); static Bytes* makebytes(uchar* buf, int len); static void freebytes(Bytes* b); static Ints* newints(int len); static Ints* makeints(int* buf, int len); static void freeints(Ints* b); //================= client/server ======================== // push TLS onto fd, returning new (application) file descriptor // or -1 if error. int tlsServer(int fd, TLSconn *conn) { char buf[8]; char dname[64]; int n, data, ctl, hand; TlsConnection *tls; if(conn == nil) return -1; ctl = open("#a/tls/clone", ORDWR); if(ctl < 0) return -1; n = read(ctl, buf, sizeof(buf)-1); if(n < 0){ close(ctl); return -1; } buf[n] = 0; sprint(conn->dir, "#a/tls/%s", buf); sprint(dname, "#a/tls/%s/hand", buf); hand = open(dname, ORDWR); if(hand < 0){ close(ctl); return -1; } fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion); tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain); sprint(dname, "#a/tls/%s/data", buf); data = open(dname, ORDWR); close(fd); close(hand); close(ctl); if(data < 0){ return -1; } if(tls == nil){ close(data); return -1; } if(conn->cert) free(conn->cert); conn->cert = 0; // client certificates are not yet implemented conn->certlen = 0; conn->sessionIDlen = tls->sid->len; conn->sessionID = emalloc(conn->sessionIDlen); memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen); if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0) tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst, tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize); tlsConnectionFree(tls); return data; } // push TLS onto fd, returning new (application) file descriptor // or -1 if error. int tlsClient(int fd, TLSconn *conn) { char buf[8]; char dname[64]; int n, data, ctl, hand; TlsConnection *tls; if(!conn) return -1; ctl = open("#a/tls/clone", ORDWR); if(ctl < 0) return -1; n = read(ctl, buf, sizeof(buf)-1); if(n < 0){ close(ctl); return -1; } buf[n] = 0; sprint(conn->dir, "#a/tls/%s", buf); sprint(dname, "#a/tls/%s/hand", buf); hand = open(dname, ORDWR); if(hand < 0){ close(ctl); return -1; } sprint(dname, "#a/tls/%s/data", buf); data = open(dname, ORDWR); if(data < 0) return -1; fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion); tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->trace); close(fd); close(hand); close(ctl); if(tls == nil){ close(data); return -1; } conn->certlen = tls->cert->len; conn->cert = emalloc(conn->certlen); memcpy(conn->cert, tls->cert->data, conn->certlen); conn->sessionIDlen = tls->sid->len; conn->sessionID = emalloc(conn->sessionIDlen); memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen); if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0) tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst, tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize); tlsConnectionFree(tls); return data; } static int countchain(PEMChain *p) { int i = 0; while (p) { i++; p = p->next; } return i; } static TlsConnection * tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chp) { TlsConnection *c; Msg m; Bytes *csid; uchar sid[SidSize], kd[MaxKeyData]; char *secrets; int cipher, compressor, nsid, rv, numcerts, i; if(trace) trace("tlsServer2\n"); if(!initCiphers()) return nil; c = emalloc(sizeof(TlsConnection)); c->ctl = ctl; c->hand = hand; c->trace = trace; c->version = ProtocolVersion; memset(&m, 0, sizeof(m)); if(!msgRecv(c, &m)){ if(trace) trace("initial msgRecv failed\n"); goto Err; } if(m.tag != HClientHello) { tlsError(c, EUnexpectedMessage, "expected a client hello"); goto Err; } c->clientVersion = m.u.clientHello.version; if(trace) trace("ClientHello version %x\n", c->clientVersion); if(setVersion(c, m.u.clientHello.version) < 0) { tlsError(c, EIllegalParameter, "incompatible version"); goto Err; } memmove(c->crandom, m.u.clientHello.random, RandomSize); cipher = okCipher(m.u.clientHello.ciphers); if(cipher < 0) { // reply with EInsufficientSecurity if we know that's the case if(cipher == -2) tlsError(c, EInsufficientSecurity, "cipher suites too weak"); else tlsError(c, EHandshakeFailure, "no matching cipher suite"); goto Err; } if(!setAlgs(c, cipher)){ tlsError(c, EHandshakeFailure, "no matching cipher suite"); goto Err; } compressor = okCompression(m.u.clientHello.compressors); if(compressor < 0) { tlsError(c, EHandshakeFailure, "no matching compressor"); goto Err; } csid = m.u.clientHello.sid; if(trace) trace(" cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len); c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom); if(c->sec == nil){ tlsError(c, EHandshakeFailure, "can't initialize security: %r"); goto Err; } c->sec->rpc = factotum_rsa_open(cert, ncert); if(c->sec->rpc == nil){ tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r"); goto Err; } c->sec->rsapub = X509toRSApub(cert, ncert, nil, 0); msgClear(&m); m.tag = HServerHello; m.u.serverHello.version = c->version; memmove(m.u.serverHello.random, c->srandom, RandomSize); m.u.serverHello.cipher = cipher; m.u.serverHello.compressor = compressor; c->sid = makebytes(sid, nsid); m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len); if(!msgSend(c, &m, AQueue)) goto Err; msgClear(&m); m.tag = HCertificate; numcerts = countchain(chp); m.u.certificate.ncert = 1 + numcerts; m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes)); m.u.certificate.certs[0] = makebytes(cert, ncert); for (i = 0; i < numcerts && chp; i++, chp = chp->next) m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen); if(!msgSend(c, &m, AQueue)) goto Err; msgClear(&m); m.tag = HServerHelloDone; if(!msgSend(c, &m, AFlush)) goto Err; msgClear(&m); if(!msgRecv(c, &m)) goto Err; if(m.tag != HClientKeyExchange) { tlsError(c, EUnexpectedMessage, "expected a client key exchange"); goto Err; } if(tlsSecSecrets(c->sec, c->version, m.u.clientKeyExchange.key->data, m.u.clientKeyExchange.key->len, kd, c->nsecret) < 0){ tlsError(c, EHandshakeFailure, "couldn't set secrets: %r"); goto Err; } if(trace) trace("tls secrets\n"); secrets = (char*)emalloc(2*c->nsecret); enc64(secrets, 2*c->nsecret, kd, c->nsecret); rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets); memset(secrets, 0, 2*c->nsecret); free(secrets); memset(kd, 0, c->nsecret); if(rv < 0){ tlsError(c, EHandshakeFailure, "can't set keys: %r"); goto Err; } msgClear(&m); /* no CertificateVerify; skip to Finished */ if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){ tlsError(c, EInternalError, "can't set finished: %r"); goto Err; } if(!msgRecv(c, &m)) goto Err; if(m.tag != HFinished) { tlsError(c, EUnexpectedMessage, "expected a finished"); goto Err; } if(!finishedMatch(c, &m.u.finished)) { tlsError(c, EHandshakeFailure, "finished verification failed"); goto Err; } msgClear(&m); /* change cipher spec */ if(fprint(c->ctl, "changecipher") < 0){ tlsError(c, EInternalError, "can't enable cipher: %r"); goto Err; } if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){ tlsError(c, EInternalError, "can't set finished: %r"); goto Err; } m.tag = HFinished; m.u.finished = c->finished; if(!msgSend(c, &m, AFlush)) goto Err; msgClear(&m); if(trace) trace("tls finished\n"); if(fprint(c->ctl, "opened") < 0) goto Err; tlsSecOk(c->sec); return c; Err: msgClear(&m); tlsConnectionFree(c); return 0; } static TlsConnection * tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...)) { TlsConnection *c; Msg m; uchar kd[MaxKeyData], *epm; char *secrets; int creq, nepm, rv; if(!initCiphers()) return nil; epm = nil; c = emalloc(sizeof(TlsConnection)); c->version = ProtocolVersion; c->ctl = ctl; c->hand = hand; c->trace = trace; c->isClient = 1; c->clientVersion = c->version; c->sec = tlsSecInitc(c->clientVersion, c->crandom); if(c->sec == nil) goto Err; /* client hello */ memset(&m, 0, sizeof(m)); m.tag = HClientHello; m.u.clientHello.version = c->clientVersion; memmove(m.u.clientHello.random, c->crandom, RandomSize); m.u.clientHello.sid = makebytes(csid, ncsid); m.u.clientHello.ciphers = makeciphers(); m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors)); if(!msgSend(c, &m, AFlush)) goto Err; msgClear(&m); /* server hello */ if(!msgRecv(c, &m)) goto Err; if(m.tag != HServerHello) { tlsError(c, EUnexpectedMessage, "expected a server hello"); goto Err; } if(setVersion(c, m.u.serverHello.version) < 0) { tlsError(c, EIllegalParameter, "incompatible version %r"); goto Err; } memmove(c->srandom, m.u.serverHello.random, RandomSize); c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len); if(c->sid->len != 0 && c->sid->len != SidSize) { tlsError(c, EIllegalParameter, "invalid server session identifier"); goto Err; } if(!setAlgs(c, m.u.serverHello.cipher)) { tlsError(c, EIllegalParameter, "invalid cipher suite"); goto Err; } if(m.u.serverHello.compressor != CompressionNull) { tlsError(c, EIllegalParameter, "invalid compression"); goto Err; } msgClear(&m); /* certificate */ if(!msgRecv(c, &m) || m.tag != HCertificate) { tlsError(c, EUnexpectedMessage, "expected a certificate"); goto Err; } if(m.u.certificate.ncert < 1) { tlsError(c, EIllegalParameter, "runt certificate"); goto Err; } c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len); msgClear(&m); /* server key exchange (optional) */ if(!msgRecv(c, &m)) goto Err; if(m.tag == HServerKeyExchange) { tlsError(c, EUnexpectedMessage, "got an server key exchange"); goto Err; // If implementing this later, watch out for rollback attack // described in Wagner Schneier 1996, section 4.4. } /* certificate request (optional) */ creq = 0; if(m.tag == HCertificateRequest) { creq = 1; msgClear(&m); if(!msgRecv(c, &m)) goto Err; } if(m.tag != HServerHelloDone) { tlsError(c, EUnexpectedMessage, "expected a server hello done"); goto Err; } msgClear(&m); if(tlsSecSecretc(c->sec, c->sid->data, c->sid->len, c->srandom, c->cert->data, c->cert->len, c->version, &epm, &nepm, kd, c->nsecret) < 0){ tlsError(c, EBadCertificate, "invalid x509/rsa certificate"); goto Err; } secrets = (char*)emalloc(2*c->nsecret); enc64(secrets, 2*c->nsecret, kd, c->nsecret); rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets); memset(secrets, 0, 2*c->nsecret); free(secrets); memset(kd, 0, c->nsecret); if(rv < 0){ tlsError(c, EHandshakeFailure, "can't set keys: %r"); goto Err; } if(creq) { /* send a zero length certificate */ m.tag = HCertificate; if(!msgSend(c, &m, AFlush)) goto Err; msgClear(&m); } /* client key exchange */ m.tag = HClientKeyExchange; m.u.clientKeyExchange.key = makebytes(epm, nepm); free(epm); epm = nil; if(m.u.clientKeyExchange.key == nil) { tlsError(c, EHandshakeFailure, "can't set secret: %r"); goto Err; } if(!msgSend(c, &m, AFlush)) goto Err; msgClear(&m); /* change cipher spec */ if(fprint(c->ctl, "changecipher") < 0){ tlsError(c, EInternalError, "can't enable cipher: %r"); goto Err; } // Cipherchange must occur immediately before Finished to avoid // potential hole; see section 4.3 of Wagner Schneier 1996. if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){ tlsError(c, EInternalError, "can't set finished 1: %r"); goto Err; } m.tag = HFinished; m.u.finished = c->finished; if(!msgSend(c, &m, AFlush)) { fprint(2, "tlsClient nepm=%d\n", nepm); tlsError(c, EInternalError, "can't flush after client Finished: %r"); goto Err; } msgClear(&m); if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){ fprint(2, "tlsClient nepm=%d\n", nepm); tlsError(c, EInternalError, "can't set finished 0: %r"); goto Err; } if(!msgRecv(c, &m)) { fprint(2, "tlsClient nepm=%d\n", nepm); tlsError(c, EInternalError, "can't read server Finished: %r"); goto Err; } if(m.tag != HFinished) { fprint(2, "tlsClient nepm=%d\n", nepm); tlsError(c, EUnexpectedMessage, "expected a Finished msg from server"); goto Err; } if(!finishedMatch(c, &m.u.finished)) { tlsError(c, EHandshakeFailure, "finished verification failed"); goto Err; } msgClear(&m); if(fprint(c->ctl, "opened") < 0){ if(trace) trace("unable to do final open: %r\n"); goto Err; } tlsSecOk(c->sec); return c; Err: free(epm); msgClear(&m); tlsConnectionFree(c); return 0; } //================= message functions ======================== static uchar sendbuf[9000], *sendp; static int msgSend(TlsConnection *c, Msg *m, int act) { uchar *p; // sendp = start of new message; p = write pointer int nn, n, i; if(sendp == nil) sendp = sendbuf; p = sendp; if(c->trace) c->trace("send %s", msgPrint((char*)p, (sizeof sendbuf) - (p-sendbuf), m)); p[0] = m->tag; // header - fill in size later p += 4; switch(m->tag) { default: tlsError(c, EInternalError, "can't encode a %d", m->tag); goto Err; case HClientHello: // version put16(p, m->u.clientHello.version); p += 2; // random memmove(p, m->u.clientHello.random, RandomSize); p += RandomSize; // sid n = m->u.clientHello.sid->len; assert(n < 256); p[0] = n; memmove(p+1, m->u.clientHello.sid->data, n); p += n+1; n = m->u.clientHello.ciphers->len; assert(n > 0 && n < 200); put16(p, n*2); p += 2; for(i=0; i<n; i++) { put16(p, m->u.clientHello.ciphers->data[i]); p += 2; } n = m->u.clientHello.compressors->len; assert(n > 0); p[0] = n; memmove(p+1, m->u.clientHello.compressors->data, n); p += n+1; break; case HServerHello: put16(p, m->u.serverHello.version); p += 2; // random memmove(p, m->u.serverHello.random, RandomSize); p += RandomSize; // sid n = m->u.serverHello.sid->len; assert(n < 256); p[0] = n; memmove(p+1, m->u.serverHello.sid->data, n); p += n+1; put16(p, m->u.serverHello.cipher); p += 2; p[0] = m->u.serverHello.compressor; p += 1; break; case HServerHelloDone: break; case HCertificate: nn = 0; for(i = 0; i < m->u.certificate.ncert; i++) nn += 3 + m->u.certificate.certs[i]->len; if(p + 3 + nn - sendbuf > sizeof(sendbuf)) { tlsError(c, EInternalError, "output buffer too small for certificate"); goto Err; } put24(p, nn); p += 3; for(i = 0; i < m->u.certificate.ncert; i++){ put24(p, m->u.certificate.certs[i]->len); p += 3; memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len); p += m->u.certificate.certs[i]->len; } break; case HClientKeyExchange: n = m->u.clientKeyExchange.key->len; if(c->version != SSL3Version){ put16(p, n); p += 2; } memmove(p, m->u.clientKeyExchange.key->data, n); p += n; break; case HFinished: memmove(p, m->u.finished.verify, m->u.finished.n); p += m->u.finished.n; break; } // go back and fill in size n = p-sendp; assert(p <= sendbuf+sizeof(sendbuf)); put24(sendp+1, n-4); // remember hash of Handshake messages if(m->tag != HHelloRequest) { md5(sendp, n, 0, &c->hsmd5); sha1(sendp, n, 0, &c->hssha1); } sendp = p; if(act == AFlush){ sendp = sendbuf; if(write(c->hand, sendbuf, p-sendbuf) < 0){ fprint(2, "write error: %r\n"); goto Err; } } msgClear(m); return 1; Err: msgClear(m); return 0; } static uchar* tlsReadN(TlsConnection *c, int n) { uchar *p; int nn, nr; nn = c->ep - c->rp; if(nn < n){ if(c->rp != c->buf){ memmove(c->buf, c->rp, nn); c->rp = c->buf; c->ep = &c->buf[nn]; } for(; nn < n; nn += nr) { nr = read(c->hand, &c->rp[nn], n - nn); if(nr <= 0) return nil; c->ep += nr; } } p = c->rp; c->rp += n; return p; } static int msgRecv(TlsConnection *c, Msg *m) { uchar *p; int type, n, nn, i, nsid, nrandom, nciph; for(;;) { p = tlsReadN(c, 4); if(p == nil) return 0; type = p[0]; n = get24(p+1); if(type != HHelloRequest) break; if(n != 0) { tlsError(c, EDecodeError, "invalid hello request during handshake"); return 0; } } if(n > sizeof(c->buf)) { tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->buf)); return 0; } if(type == HSSL2ClientHello){ /* Cope with an SSL3 ClientHello expressed in SSL2 record format. This is sent by some clients that we must interoperate with, such as Java's JSSE and Microsoft's Internet Explorer. */ p = tlsReadN(c, n); if(p == nil) return 0; md5(p, n, 0, &c->hsmd5); sha1(p, n, 0, &c->hssha1); m->tag = HClientHello; if(n < 22) goto Short; m->u.clientHello.version = get16(p+1); p += 3; n -= 3; nn = get16(p); /* cipher_spec_len */ nsid = get16(p + 2); nrandom = get16(p + 4); p += 6; n -= 6; if(nsid != 0 /* no sid's, since shouldn't restart using ssl2 header */ || nrandom < 16 || nn % 3) goto Err; if(c->trace && (n - nrandom != nn)) c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn); /* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */ nciph = 0; for(i = 0; i < nn; i += 3) if(p[i] == 0) nciph++; m->u.clientHello.ciphers = newints(nciph); nciph = 0; for(i = 0; i < nn; i += 3) if(p[i] == 0) m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]); p += nn; m->u.clientHello.sid = makebytes(nil, 0); if(nrandom > RandomSize) nrandom = RandomSize; memset(m->u.clientHello.random, 0, RandomSize - nrandom); memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom); m->u.clientHello.compressors = newbytes(1); m->u.clientHello.compressors->data[0] = CompressionNull; goto Ok; } md5(p, 4, 0, &c->hsmd5); sha1(p, 4, 0, &c->hssha1); p = tlsReadN(c, n); if(p == nil) return 0; md5(p, n, 0, &c->hsmd5); sha1(p, n, 0, &c->hssha1); m->tag = type; switch(type) { default: tlsError(c, EUnexpectedMessage, "can't decode a %d", type); goto Err; case HClientHello: if(n < 2) goto Short; m->u.clientHello.version = get16(p); p += 2; n -= 2; if(n < RandomSize) goto Short; memmove(m->u.clientHello.random, p, RandomSize); p += RandomSize; n -= RandomSize; if(n < 1 || n < p[0]+1) goto Short; m->u.clientHello.sid = makebytes(p+1, p[0]); p += m->u.clientHello.sid->len+1; n -= m->u.clientHello.sid->len+1; if(n < 2) goto Short; nn = get16(p); p += 2; n -= 2; if((nn & 1) || n < nn || nn < 2) goto Short; m->u.clientHello.ciphers = newints(nn >> 1); for(i = 0; i < nn; i += 2) m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]); p += nn; n -= nn; if(n < 1 || n < p[0]+1 || p[0] == 0) goto Short; nn = p[0]; m->u.clientHello.compressors = newbytes(nn); memmove(m->u.clientHello.compressors->data, p+1, nn); n -= nn + 1; break; case HServerHello: if(n < 2) goto Short; m->u.serverHello.version = get16(p); p += 2; n -= 2; if(n < RandomSize) goto Short; memmove(m->u.serverHello.random, p, RandomSize); p += RandomSize; n -= RandomSize; if(n < 1 || n < p[0]+1) goto Short; m->u.serverHello.sid = makebytes(p+1, p[0]); p += m->u.serverHello.sid->len+1; n -= m->u.serverHello.sid->len+1; if(n < 3) goto Short; m->u.serverHello.cipher = get16(p); m->u.serverHello.compressor = p[2]; n -= 3; break; case HCertificate: if(n < 3) goto Short; nn = get24(p); p += 3; n -= 3; if(n != nn) goto Short; /* certs */ i = 0; while(n > 0) { if(n < 3) goto Short; nn = get24(p); p += 3; n -= 3; if(nn > n) goto Short; m->u.certificate.ncert = i+1; m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes)); m->u.certificate.certs[i] = makebytes(p, nn); p += nn; n -= nn; i++; } break; case HCertificateRequest: if(n < 1) goto Short; nn = p[0]; p += 1; n -= 1; if(nn < 1 || nn > n) goto Short; m->u.certificateRequest.types = makebytes(p, nn); p += nn; n -= nn; if(n < 2) goto Short; nn = get16(p); p += 2; n -= 2; /* nn == 0 can happen; yahoo's servers do it */ if(nn != n) goto Short; /* cas */ i = 0; while(n > 0) { if(n < 2) goto Short; nn = get16(p); p += 2; n -= 2; if(nn < 1 || nn > n) goto Short; m->u.certificateRequest.nca = i+1; m->u.certificateRequest.cas = erealloc( m->u.certificateRequest.cas, (i+1)*sizeof(Bytes)); m->u.certificateRequest.cas[i] = makebytes(p, nn); p += nn; n -= nn; i++; } break; case HServerHelloDone: break; case HClientKeyExchange: /* * this message depends upon the encryption selected * assume rsa. */ if(c->version == SSL3Version) nn = n; else{ if(n < 2) goto Short; nn = get16(p); p += 2; n -= 2; } if(n < nn) goto Short; m->u.clientKeyExchange.key = makebytes(p, nn); n -= nn; break; case HFinished: m->u.finished.n = c->finished.n; if(n < m->u.finished.n) goto Short; memmove(m->u.finished.verify, p, m->u.finished.n); n -= m->u.finished.n; break; } if(type != HClientHello && n != 0) goto Short; Ok: if(c->trace){ char *buf; buf = emalloc(8000); c->trace("recv %s", msgPrint(buf, 8000, m)); free(buf); } return 1; Short: tlsError(c, EDecodeError, "handshake message has invalid length"); Err: msgClear(m); return 0; } static void msgClear(Msg *m) { int i; switch(m->tag) { default: sysfatal("msgClear: unknown message type: %d", m->tag); case HHelloRequest: break; case HClientHello: freebytes(m->u.clientHello.sid); freeints(m->u.clientHello.ciphers); freebytes(m->u.clientHello.compressors); break; case HServerHello: freebytes(m->u.clientHello.sid); break; case HCertificate: for(i=0; i<m->u.certificate.ncert; i++) freebytes(m->u.certificate.certs[i]); free(m->u.certificate.certs); break; case HCertificateRequest: freebytes(m->u.certificateRequest.types); for(i=0; i<m->u.certificateRequest.nca; i++) freebytes(m->u.certificateRequest.cas[i]); free(m->u.certificateRequest.cas); break; case HServerHelloDone: break; case HClientKeyExchange: freebytes(m->u.clientKeyExchange.key); break; case HFinished: break; } memset(m, 0, sizeof(Msg)); } static char * bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1) { int i; if(s0) bs = seprint(bs, be, "%s", s0); bs = seprint(bs, be, "["); if(b == nil) bs = seprint(bs, be, "nil"); else for(i=0; i<b->len; i++) bs = seprint(bs, be, "%.2x ", b->data[i]); bs = seprint(bs, be, "]"); if(s1) bs = seprint(bs, be, "%s", s1); return bs; } static char * intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1) { int i; if(s0) bs = seprint(bs, be, "%s", s0); bs = seprint(bs, be, "["); if(b == nil) bs = seprint(bs, be, "nil"); else for(i=0; i<b->len; i++) bs = seprint(bs, be, "%x ", b->data[i]); bs = seprint(bs, be, "]"); if(s1) bs = seprint(bs, be, "%s", s1); return bs; } static char* msgPrint(char *buf, int n, Msg *m) { int i; char *bs = buf, *be = buf+n; switch(m->tag) { default: bs = seprint(bs, be, "unknown %d\n", m->tag); break; case HClientHello: bs = seprint(bs, be, "ClientHello\n"); bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version); bs = seprint(bs, be, "\trandom: "); for(i=0; i<RandomSize; i++) bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]); bs = seprint(bs, be, "\n"); bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n"); bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n"); bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n"); break; case HServerHello: bs = seprint(bs, be, "ServerHello\n"); bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version); bs = seprint(bs, be, "\trandom: "); for(i=0; i<RandomSize; i++) bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]); bs = seprint(bs, be, "\n"); bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n"); bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher); bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor); break; case HCertificate: bs = seprint(bs, be, "Certificate\n"); for(i=0; i<m->u.certificate.ncert; i++) bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n"); break; case HCertificateRequest: bs = seprint(bs, be, "CertificateRequest\n"); bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n"); bs = seprint(bs, be, "\tcertificateauthorities\n"); for(i=0; i<m->u.certificateRequest.nca; i++) bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n"); break; case HServerHelloDone: bs = seprint(bs, be, "ServerHelloDone\n"); break; case HClientKeyExchange: bs = seprint(bs, be, "HClientKeyExchange\n"); bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n"); break; case HFinished: bs = seprint(bs, be, "HFinished\n"); for(i=0; i<m->u.finished.n; i++) bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]); bs = seprint(bs, be, "\n"); break; } USED(bs); return buf; } static void tlsError(TlsConnection *c, int err, char *fmt, ...) { char msg[512]; va_list arg; va_start(arg, fmt); vseprint(msg, msg+sizeof(msg), fmt, arg); va_end(arg); if(c->trace) c->trace("tlsError: %s\n", msg); else if(c->erred) fprint(2, "double error: %r, %s", msg); else werrstr("tls: local %s", msg); c->erred = 1; fprint(c->ctl, "alert %d", err); } // commit to specific version number static int setVersion(TlsConnection *c, int version) { if(c->verset || version > MaxProtoVersion || version < MinProtoVersion) return -1; if(version > c->version) version = c->version; if(version == SSL3Version) { c->version = version; c->finished.n = SSL3FinishedLen; }else if(version == TLSVersion){ c->version = version; c->finished.n = TLSFinishedLen; }else return -1; c->verset = 1; return fprint(c->ctl, "version 0x%x", version); } // confirm that received Finished message matches the expected value static int finishedMatch(TlsConnection *c, Finished *f) { return memcmp(f->verify, c->finished.verify, f->n) == 0; } // free memory associated with TlsConnection struct // (but don't close the TLS channel itself) static void tlsConnectionFree(TlsConnection *c) { tlsSecClose(c->sec); freebytes(c->sid); freebytes(c->cert); memset(c, 0, sizeof(c)); free(c); } //================= cipher choices ======================== static int weakCipher[CipherMax] = { 1, /* TLS_NULL_WITH_NULL_NULL */ 1, /* TLS_RSA_WITH_NULL_MD5 */ 1, /* TLS_RSA_WITH_NULL_SHA */ 1, /* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */ 0, /* TLS_RSA_WITH_RC4_128_MD5 */ 0, /* TLS_RSA_WITH_RC4_128_SHA */ 1, /* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */ 0, /* TLS_RSA_WITH_IDEA_CBC_SHA */ 1, /* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */ 0, /* TLS_RSA_WITH_DES_CBC_SHA */ 0, /* TLS_RSA_WITH_3DES_EDE_CBC_SHA */ 1, /* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */ 0, /* TLS_DH_DSS_WITH_DES_CBC_SHA */ 0, /* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */ 1, /* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */ 0, /* TLS_DH_RSA_WITH_DES_CBC_SHA */ 0, /* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */ 1, /* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */ 0, /* TLS_DHE_DSS_WITH_DES_CBC_SHA */ 0, /* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */ 1, /* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */ 0, /* TLS_DHE_RSA_WITH_DES_CBC_SHA */ 0, /* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */ 1, /* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */ 1, /* TLS_DH_anon_WITH_RC4_128_MD5 */ 1, /* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */ 1, /* TLS_DH_anon_WITH_DES_CBC_SHA */ 1, /* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */ }; static int setAlgs(TlsConnection *c, int a) { int i; for(i = 0; i < nelem(cipherAlgs); i++){ if(cipherAlgs[i].tlsid == a){ c->enc = cipherAlgs[i].enc; c->digest = cipherAlgs[i].digest; c->nsecret = cipherAlgs[i].nsecret; if(c->nsecret > MaxKeyData) return 0; return 1; } } return 0; } static int okCipher(Ints *cv) { int weak, i, j, c; weak = 1; for(i = 0; i < cv->len; i++) { c = cv->data[i]; if(c >= CipherMax) weak = 0; else weak &= weakCipher[c]; for(j = 0; j < nelem(cipherAlgs); j++) if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c) return c; } if(weak) return -2; return -1; } static int okCompression(Bytes *cv) { int i, j, c; for(i = 0; i < cv->len; i++) { c = cv->data[i]; for(j = 0; j < nelem(compressors); j++) { if(compressors[j] == c) return c; } } return -1; } static Lock ciphLock; static int nciphers; static int initCiphers(void) { enum {MaxAlgF = 1024, MaxAlgs = 10}; char s[MaxAlgF], *flds[MaxAlgs]; int i, j, n, ok; lock(&ciphLock); if(nciphers){ unlock(&ciphLock); return nciphers; } j = open("#a/tls/encalgs", OREAD); if(j < 0){ werrstr("can't open #a/tls/encalgs: %r"); return 0; } n = read(j, s, MaxAlgF-1); close(j); if(n <= 0){ werrstr("nothing in #a/tls/encalgs: %r"); return 0; } s[n] = 0; n = getfields(s, flds, MaxAlgs, 1, " \t\r\n"); for(i = 0; i < nelem(cipherAlgs); i++){ ok = 0; for(j = 0; j < n; j++){ if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){ ok = 1; break; } } cipherAlgs[i].ok = ok; } j = open("#a/tls/hashalgs", OREAD); if(j < 0){ werrstr("can't open #a/tls/hashalgs: %r"); return 0; } n = read(j, s, MaxAlgF-1); close(j); if(n <= 0){ werrstr("nothing in #a/tls/hashalgs: %r"); return 0; } s[n] = 0; n = getfields(s, flds, MaxAlgs, 1, " \t\r\n"); for(i = 0; i < nelem(cipherAlgs); i++){ ok = 0; for(j = 0; j < n; j++){ if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){ ok = 1; break; } } cipherAlgs[i].ok &= ok; if(cipherAlgs[i].ok) nciphers++; } unlock(&ciphLock); return nciphers; } static Ints* makeciphers(void) { Ints *is; int i, j; is = newints(nciphers); j = 0; for(i = 0; i < nelem(cipherAlgs); i++){ if(cipherAlgs[i].ok) is->data[j++] = cipherAlgs[i].tlsid; } return is; } //================= security functions ======================== // given X.509 certificate, set up connection to factotum // for using corresponding private key static AuthRpc* factotum_rsa_open(uchar *cert, int certlen) { int afd; char *s; mpint *pub = nil; RSApub *rsapub; AuthRpc *rpc; // start talking to factotum if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0) return nil; if((rpc = auth_allocrpc(afd)) == nil){ close(afd); return nil; } s = "proto=rsa service=tls role=client"; if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){ factotum_rsa_close(rpc); return nil; } // roll factotum keyring around to match certificate rsapub = X509toRSApub(cert, certlen, nil, 0); while(1){ if(auth_rpc(rpc, "read", nil, 0) != ARok){ factotum_rsa_close(rpc); rpc = nil; goto done; } pub = strtomp(rpc->arg, nil, 16, nil); assert(pub != nil); if(mpcmp(pub,rsapub->n) == 0) break; } done: mpfree(pub); rsapubfree(rsapub); return rpc; } static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher) { char *p; int rv; if((p = mptoa(cipher, 16, nil, 0)) == nil) return nil; rv = auth_rpc(rpc, "write", p, strlen(p)); free(p); if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok) return nil; mpfree(cipher); return strtomp(rpc->arg, nil, 16, nil); } static void factotum_rsa_close(AuthRpc*rpc) { if(!rpc) return; close(rpc->afd); auth_freerpc(rpc); } static void tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1) { uchar ai[MD5dlen], tmp[MD5dlen]; int i, n; MD5state *s; // generate a1 s = hmac_md5(label, nlabel, key, nkey, nil, nil); s = hmac_md5(seed0, nseed0, key, nkey, nil, s); hmac_md5(seed1, nseed1, key, nkey, ai, s); while(nbuf > 0) { s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil); s = hmac_md5(label, nlabel, key, nkey, nil, s); s = hmac_md5(seed0, nseed0, key, nkey, nil, s); hmac_md5(seed1, nseed1, key, nkey, tmp, s); n = MD5dlen; if(n > nbuf) n = nbuf; for(i = 0; i < n; i++) buf[i] ^= tmp[i]; buf += n; nbuf -= n; hmac_md5(ai, MD5dlen, key, nkey, tmp, nil); memmove(ai, tmp, MD5dlen); } } static void tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1) { uchar ai[SHA1dlen], tmp[SHA1dlen]; int i, n; SHAstate *s; // generate a1 s = hmac_sha1(label, nlabel, key, nkey, nil, nil); s = hmac_sha1(seed0, nseed0, key, nkey, nil, s); hmac_sha1(seed1, nseed1, key, nkey, ai, s); while(nbuf > 0) { s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil); s = hmac_sha1(label, nlabel, key, nkey, nil, s); s = hmac_sha1(seed0, nseed0, key, nkey, nil, s); hmac_sha1(seed1, nseed1, key, nkey, tmp, s); n = SHA1dlen; if(n > nbuf) n = nbuf; for(i = 0; i < n; i++) buf[i] ^= tmp[i]; buf += n; nbuf -= n; hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil); memmove(ai, tmp, SHA1dlen); } } // fill buf with md5(args)^sha1(args) static void tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1) { int i; int nlabel = strlen(label); int n = (nkey + 1) >> 1; for(i = 0; i < nbuf; i++) buf[i] = 0; tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1); tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1); } /* * for setting server session id's */ static Lock sidLock; static long maxSid = 1; /* the keys are verified to have the same public components * and to function correctly with pkcs 1 encryption and decryption. */ static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom) { TlsSec *sec = emalloc(sizeof(*sec)); USED(csid); USED(ncsid); // ignore csid for now memmove(sec->crandom, crandom, RandomSize); sec->clientVers = cvers; put32(sec->srandom, time(0)); genrandom(sec->srandom+4, RandomSize-4); memmove(srandom, sec->srandom, RandomSize); /* * make up a unique sid: use our pid, and and incrementing id * can signal no sid by setting nssid to 0. */ memset(ssid, 0, SidSize); put32(ssid, getpid()); lock(&sidLock); put32(ssid+4, maxSid++); unlock(&sidLock); *nssid = SidSize; return sec; } static int tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd) { if(epm != nil){ if(setVers(sec, vers) < 0) goto Err; serverMasterSecret(sec, epm, nepm); }else if(sec->vers != vers){ werrstr("mismatched session versions"); goto Err; } setSecrets(sec, kd, nkd); return 0; Err: sec->ok = -1; return -1; } static TlsSec* tlsSecInitc(int cvers, uchar *crandom) { TlsSec *sec = emalloc(sizeof(*sec)); sec->clientVers = cvers; put32(sec->crandom, time(0)); genrandom(sec->crandom+4, RandomSize-4); memmove(crandom, sec->crandom, RandomSize); return sec; } static int tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd) { RSApub *pub; pub = nil; USED(sid); USED(nsid); memmove(sec->srandom, srandom, RandomSize); if(setVers(sec, vers) < 0) goto Err; pub = X509toRSApub(cert, ncert, nil, 0); if(pub == nil){ werrstr("invalid x509/rsa certificate"); goto Err; } if(clientMasterSecret(sec, pub, epm, nepm) < 0) goto Err; rsapubfree(pub); setSecrets(sec, kd, nkd); return 0; Err: if(pub != nil) rsapubfree(pub); sec->ok = -1; return -1; } static int tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient) { if(sec->nfin != nfin){ sec->ok = -1; werrstr("invalid finished exchange"); return -1; } md5.malloced = 0; sha1.malloced = 0; (*sec->setFinished)(sec, md5, sha1, fin, isclient); return 1; } static void tlsSecOk(TlsSec *sec) { if(sec->ok == 0) sec->ok = 1; } static void tlsSecKill(TlsSec *sec) { if(!sec) return; factotum_rsa_close(sec->rpc); sec->ok = -1; } static void tlsSecClose(TlsSec *sec) { if(!sec) return; factotum_rsa_close(sec->rpc); free(sec->server); free(sec); } static int setVers(TlsSec *sec, int v) { if(v == SSL3Version){ sec->setFinished = sslSetFinished; sec->nfin = SSL3FinishedLen; sec->prf = sslPRF; }else if(v == TLSVersion){ sec->setFinished = tlsSetFinished; sec->nfin = TLSFinishedLen; sec->prf = tlsPRF; }else{ werrstr("invalid version"); return -1; } sec->vers = v; return 0; } /* * generate secret keys from the master secret. * * different crypto selections will require different amounts * of key expansion and use of key expansion data, * but it's all generated using the same function. */ static void setSecrets(TlsSec *sec, uchar *kd, int nkd) { (*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion", sec->srandom, RandomSize, sec->crandom, RandomSize); } /* * set the master secret from the pre-master secret. */ static void setMasterSecret(TlsSec *sec, Bytes *pm) { (*sec->prf)(sec->sec, MasterSecretSize, pm->data, MasterSecretSize, "master secret", sec->crandom, RandomSize, sec->srandom, RandomSize); } static void serverMasterSecret(TlsSec *sec, uchar *epm, int nepm) { Bytes *pm; pm = pkcs1_decrypt(sec, epm, nepm); // if the client messed up, just continue as if everything is ok, // to prevent attacks to check for correctly formatted messages. // Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client. if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){ fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n", sec->ok, pm, pm ? get16(pm->data) : -1, sec->clientVers, nepm); sec->ok = -1; if(pm != nil) freebytes(pm); pm = newbytes(MasterSecretSize); genrandom(pm->data, MasterSecretSize); } setMasterSecret(sec, pm); memset(pm->data, 0, pm->len); freebytes(pm); } static int clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm) { Bytes *pm, *key; pm = newbytes(MasterSecretSize); put16(pm->data, sec->clientVers); genrandom(pm->data+2, MasterSecretSize - 2); setMasterSecret(sec, pm); key = pkcs1_encrypt(pm, pub, 2); memset(pm->data, 0, pm->len); freebytes(pm); if(key == nil){ werrstr("tls pkcs1_encrypt failed"); return -1; } *nepm = key->len; *epm = malloc(*nepm); if(*epm == nil){ freebytes(key); werrstr("out of memory"); return -1; } memmove(*epm, key->data, *nepm); freebytes(key); return 1; } static void sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient) { DigestState *s; uchar h0[MD5dlen], h1[SHA1dlen], pad[48]; char *label; if(isClient) label = "CLNT"; else label = "SRVR"; md5((uchar*)label, 4, nil, &hsmd5); md5(sec->sec, MasterSecretSize, nil, &hsmd5); memset(pad, 0x36, 48); md5(pad, 48, nil, &hsmd5); md5(nil, 0, h0, &hsmd5); memset(pad, 0x5C, 48); s = md5(sec->sec, MasterSecretSize, nil, nil); s = md5(pad, 48, nil, s); md5(h0, MD5dlen, finished, s); sha1((uchar*)label, 4, nil, &hssha1); sha1(sec->sec, MasterSecretSize, nil, &hssha1); memset(pad, 0x36, 40); sha1(pad, 40, nil, &hssha1); sha1(nil, 0, h1, &hssha1); memset(pad, 0x5C, 40); s = sha1(sec->sec, MasterSecretSize, nil, nil); s = sha1(pad, 40, nil, s); sha1(h1, SHA1dlen, finished + MD5dlen, s); } // fill "finished" arg with md5(args)^sha1(args) static void tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient) { uchar h0[MD5dlen], h1[SHA1dlen]; char *label; // get current hash value, but allow further messages to be hashed in md5(nil, 0, h0, &hsmd5); sha1(nil, 0, h1, &hssha1); if(isClient) label = "client finished"; else label = "server finished"; tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen); } static void sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1) { DigestState *s; uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26]; int i, n, len; USED(label); len = 1; while(nbuf > 0){ if(len > 26) return; for(i = 0; i < len; i++) tmp[i] = 'A' - 1 + len; s = sha1(tmp, len, nil, nil); s = sha1(key, nkey, nil, s); s = sha1(seed0, nseed0, nil, s); sha1(seed1, nseed1, sha1dig, s); s = md5(key, nkey, nil, nil); md5(sha1dig, SHA1dlen, md5dig, s); n = MD5dlen; if(n > nbuf) n = nbuf; memmove(buf, md5dig, n); buf += n; nbuf -= n; len++; } } static mpint* bytestomp(Bytes* bytes) { mpint* ans; ans = betomp(bytes->data, bytes->len, nil); return ans; } /* * Convert mpint* to Bytes, putting high order byte first. */ static Bytes* mptobytes(mpint* big) { int n, m; uchar *a; Bytes* ans; a = nil; n = (mpsignif(big)+7)/8; m = mptobe(big, nil, n, &a); ans = makebytes(a, m); if(a != nil) free(a); return ans; } // Do RSA computation on block according to key, and pad // result on left with zeros to make it modlen long. static Bytes* rsacomp(Bytes* block, RSApub* key, int modlen) { mpint *x, *y; Bytes *a, *ybytes; int ylen; x = bytestomp(block); y = rsaencrypt(key, x, nil); mpfree(x); ybytes = mptobytes(y); ylen = ybytes->len; if(ylen < modlen) { a = newbytes(modlen); memset(a->data, 0, modlen-ylen); memmove(a->data+modlen-ylen, ybytes->data, ylen); freebytes(ybytes); ybytes = a; } else if(ylen > modlen) { // assume it has leading zeros (mod should make it so) a = newbytes(modlen); memmove(a->data, ybytes->data, modlen); freebytes(ybytes); ybytes = a; } mpfree(y); return ybytes; } // encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1 static Bytes* pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype) { Bytes *pad, *eb, *ans; int i, dlen, padlen, modlen; modlen = (mpsignif(key->n)+7)/8; dlen = data->len; if(modlen < 12 || dlen > modlen - 11) return nil; padlen = modlen - 3 - dlen; pad = newbytes(padlen); genrandom(pad->data, padlen); for(i = 0; i < padlen; i++) { if(blocktype == 0) pad->data[i] = 0; else if(blocktype == 1) pad->data[i] = 255; else if(pad->data[i] == 0) pad->data[i] = 1; } eb = newbytes(modlen); eb->data[0] = 0; eb->data[1] = blocktype; memmove(eb->data+2, pad->data, padlen); eb->data[padlen+2] = 0; memmove(eb->data+padlen+3, data->data, dlen); ans = rsacomp(eb, key, modlen); freebytes(eb); freebytes(pad); return ans; } // decrypt data according to PKCS#1, with given key. // expect a block type of 2. static Bytes* pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm) { Bytes *eb, *ans = nil; int i, modlen; mpint *x, *y; modlen = (mpsignif(sec->rsapub->n)+7)/8; if(nepm != modlen) return nil; x = betomp(epm, nepm, nil); y = factotum_rsa_decrypt(sec->rpc, x); if(y == nil) return nil; eb = mptobytes(y); if(eb->len < modlen){ // pad on left with zeros ans = newbytes(modlen); memset(ans->data, 0, modlen-eb->len); memmove(ans->data+modlen-eb->len, eb->data, eb->len); freebytes(eb); eb = ans; } if(eb->data[0] == 0 && eb->data[1] == 2) { for(i = 2; i < modlen; i++) if(eb->data[i] == 0) break; if(i < modlen - 1) ans = makebytes(eb->data+i+1, modlen-(i+1)); } freebytes(eb); return ans; } //================= general utility functions ======================== static void * emalloc(int n) { void *p; if(n==0) n=1; p = malloc(n); if(p == nil){ exits("out of memory"); } memset(p, 0, n); setmalloctag(p, getcallerpc(&n)); return p; } static void * erealloc(void *ReallocP, int ReallocN) { if(ReallocN == 0) ReallocN = 1; if(!ReallocP) ReallocP = emalloc(ReallocN); else if(!(ReallocP = realloc(ReallocP, ReallocN))){ exits("out of memory"); } setrealloctag(ReallocP, getcallerpc(&ReallocP)); return(ReallocP); } static void put32(uchar *p, u32int x) { p[0] = x>>24; p[1] = x>>16; p[2] = x>>8; p[3] = x; } static void put24(uchar *p, int x) { p[0] = x>>16; p[1] = x>>8; p[2] = x; } static void put16(uchar *p, int x) { p[0] = x>>8; p[1] = x; } static u32int get32(uchar *p) { return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3]; } static int get24(uchar *p) { return (p[0]<<16)|(p[1]<<8)|p[2]; } static int get16(uchar *p) { return (p[0]<<8)|p[1]; } #define OFFSET(x, s) offsetof(s, x) /* * malloc and return a new Bytes structure capable of * holding len bytes. (len >= 0) * Used to use crypt_malloc, which aborts if malloc fails. */ static Bytes* newbytes(int len) { Bytes* ans; ans = (Bytes*)malloc(OFFSET(data[0], Bytes) + len); ans->len = len; return ans; } /* * newbytes(len), with data initialized from buf */ static Bytes* makebytes(uchar* buf, int len) { Bytes* ans; ans = newbytes(len); memmove(ans->data, buf, len); return ans; } static void freebytes(Bytes* b) { if(b != nil) free(b); } /* len is number of ints */ static Ints* newints(int len) { Ints* ans; ans = (Ints*)malloc(OFFSET(data[0], Ints) + len*sizeof(int)); ans->len = len; return ans; } static Ints* makeints(int* buf, int len) { Ints* ans; ans = newints(len); if(len > 0) memmove(ans->data, buf, len*sizeof(int)); return ans; } static void freeints(Ints* b) { if(b != nil) free(b); }