shithub: riscv

ref: bd510d7fb0f1fd312319be4d9fd37a9c4bb5169a
dir: /sys/src/libsec/port/ecc.c/

View raw version
#include "os.h"
#include <mp.h>
#include <libsec.h>
#include <ctype.h>

extern void jacobian_affine(mpint *p,
	mpint *X, mpint *Y, mpint *Z);
extern void jacobian_dbl(mpint *p, mpint *a,
	mpint *X1, mpint *Y1, mpint *Z1,
	mpint *X3, mpint *Y3, mpint *Z3);
extern void jacobian_add(mpint *p, mpint *a,
	mpint *X1, mpint *Y1, mpint *Z1,
	mpint *X2, mpint *Y2, mpint *Z2,
	mpint *X3, mpint *Y3, mpint *Z3);

void
ecassign(ECdomain *dom, ECpoint *a, ECpoint *b)
{
	if((b->inf = a->inf) != 0)
		return;
	mpassign(a->x, b->x);
	mpassign(a->y, b->y);
	if(b->z != nil){
		mpassign(a->z != nil ? a->z : mpone, b->z);
		return;
	}
	if(a->z != nil){
		b->z = mpcopy(a->z);
		jacobian_affine(dom->p, b->x, b->y, b->z);
		mpfree(b->z);
		b->z = nil;
	}
}

void
ecadd(ECdomain *dom, ECpoint *a, ECpoint *b, ECpoint *s)
{
	if(a->inf && b->inf){
		s->inf = 1;
		return;
	}
	if(a->inf){
		ecassign(dom, b, s);
		return;
	}
	if(b->inf){
		ecassign(dom, a, s);
		return;
	}

	if(s->z == nil){
		s->z = mpcopy(mpone);
		ecadd(dom, a, b, s);
		if(!s->inf)
			jacobian_affine(dom->p, s->x, s->y, s->z);
		mpfree(s->z);
		s->z = nil;
		return;
	}

	if(a == b)
		jacobian_dbl(dom->p, dom->a,
			a->x, a->y, a->z != nil ? a->z : mpone,
			s->x, s->y, s->z);
	else
		jacobian_add(dom->p, dom->a,
			a->x, a->y, a->z != nil ? a->z : mpone,
			b->x, b->y, b->z != nil ? b->z : mpone,
			s->x, s->y, s->z);
	s->inf = mpcmp(s->z, mpzero) == 0;
}

void
ecmul(ECdomain *dom, ECpoint *a, mpint *k, ECpoint *s)
{
	ECpoint ns, na;
	mpint *l;

	if(a->inf || mpcmp(k, mpzero) == 0){
		s->inf = 1;
		return;
	}
	ns.inf = 1;
	ns.x = mpnew(0);
	ns.y = mpnew(0);
	ns.z = mpnew(0);
	na.x = mpnew(0);
	na.y = mpnew(0);
	na.z = mpnew(0);
	ecassign(dom, a, &na);
	l = mpcopy(k);
	l->sign = 1;
	while(mpcmp(l, mpzero) != 0){
		if(l->p[0] & 1)
			ecadd(dom, &na, &ns, &ns);
		ecadd(dom, &na, &na, &na);
		mpright(l, 1, l);
	}
	if(k->sign < 0 && !ns.inf){
		ns.y->sign = -1;
		mpmod(ns.y, dom->p, ns.y);
	}
	ecassign(dom, &ns, s);
	mpfree(ns.x);
	mpfree(ns.y);
	mpfree(ns.z);
	mpfree(na.x);
	mpfree(na.y);
	mpfree(na.z);
	mpfree(l);
}

int
ecverify(ECdomain *dom, ECpoint *a)
{
	mpint *p, *q;
	int r;

	if(a->inf)
		return 1;

	assert(a->z == nil);	/* need affine coordinates */
	p = mpnew(0);
	q = mpnew(0);
	mpmodmul(a->y, a->y, dom->p, p);
	mpmodmul(a->x, a->x, dom->p, q);
	mpmodadd(q, dom->a, dom->p, q);
	mpmodmul(q, a->x, dom->p, q);
	mpmodadd(q, dom->b, dom->p, q);
	r = mpcmp(p, q);
	mpfree(p);
	mpfree(q);
	return r == 0;
}

int
ecpubverify(ECdomain *dom, ECpub *a)
{
	ECpoint p;
	int r;

	if(a->inf)
		return 0;
	if(!ecverify(dom, a))
		return 0;
	p.x = mpnew(0);
	p.y = mpnew(0);
	p.z = mpnew(0);
	ecmul(dom, a, dom->n, &p);
	r = p.inf;
	mpfree(p.x);
	mpfree(p.y);
	mpfree(p.z);
	return r;
}

static void
fixnibble(uchar *a)
{
	if(*a >= 'a')
		*a -= 'a'-10;
	else if(*a >= 'A')
		*a -= 'A'-10;
	else
		*a -= '0';
}

static int
octet(char **s)
{
	uchar c, d;
	
	c = *(*s)++;
	if(!isxdigit(c))
		return -1;
	d = *(*s)++;
	if(!isxdigit(d))
		return -1;
	fixnibble(&c);
	fixnibble(&d);
	return (c << 4) | d;
}

static mpint*
halfpt(ECdomain *dom, char *s, char **rptr, mpint *out)
{
	char *buf, *r;
	int n;
	mpint *ret;
	
	n = ((mpsignif(dom->p)+7)/8)*2;
	if(strlen(s) < n)
		return 0;
	buf = malloc(n+1);
	buf[n] = 0;
	memcpy(buf, s, n);
	ret = strtomp(buf, &r, 16, out);
	*rptr = s + (r - buf);
	free(buf);
	return ret;
}

static int
mpleg(mpint *a, mpint *b)
{
	int r, k;
	mpint *m, *n, *t;
	
	r = 1;
	m = mpcopy(a);
	n = mpcopy(b);
	for(;;){
		if(mpcmp(m, n) > 0)
			mpmod(m, n, m);
		if(mpcmp(m, mpzero) == 0){
			r = 0;
			break;
		}
		if(mpcmp(m, mpone) == 0)
			break;
		k = mplowbits0(m);
		if(k > 0){
			if(k & 1)
				switch(n->p[0] & 15){
				case 3: case 5: case 11: case 13:
					r = -r;
				}
			mpright(m, k, m);
		}
		if((n->p[0] & 3) == 3 && (m->p[0] & 3) == 3)
			r = -r;
		t = m;
		m = n;
		n = t;
	}
	mpfree(m);
	mpfree(n);
	return r;
}

static int
mpsqrt(mpint *n, mpint *p, mpint *r)
{
	mpint *a, *t, *s, *xp, *xq, *yp, *yq, *zp, *zq, *N;

	if(mpleg(n, p) == -1)
		return 0;
	a = mpnew(0);
	t = mpnew(0);
	s = mpnew(0);
	N = mpnew(0);
	xp = mpnew(0);
	xq = mpnew(0);
	yp = mpnew(0);
	yq = mpnew(0);
	zp = mpnew(0);
	zq = mpnew(0);
	for(;;){
		for(;;){
			mpnrand(p, genrandom, a);
			if(mpcmp(a, mpzero) > 0)
				break;
		}
		mpmul(a, a, t);
		mpsub(t, n, t);
		mpmod(t, p, t);
		if(mpleg(t, p) == -1)
			break;
	}
	mpadd(p, mpone, N);
	mpright(N, 1, N);
	mpmul(a, a, t);
	mpsub(t, n, t);
	mpassign(a, xp);
	uitomp(1, xq);
	uitomp(1, yp);
	uitomp(0, yq);
	while(mpcmp(N, mpzero) != 0){
		if(N->p[0] & 1){
			mpmul(xp, yp, zp);
			mpmul(xq, yq, zq);
			mpmul(zq, t, zq);
			mpadd(zp, zq, zp);
			mpmod(zp, p, zp);
			mpmul(xp, yq, zq);
			mpmul(xq, yp, s);
			mpadd(zq, s, zq);
			mpmod(zq, p, yq);
			mpassign(zp, yp);
		}
		mpmul(xp, xp, zp);
		mpmul(xq, xq, zq);
		mpmul(zq, t, zq);
		mpadd(zp, zq, zp);
		mpmod(zp, p, zp);
		mpmul(xp, xq, zq);
		mpadd(zq, zq, zq);
		mpmod(zq, p, xq);
		mpassign(zp, xp);
		mpright(N, 1, N);
	}
	if(mpcmp(yq, mpzero) != 0)
		abort();
	mpassign(yp, r);
	mpfree(a);
	mpfree(t);
	mpfree(s);
	mpfree(N);
	mpfree(xp);
	mpfree(xq);
	mpfree(yp);
	mpfree(yq);
	mpfree(zp);
	mpfree(zq);
	return 1;
}

ECpoint*
strtoec(ECdomain *dom, char *s, char **rptr, ECpoint *ret)
{
	int allocd, o;
	mpint *r;

	allocd = 0;
	if(ret == nil){
		allocd = 1;
		ret = mallocz(sizeof(*ret), 1);
		if(ret == nil)
			return nil;
		ret->x = mpnew(0);
		ret->y = mpnew(0);
	}
	ret->inf = 0;
	o = 0;
	switch(octet(&s)){
	case 0:
		ret->inf = 1;
		break;
	case 3:
		o = 1;
	case 2:
		if(halfpt(dom, s, &s, ret->x) == nil)
			goto err;
		r = mpnew(0);
		mpmul(ret->x, ret->x, r);
		mpadd(r, dom->a, r);
		mpmul(r, ret->x, r);
		mpadd(r, dom->b, r);
		if(!mpsqrt(r, dom->p, r)){
			mpfree(r);
			goto err;
		}
		if((r->p[0] & 1) != o)
			mpsub(dom->p, r, r);
		mpassign(r, ret->y);
		mpfree(r);
		if(!ecverify(dom, ret))
			goto err;
		break;
	case 4:
		if(halfpt(dom, s, &s, ret->x) == nil)
			goto err;
		if(halfpt(dom, s, &s, ret->y) == nil)
			goto err;
		if(!ecverify(dom, ret))
			goto err;
		break;
	}
	if(ret->z != nil && !ret->inf)
		mpassign(mpone, ret->z);
	return ret;

err:
	if(rptr)
		*rptr = s;
	if(allocd){
		mpfree(ret->x);
		mpfree(ret->y);
		free(ret);
	}
	return nil;
}

ECpriv*
ecgen(ECdomain *dom, ECpriv *p)
{
	if(p == nil){
		p = mallocz(sizeof(*p), 1);
		if(p == nil)
			return nil;
		p->x = mpnew(0);
		p->y = mpnew(0);
		p->d = mpnew(0);
	}
	for(;;){
		mpnrand(dom->n, genrandom, p->d);
		if(mpcmp(p->d, mpzero) > 0)
			break;
	}
	ecmul(dom, &dom->G, p->d, p);
	return p;
}

void
ecdsasign(ECdomain *dom, ECpriv *priv, uchar *dig, int len, mpint *r, mpint *s)
{
	ECpriv tmp;
	mpint *E, *t;

	tmp.x = mpnew(0);
	tmp.y = mpnew(0);
	tmp.z = nil;
	tmp.d = mpnew(0);
	E = betomp(dig, len, nil);
	t = mpnew(0);
	if(mpsignif(dom->n) < 8*len)
		mpright(E, 8*len - mpsignif(dom->n), E);
	for(;;){
		ecgen(dom, &tmp);
		mpmod(tmp.x, dom->n, r);
		if(mpcmp(r, mpzero) == 0)
			continue;
		mpmul(r, priv->d, s);
		mpadd(E, s, s);
		mpinvert(tmp.d, dom->n, t);
		mpmodmul(s, t, dom->n, s);
		if(mpcmp(s, mpzero) != 0)
			break;
	}
	mpfree(t);
	mpfree(E);
	mpfree(tmp.x);
	mpfree(tmp.y);
	mpfree(tmp.d);
}

int
ecdsaverify(ECdomain *dom, ECpub *pub, uchar *dig, int len, mpint *r, mpint *s)
{
	mpint *E, *t, *u1, *u2;
	ECpoint R, S;
	int ret;

	if(mpcmp(r, mpone) < 0 || mpcmp(s, mpone) < 0 || mpcmp(r, dom->n) >= 0 || mpcmp(r, dom->n) >= 0)
		return 0;
	E = betomp(dig, len, nil);
	if(mpsignif(dom->n) < 8*len)
		mpright(E, 8*len - mpsignif(dom->n), E);
	t = mpnew(0);
	u1 = mpnew(0);
	u2 = mpnew(0);
	R.x = mpnew(0);
	R.y = mpnew(0);
	R.z = mpnew(0);
	S.x = mpnew(0);
	S.y = mpnew(0);
	S.z = mpnew(0);
	mpinvert(s, dom->n, t);
	mpmodmul(E, t, dom->n, u1);
	mpmodmul(r, t, dom->n, u2);
	ecmul(dom, &dom->G, u1, &R);
	ecmul(dom, pub, u2, &S);
	ecadd(dom, &R, &S, &R);
	ret = 0;
	if(!R.inf){
		jacobian_affine(dom->p, R.x, R.y, R.z);
		mpmod(R.x, dom->n, t);
		ret = mpcmp(r, t) == 0;
	}
	mpfree(E);
	mpfree(t);
	mpfree(u1);
	mpfree(u2);
	mpfree(R.x);
	mpfree(R.y);
	mpfree(R.z);
	mpfree(S.x);
	mpfree(S.y);
	mpfree(S.z);
	return ret;
}

static char *code = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz";

void
base58enc(uchar *src, char *dst, int len)
{
	mpint *n, *r, *b;
	char *sdst, t;
	
	sdst = dst;
	n = betomp(src, len, nil);
	b = uitomp(58, nil);
	r = mpnew(0);
	while(mpcmp(n, mpzero) != 0){
		mpdiv(n, b, n, r);
		*dst++ = code[mptoui(r)];
	}
	for(; *src == 0; src++)
		*dst++ = code[0];
	dst--;
	while(dst > sdst){
		t = *sdst;
		*sdst++ = *dst;
		*dst-- = t;
	}
}

int
base58dec(char *src, uchar *dst, int len)
{
	mpint *n, *b, *r;
	char *t;
	
	n = mpnew(0);
	r = mpnew(0);
	b = uitomp(58, nil);
	for(; *src; src++){
		t = strchr(code, *src);
		if(t == nil){
			mpfree(n);
			mpfree(r);
			mpfree(b);
			werrstr("invalid base58 char");
			return -1;
		}
		uitomp(t - code, r);
		mpmul(n, b, n);
		mpadd(n, r, n);
	}
	mptober(n, dst, len);
	mpfree(n);
	mpfree(r);
	mpfree(b);
	return 0;
}

void
ecdominit(ECdomain *dom, void (*init)(mpint *p, mpint *a, mpint *b, mpint *x, mpint *y, mpint *n, mpint *h))
{
	memset(dom, 0, sizeof(*dom));
	dom->p = mpnew(0);
	dom->a = mpnew(0);
	dom->b = mpnew(0);
	dom->G.x = mpnew(0);
	dom->G.y = mpnew(0);
	dom->n = mpnew(0);
	dom->h = mpnew(0);
	if(init){
		(*init)(dom->p, dom->a, dom->b, dom->G.x, dom->G.y, dom->n, dom->h);
		dom->p = mpfield(dom->p);
	}
}

void
ecdomfree(ECdomain *dom)
{
	mpfree(dom->p);
	mpfree(dom->a);
	mpfree(dom->b);
	mpfree(dom->G.x);
	mpfree(dom->G.y);
	mpfree(dom->n);
	mpfree(dom->h);
	memset(dom, 0, sizeof(*dom));
}

int
ecencodepub(ECdomain *dom, ECpub *pub, uchar *data, int len)
{
	int n;

	n = (mpsignif(dom->p)+7)/8;
	if(len < 1 + 2*n)
		return 0;
	len = 1 + 2*n;
	data[0] = 0x04;
	mptober(pub->x, data+1, n);
	mptober(pub->y, data+1+n, n);
	return len;
}

ECpub*
ecdecodepub(ECdomain *dom, uchar *data, int len)
{
	ECpub *pub;
	int n;

	n = (mpsignif(dom->p)+7)/8;
	if(len != 1 + 2*n || data[0] != 0x04)
		return nil;
	pub = mallocz(sizeof(*pub), 1);
	if(pub == nil)
		return nil;
	pub->x = betomp(data+1, n, nil);
	pub->y = betomp(data+1+n, n, nil);
	if(!ecpubverify(dom, pub)){
		ecpubfree(pub);
		pub = nil;
	}
	return pub;
}

void
ecpubfree(ECpub *p)
{
	if(p == nil)
		return;
	mpfree(p->x);
	mpfree(p->y);
	free(p);
}