shithub: riscv

ref: 7ab86494986671bf38464435f57fb6f925ce8632
dir: /sys/src/cmd/auth/factotum/rsa.c/

View raw version
/*
 * RSA authentication.
 *
 * Old ssh client protocol:
 *	read public key
 *		if you don't like it, read another, repeat
 *	write challenge
 *	read response
 *
 * all numbers are hexadecimal biginits parsable with strtomp.
 *
 * Sign (PKCS #1 using hash=sha1 or hash=md5)
 *	write hash(msg)
 *	read signature(hash(msg))
 *
 * Verify:
 *	write hash(msg)
 *	write signature(hash(msg))
 *	read ok or fail
 */

#include "dat.h"

enum {
	CHavePub,
	CHaveResp,
	VNeedHash,
	VNeedSig,
	VHaveResp,
	SNeedHash,
	SHaveResp,
	Maxphase,
};

static char *phasenames[] = {
[CHavePub]	"CHavePub",
[CHaveResp]	"CHaveResp",
[VNeedHash]	"VNeedHash",
[VNeedSig]	"VNeedSig",
[VHaveResp]	"VHaveResp",
[SNeedHash]	"SNeedHash",
[SHaveResp]	"SHaveResp",
};

struct State
{
	RSApriv *priv;
	mpint *resp;
	int off;
	Key *key;
	mpint *digest;
	int sigresp;
};

static mpint* mkdigest(RSApub *key, char *hashalg, uchar *hash, uint dlen);

static RSApriv*
readrsapriv(Key *k)
{
	char *a;
	RSApriv *priv;

	priv = rsaprivalloc();

	if((a=_strfindattr(k->attr, "ek"))==nil || (priv->pub.ek=strtomp(a, nil, 16, nil))==nil)
		goto Error;
	if((a=_strfindattr(k->attr, "n"))==nil || (priv->pub.n=strtomp(a, nil, 16, nil))==nil)
		goto Error;
	if(k->privattr == nil)		/* only public half */
		return priv;
	if((a=_strfindattr(k->privattr, "!p"))==nil || (priv->p=strtomp(a, nil, 16, nil))==nil)
		goto Error;
	if((a=_strfindattr(k->privattr, "!q"))==nil || (priv->q=strtomp(a, nil, 16, nil))==nil)
		goto Error;
	if((a=_strfindattr(k->privattr, "!kp"))==nil || (priv->kp=strtomp(a, nil, 16, nil))==nil)
		goto Error;
	if((a=_strfindattr(k->privattr, "!kq"))==nil || (priv->kq=strtomp(a, nil, 16, nil))==nil)
		goto Error;
	if((a=_strfindattr(k->privattr, "!c2"))==nil || (priv->c2=strtomp(a, nil, 16, nil))==nil)
		goto Error;
	if((a=_strfindattr(k->privattr, "!dk"))==nil || (priv->dk=strtomp(a, nil, 16, nil))==nil)
		goto Error;
	return priv;

Error:
	rsaprivfree(priv);
	return nil;
}

static int
rsainit(Proto*, Fsstate *fss)
{
	Keyinfo ki;
	State *s;
	char *role;

	if((role = _strfindattr(fss->attr, "role")) == nil)
		return failure(fss, "rsa role not specified");
	if(strcmp(role, "client") == 0)
		fss->phase = CHavePub;
	else if(strcmp(role, "sign") == 0)
		fss->phase = SNeedHash;
	else if(strcmp(role, "verify") == 0)
		fss->phase = VNeedHash;
	else
		return failure(fss, "rsa role %s unimplemented", role);

	s = emalloc(sizeof *s);
	fss->phasename = phasenames;
	fss->maxphase = Maxphase;
	fss->ps = s;

	switch(fss->phase){
	case SNeedHash:
	case VNeedHash:
		mkkeyinfo(&ki, fss, nil);
		if(findkey(&s->key, &ki, nil) != RpcOk)
			return failure(fss, nil);
		/* signing needs private key */
		if(fss->phase == SNeedHash && s->key->privattr == nil)
			return failure(fss,
				"missing private half of key -- cannot sign");
	}
	return RpcOk;
}

static int
rsaread(Fsstate *fss, void *va, uint *n)
{
	RSApriv *priv;
	State *s;
	mpint *m;
	Keyinfo ki;
	int len, r;

	s = fss->ps;
	switch(fss->phase){
	default:
		return phaseerror(fss, "read");
	case CHavePub:
		if(s->key){
			closekey(s->key);
			s->key = nil;
		}
		mkkeyinfo(&ki, fss, nil);
		ki.skip = s->off;
		ki.noconf = 1;
		if(findkey(&s->key, &ki, nil) != RpcOk)
			return failure(fss, nil);
		s->off++;
		priv = s->key->priv;
		*n = snprint(va, *n, "%B %B", priv->pub.n, priv->pub.ek);
		return RpcOk;
	case CHaveResp:
		*n = snprint(va, *n, "%B", s->resp);
		fss->phase = Established;
		return RpcOk;
	case SHaveResp:
		priv = s->key->priv;
		len = (mpsignif(priv->pub.n)+7)/8;
		if(len > *n)
			return failure(fss, "signature buffer too short");
		m = rsadecrypt(priv, s->digest, nil);
		r = mptobe(m, (uchar*)va, len, nil);
		if(r < len){
			memmove((uchar*)va+len-r, va, r);
			memset(va, 0, len-r);
		}
		*n = len;
		mpfree(m);
		fss->phase = Established;
		return RpcOk;
	case VHaveResp:
		*n = snprint(va, *n, "%s", s->sigresp == 0? "ok":
			"signature does not verify");
		fss->phase = Established;
		return RpcOk;
	}
}

static int
rsawrite(Fsstate *fss, void *va, uint n)
{
	RSApriv *priv;
	mpint *m, *mm;
	State *s;
	char *hash;
	int dlen;

	s = fss->ps;
	switch(fss->phase){
	default:
		return phaseerror(fss, "write");
	case CHavePub:
		if(s->key == nil)
			return failure(fss, "no current key");
		switch(canusekey(fss, s->key)){
		case -1:
			return RpcConfirm;
		case 0:
			return failure(fss, "confirmation denied");
		case 1:
			break;
		}
		m = strtomp(va, nil, 16, nil);
		if(m == nil)
			return failure(fss, "invalid challenge value");
		m = rsadecrypt(s->key->priv, m, m);
		s->resp = m;
		fss->phase = CHaveResp;
		return RpcOk;
	case SNeedHash:
	case VNeedHash:
		/* get hash type from key */
		hash = _strfindattr(s->key->attr, "hash");
		if(hash == nil)
			hash = "sha1";
		if(strcmp(hash, "sha1") == 0)
			dlen = SHA1dlen;
		else if(strcmp(hash, "md5") == 0)
			dlen = MD5dlen;
		else if(strcmp(hash, "sha256") == 0)
			dlen = SHA2_256dlen;
		else
			return failure(fss, "unknown hash function %s", hash);
		if(n != dlen)
			return failure(fss, "hash length %d should be %d",
				n, dlen);
		priv = s->key->priv;
		s->digest = mkdigest(&priv->pub, hash, (uchar *)va, n);
		if(s->digest == nil)
			return failure(fss, nil);
		if(fss->phase == VNeedHash)
			fss->phase = VNeedSig;
		else
			fss->phase = SHaveResp;
		return RpcOk;
	case VNeedSig:
		priv = s->key->priv;
		m = betomp((uchar*)va, n, nil);
		mm = rsaencrypt(&priv->pub, m, nil);
		s->sigresp = mpcmp(s->digest, mm);
		mpfree(m);
		mpfree(mm);
		fss->phase = VHaveResp;
		return RpcOk;
	}
}

static void
rsaclose(Fsstate *fss)
{
	State *s;

	s = fss->ps;
	if(s->key)
		closekey(s->key);
	if(s->resp)
		mpfree(s->resp);
	if(s->digest)
		mpfree(s->digest);
	free(s);
}

static int
rsaaddkey(Key *k, int before)
{
	fmtinstall('B', mpfmt);

	if((k->priv = readrsapriv(k)) == nil){
		werrstr("malformed key data");
		return -1;
	}
	return replacekey(k, before);
}

static void
rsaclosekey(Key *k)
{
	rsaprivfree(k->priv);
}

Proto rsa = {
.name=	"rsa",
.init=		rsainit,
.write=	rsawrite,
.read=	rsaread,
.close=	rsaclose,
.addkey=	rsaaddkey,
.closekey=	rsaclosekey,
};

/*
 * Simple ASN.1 encodings.
 * Lengths < 128 are encoded as 1-bytes constants,
 * making our life easy.
 */

/*
 * Hash OIDs
 *
 * SHA1 = 1.3.14.3.2.26
 * MDx = 1.2.840.113549.2.x
 * SHA256 = 2.16.840.1.101.3.4.2.1
 */
#define O0(a,b)	((a)*40+(b))
#define O2(x)	\
	(((x)>> 7)&0x7F)|0x80, \
	((x)&0x7F)
#define O3(x)	\
	(((x)>>14)&0x7F)|0x80, \
	(((x)>> 7)&0x7F)|0x80, \
	((x)&0x7F)
uchar oidsha1[] = { O0(1, 3), 14, 3, 2, 26 };
uchar oidmd5[] = { O0(1, 2), O2(840), O3(113549), 2, 5 };
uchar oidsha256[] = { O0(2, 16), O2(840), 1, 101, 3, 4, 2, 1 };
/*
 *	DigestInfo ::= SEQUENCE {
 *		digestAlgorithm AlgorithmIdentifier,
 *		digest OCTET STRING
 *	}
 *
 * except that OpenSSL seems to sign
 *
 *	DigestInfo ::= SEQUENCE {
 *		SEQUENCE{ digestAlgorithm AlgorithmIdentifier, NULL }
 *		digest OCTET STRING
 *	}
 *
 * instead.  Sigh.
 */
static int
mkasn1(uchar *asn1, char *alg, uchar *d, uint dlen)
{
	uchar *obj, *p;
	uint olen;

	if(strcmp(alg, "sha1") == 0){
		obj = oidsha1;
		olen = sizeof(oidsha1);
	}else if(strcmp(alg, "md5") == 0){
		obj = oidmd5;
		olen = sizeof(oidmd5);
	}else if(strcmp(alg, "sha256") == 0){
		obj = oidsha256;
		olen = sizeof(oidsha256);
	}else{
		sysfatal("bad alg in mkasn1");
		return -1;
	}

	p = asn1;
	*p++ = 0x30;		/* sequence */
	p++;

	*p++ = 0x30;		/* another sequence */
	p++;

	*p++ = 0x06;		/* object id */
	*p++ = olen;
	memmove(p, obj, olen);
	p += olen;

	*p++ = 0x05;		/* null */
	*p++ = 0;

	asn1[3] = p - (asn1+4);	/* end of inner sequence */

	*p++ = 0x04;		/* octet string */
	*p++ = dlen;
	memmove(p, d, dlen);
	p += dlen;

	asn1[1] = p - (asn1+2);	/* end of outer sequence */
	return p - asn1;
}

static mpint*
mkdigest(RSApub *key, char *hashalg, uchar *hash, uint dlen)
{
	mpint *m;
	uchar asn1[512], *buf;
	int len, n, pad;

	/*
	 * Create ASN.1
	 */
	n = mkasn1(asn1, hashalg, hash, dlen);

	/*
	 * PKCS#1 padding
	 */
	len = (mpsignif(key->n)+7)/8 - 1;
	if(len < n+2){
		werrstr("rsa key too short");
		return nil;
	}
	pad = len - (n+2);
	buf = emalloc(len);
	buf[0] = 0x01;
	memset(buf+1, 0xFF, pad);
	buf[1+pad] = 0x00;
	memmove(buf+1+pad+1, asn1, n);
	m = betomp(buf, len, nil);
	free(buf);
	return m;
}