ref: 8561a843070582d48ef7f99e1138941c3509975f
dir: /sys/src/cmd/auth/factotum/rsa.c/
/* * RSA authentication. * * Old ssh client protocol: * read public key * if you don't like it, read another, repeat * write challenge * read response * * all numbers are hexadecimal biginits parsable with strtomp. * * Sign (PKCS #1 using hash=sha1 or hash=md5) * write hash(msg) * read signature(hash(msg)) * * Verify: * write hash(msg) * write signature(hash(msg)) * read ok or fail */ #include "dat.h" enum { CHavePub, CHaveResp, VNeedHash, VNeedSig, VHaveResp, SNeedHash, SHaveResp, Maxphase, }; static char *phasenames[] = { [CHavePub] "CHavePub", [CHaveResp] "CHaveResp", [VNeedHash] "VNeedHash", [VNeedSig] "VNeedSig", [VHaveResp] "VHaveResp", [SNeedHash] "SNeedHash", [SHaveResp] "SHaveResp", }; struct State { RSApriv *priv; mpint *resp; int off; Key *key; mpint *digest; int sigresp; }; static mpint* mkdigest(RSApub *key, char *hashalg, uchar *hash, uint dlen); static RSApriv* readrsapriv(Key *k) { char *a; RSApriv *priv; priv = rsaprivalloc(); if((a=_strfindattr(k->attr, "ek"))==nil || (priv->pub.ek=strtomp(a, nil, 16, nil))==nil) goto Error; if((a=_strfindattr(k->attr, "n"))==nil || (priv->pub.n=strtomp(a, nil, 16, nil))==nil) goto Error; if(k->privattr == nil) /* only public half */ return priv; if((a=_strfindattr(k->privattr, "!p"))==nil || (priv->p=strtomp(a, nil, 16, nil))==nil) goto Error; if((a=_strfindattr(k->privattr, "!q"))==nil || (priv->q=strtomp(a, nil, 16, nil))==nil) goto Error; if((a=_strfindattr(k->privattr, "!kp"))==nil || (priv->kp=strtomp(a, nil, 16, nil))==nil) goto Error; if((a=_strfindattr(k->privattr, "!kq"))==nil || (priv->kq=strtomp(a, nil, 16, nil))==nil) goto Error; if((a=_strfindattr(k->privattr, "!c2"))==nil || (priv->c2=strtomp(a, nil, 16, nil))==nil) goto Error; if((a=_strfindattr(k->privattr, "!dk"))==nil || (priv->dk=strtomp(a, nil, 16, nil))==nil) goto Error; return priv; Error: rsaprivfree(priv); return nil; } static int rsainit(Proto*, Fsstate *fss) { Keyinfo ki; State *s; char *role; if((role = _strfindattr(fss->attr, "role")) == nil) return failure(fss, "rsa role not specified"); if(strcmp(role, "client") == 0) fss->phase = CHavePub; else if(strcmp(role, "sign") == 0) fss->phase = SNeedHash; else if(strcmp(role, "verify") == 0) fss->phase = VNeedHash; else return failure(fss, "rsa role %s unimplemented", role); s = emalloc(sizeof *s); fss->phasename = phasenames; fss->maxphase = Maxphase; fss->ps = s; switch(fss->phase){ case SNeedHash: case VNeedHash: mkkeyinfo(&ki, fss, nil); if(findkey(&s->key, &ki, nil) != RpcOk) return failure(fss, nil); /* signing needs private key */ if(fss->phase == SNeedHash && s->key->privattr == nil) return failure(fss, "missing private half of key -- cannot sign"); } return RpcOk; } static int rsaread(Fsstate *fss, void *va, uint *n) { RSApriv *priv; State *s; mpint *m; Keyinfo ki; int len; s = fss->ps; switch(fss->phase){ default: return phaseerror(fss, "read"); case CHavePub: if(s->key){ closekey(s->key); s->key = nil; } mkkeyinfo(&ki, fss, nil); ki.skip = s->off; ki.noconf = 1; if(findkey(&s->key, &ki, nil) != RpcOk) return failure(fss, nil); s->off++; priv = s->key->priv; *n = snprint(va, *n, "%B %B", priv->pub.n, priv->pub.ek); return RpcOk; case CHaveResp: *n = snprint(va, *n, "%B", s->resp); fss->phase = Established; return RpcOk; case SHaveResp: priv = s->key->priv; len = (mpsignif(priv->pub.n)+7)/8; if(len > *n) return failure(fss, "signature buffer too short"); *n = len; m = rsadecrypt(priv, s->digest, nil); mptober(m, (uchar*)va, len); mpfree(m); fss->phase = Established; return RpcOk; case VHaveResp: *n = snprint(va, *n, "%s", s->sigresp == 0? "ok": "signature does not verify"); fss->phase = Established; return RpcOk; } } static int rsawrite(Fsstate *fss, void *va, uint n) { RSApriv *priv; mpint *m, *mm; State *s; char *hash; int dlen; s = fss->ps; switch(fss->phase){ default: return phaseerror(fss, "write"); case CHavePub: if(s->key == nil) return failure(fss, "no current key"); switch(canusekey(fss, s->key)){ case -1: return RpcConfirm; case 0: return failure(fss, "confirmation denied"); case 1: break; } m = strtomp(va, nil, 16, nil); if(m == nil) return failure(fss, "invalid challenge value"); m = rsadecrypt(s->key->priv, m, m); s->resp = m; fss->phase = CHaveResp; return RpcOk; case SNeedHash: case VNeedHash: /* get hash type from key */ hash = _strfindattr(s->key->attr, "hash"); if(hash == nil) hash = "sha1"; if(strcmp(hash, "sha1") == 0) dlen = SHA1dlen; else if(strcmp(hash, "md5") == 0) dlen = MD5dlen; else if(strcmp(hash, "sha256") == 0) dlen = SHA2_256dlen; else return failure(fss, "unknown hash function %s", hash); if(n != dlen) return failure(fss, "hash length %d should be %d", n, dlen); priv = s->key->priv; s->digest = mkdigest(&priv->pub, hash, (uchar *)va, n); if(s->digest == nil) return failure(fss, nil); if(fss->phase == VNeedHash) fss->phase = VNeedSig; else fss->phase = SHaveResp; return RpcOk; case VNeedSig: priv = s->key->priv; m = betomp((uchar*)va, n, nil); mm = rsaencrypt(&priv->pub, m, nil); s->sigresp = mpcmp(s->digest, mm); mpfree(m); mpfree(mm); fss->phase = VHaveResp; return RpcOk; } } static void rsaclose(Fsstate *fss) { State *s; s = fss->ps; if(s->key) closekey(s->key); if(s->resp) mpfree(s->resp); if(s->digest) mpfree(s->digest); free(s); } static int rsaaddkey(Key *k, int before) { fmtinstall('B', mpfmt); if((k->priv = readrsapriv(k)) == nil){ werrstr("malformed key data"); return -1; } return replacekey(k, before); } static void rsaclosekey(Key *k) { rsaprivfree(k->priv); } Proto rsa = { .name= "rsa", .init= rsainit, .write= rsawrite, .read= rsaread, .close= rsaclose, .addkey= rsaaddkey, .closekey= rsaclosekey, }; /* * Simple ASN.1 encodings. * Lengths < 128 are encoded as 1-bytes constants, * making our life easy. */ /* * Hash OIDs * * SHA1 = 1.3.14.3.2.26 * MDx = 1.2.840.113549.2.x * SHA256 = 2.16.840.1.101.3.4.2.1 */ #define O0(a,b) ((a)*40+(b)) #define O2(x) \ (((x)>> 7)&0x7F)|0x80, \ ((x)&0x7F) #define O3(x) \ (((x)>>14)&0x7F)|0x80, \ (((x)>> 7)&0x7F)|0x80, \ ((x)&0x7F) uchar oidsha1[] = { O0(1, 3), 14, 3, 2, 26 }; uchar oidmd5[] = { O0(1, 2), O2(840), O3(113549), 2, 5 }; uchar oidsha256[] = { O0(2, 16), O2(840), 1, 101, 3, 4, 2, 1 }; /* * DigestInfo ::= SEQUENCE { * digestAlgorithm AlgorithmIdentifier, * digest OCTET STRING * } * * except that OpenSSL seems to sign * * DigestInfo ::= SEQUENCE { * SEQUENCE{ digestAlgorithm AlgorithmIdentifier, NULL } * digest OCTET STRING * } * * instead. Sigh. */ static int mkasn1(uchar *asn1, char *alg, uchar *d, uint dlen) { uchar *obj, *p; uint olen; if(strcmp(alg, "sha1") == 0){ obj = oidsha1; olen = sizeof(oidsha1); }else if(strcmp(alg, "md5") == 0){ obj = oidmd5; olen = sizeof(oidmd5); }else if(strcmp(alg, "sha256") == 0){ obj = oidsha256; olen = sizeof(oidsha256); }else{ sysfatal("bad alg in mkasn1"); return -1; } p = asn1; *p++ = 0x30; /* sequence */ p++; *p++ = 0x30; /* another sequence */ p++; *p++ = 0x06; /* object id */ *p++ = olen; memmove(p, obj, olen); p += olen; *p++ = 0x05; /* null */ *p++ = 0; asn1[3] = p - (asn1+4); /* end of inner sequence */ *p++ = 0x04; /* octet string */ *p++ = dlen; memmove(p, d, dlen); p += dlen; asn1[1] = p - (asn1+2); /* end of outer sequence */ return p - asn1; } static mpint* mkdigest(RSApub *key, char *hashalg, uchar *hash, uint dlen) { mpint *m; uchar asn1[512], *buf; int len, n, pad; /* * Create ASN.1 */ n = mkasn1(asn1, hashalg, hash, dlen); /* * PKCS#1 padding */ len = (mpsignif(key->n)+7)/8 - 1; if(len < n+2){ werrstr("rsa key too short"); return nil; } pad = len - (n+2); buf = emalloc(len); buf[0] = 0x01; memset(buf+1, 0xFF, pad); buf[1+pad] = 0x00; memmove(buf+1+pad+1, asn1, n); m = betomp(buf, len, nil); free(buf); return m; }