shithub: riscv

Download patch

ref: 38e1e5272fc9c66a00d702246813135452819ffe
parent: b677ab0c5909942bf8946e9e9bd148dea7dae718
author: cinap_lenrek <[email protected]>
date: Sat Nov 21 04:39:59 EST 2015

libmp: initial attempt at constant time code, faster reductions for special primes (for ecc)

introduce MPtimesafe flag to request time invariant computation
disables normalization so significant digits are not leaked.

--- a/sys/include/mp.h
+++ b/sys/include/mp.h
@@ -22,7 +22,10 @@
 
 enum
 {
-	MPstatic=	0x01,
+	MPstatic=	0x01,	/* static constant */
+	MPnorm=		0x02,	/* normalization status */
+	MPtimesafe=	0x04,	/* request time invariant computation */
+
 	Dbytes=		sizeof(mpdigit),	/* bytes per digit */
 	Dbits=		Dbytes*8		/* bits per digit */
 };
@@ -32,7 +35,7 @@
 mpint*	mpnew(int n);		/* create a new mpint with at least n bits */
 void	mpfree(mpint *b);
 void	mpbits(mpint *b, int n);	/* ensure that b has at least n bits */
-void	mpnorm(mpint *b);		/* dump leading zeros */
+mpint*	mpnorm(mpint *b);		/* dump leading zeros */
 mpint*	mpcopy(mpint *b);
 void	mpassign(mpint *old, mpint *new);
 
@@ -47,8 +50,10 @@
 char*	mptoa(mpint*, int, char*, int);
 mpint*	letomp(uchar*, uint, mpint*);	/* byte array, little-endian */
 int	mptole(mpint*, uchar*, uint, uchar**);
+void	mptolel(mpint *b, uchar *p, int n);
 mpint*	betomp(uchar*, uint, mpint*);	/* byte array, big-endian */
 int	mptobe(mpint*, uchar*, uint, uchar**);
+void	mptober(mpint *b, uchar *p, int n);
 uint	mptoui(mpint*);			/* unsigned int */
 mpint*	uitomp(uint, mpint*);
 int	mptoi(mpint*);			/* int */
@@ -71,6 +76,11 @@
 void	mpexp(mpint *b, mpint *e, mpint *m, mpint *res);	/* res = b**e mod m */
 void	mpmod(mpint *b, mpint *m, mpint *remainder);	/* remainder = b mod m */
 
+/* modular arithmetic, time invariant when 0≤b1≤m-1 and 0≤b2≤m-1 */
+void	mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum);	/* sum = b1+b2 % m */
+void	mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff);	/* diff = b1-b2 % m */
+void	mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod);	/* prod = b1*b2 % m */
+
 /* quotient = dividend/divisor, remainder = dividend % divisor */
 void	mpdiv(mpint *dividend, mpint *divisor,  mpint *quotient, mpint *remainder);
 
@@ -77,6 +87,9 @@
 /* return neg, 0, pos as b1-b2 is neg, 0, pos */
 int	mpcmp(mpint *b1, mpint *b2);
 
+/* res = s != 0 ? b1 : b2 */
+void	mpsel(int s, mpint *b1, mpint *b2, mpint *res);
+
 /* extended gcd return d, x, and y, s.t. d = gcd(a,b) and ax+by = d */
 void	mpextendedgcd(mpint *a, mpint *b, mpint *d, mpint *x, mpint *y);
 
@@ -106,12 +119,14 @@
 /* prereq: p has room for n+1 digits */
 int	mpvecdigmulsub(mpdigit *b, int n, mpdigit m, mpdigit *p);
 
-/* p[0:alen*blen-1] = a[0:alen-1] * b[0:blen-1] */
+/* p[0:alen+blen-1] = a[0:alen-1] * b[0:blen-1] */
 /* prereq: alen >= blen, p has room for m*n digits */
 void	mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p);
+void	mpvectsmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p);
 
 /* sign of a - b or zero if the same */
 int	mpveccmp(mpdigit *a, int alen, mpdigit *b, int blen);
+int	mpvectscmp(mpdigit *a, int alen, mpdigit *b, int blen);
 
 /* divide the 2 digit dividend by the one digit divisor and stick in quotient */
 /* we assume that the result is one digit - overflow is all 1's */
--- a/sys/man/2/mp
+++ b/sys/man/2/mp
@@ -1,6 +1,6 @@
 .TH MP 2
 .SH NAME
-mpsetminbits, mpnew, mpfree, mpbits, mpnorm, mpcopy, mpassign, mprand, mpnrand, strtomp, mpfmt,mptoa, betomp, mptobe, letomp, mptole, mptoui, uitomp, mptoi, itomp, uvtomp, mptouv, vtomp, mptov, mpdigdiv, mpadd, mpsub, mpleft, mpright, mpmul, mpexp, mpmod, mpdiv, mpcmp, mpextendedgcd, mpinvert, mpsignif, mplowbits0, mpvecdigmuladd, mpvecdigmulsub, mpvecadd, mpvecsub, mpveccmp, mpvecmul, mpmagcmp, mpmagadd, mpmagsub, crtpre, crtin, crtout, crtprefree, crtresfree \- extended precision arithmetic
+mpsetminbits, mpnew, mpfree, mpbits, mpnorm, mpcopy, mpassign, mprand, mpnrand, strtomp, mpfmt,mptoa, betomp, mptobe, mptober, letomp, mptole, mptolel, mptoui, uitomp, mptoi, itomp, uvtomp, mptouv, vtomp, mptov, mpdigdiv, mpadd, mpsub, mpleft, mpright, mpmul, mpexp, mpmod, mpmodadd, mpmodsub, mpmodmul, mpdiv, mpcmp, mpsel, mpextendedgcd, mpinvert, mpsignif, mplowbits0, mpvecdigmuladd, mpvecdigmulsub, mpvecadd, mpvecsub, mpveccmp, mpvecmul, mpmagcmp, mpmagadd, mpmagsub, crtpre, crtin, crtout, crtprefree, crtresfree \- extended precision arithmetic
 .SH SYNOPSIS
 .B #include <u.h>
 .br
@@ -22,7 +22,7 @@
 void	mpbits(mpint *b, int n)
 .PP
 .B
-void	mpnorm(mpint *b)
+mpint*	mpnorm(mpint *b)
 .PP
 .B
 mpint*	mpcopy(mpint *b)
@@ -52,6 +52,9 @@
 int	mptobe(mpint *b, uchar *buf, uint blen, uchar **bufp)
 .PP
 .B
+void	mptober(mpint *b, uchar *buf, int blen)
+.PP
+.B
 mpint*	letomp(uchar *buf, uint blen, mpint *b)
 .PP
 .B
@@ -58,6 +61,9 @@
 int	mptole(mpint *b, uchar *buf, uint blen, uchar **bufp)
 .PP
 .B
+void	mptolel(mpint *b, uchar *buf, int blen)
+.PP
+.B
 uint	mptoui(mpint*)
 .PP
 .B
@@ -115,6 +121,15 @@
 	mpint *remainder)
 .PP
 .B
+void	mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum)
+.PP
+.B
+void	mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff)
+.PP
+.B
+void	mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod)
+.PP
+.B
 int	mpcmp(mpint *b1, mpint *b2)
 .PP
 .B
@@ -121,6 +136,9 @@
 int	mpmagcmp(mpint *b1, mpint *b2)
 .PP
 .B
+void	mpsel(int s, mpint *b1, mpint *b2, mpint *res)
+.PP
+.B
 void	mpextendedgcd(mpint *a, mpint *b, mpint *d, mpint *x,
 .br
 .B
@@ -383,6 +401,24 @@
 Sign is ignored in these conversions, i.e., the byte
 array version is always positive.
 .PP
+.I Mptober
+and
+.I mptolel
+fill
+.I blen
+lower bytes of an
+.I mpint
+into a fixed length byte array.
+.I Mptober
+fills the bytes right adjusted in big endian order so that the least
+significant byte is at
+.I buf[blen-1]
+while
+.I mptolel
+fills in little endian order; left adjusted; so that the least
+significat byte is filled into
+.IR buf[0] .
+.PP
 .IR Betomp ,
 and
 .I letomp
@@ -486,8 +522,33 @@
 the same as
 .I mpcmp
 but ignores the sign and just compares magnitudes.
+.TP
+.I mpsel
+assigns
+.I b1
+to
+.I res
+when
+.I s
+is not zero, otherwise
+.I b2
+is assigned to
+.IR res .
 .PD
 .PP
+Modular arithmetic:
+.TF mpmodmul_
+.TP
+.I mpmodadd	
+.BR "sum = b1+b2 mod m" .
+.TP
+.I mpmodsub	
+.BR "diff = b1-b2 mod m" .
+.TP
+.I mpmodmul	
+.BR "prod = b1*b2 mod m" .
+.PD
+.PP
 .I Mpextendedgcd
 computes the greatest common denominator,
 .IR d ,
@@ -564,8 +625,8 @@
 -1 if negative.
 .TP
 .I mpvecmul
-.BR "p[0:alen*blen] = a[0:alen-1] * b[0:blen-1]" .
-We assume that p has room for alen*blen+1 digits.
+.BR "p[0:alen+blen] = a[0:alen-1] * b[0:blen-1]" .
+We assume that p has room for alen+blen+1 digits.
 .TP
 .I mpveccmp
 This returns -1, 0, or +1 as a - b is negative, 0, or positive.
@@ -576,6 +637,17 @@
 and
 .I mpzero
 are the constants 2, 1 and 0.  These cannot be freed.
+.SS "Time invariant computation"
+.PP
+In the field of cryptography, it is sometimes neccesary to implement
+algorithms such that the runtime of the algorithm is not depdenent on
+the input data. This library provides partial support for time
+invariant computation with the
+.I MPtimesafe
+flag that can be set on input or destination operands to request timing
+safe operation. The result of a timing safe operation will also have the
+.I MPtimesafe
+flag set and is not normalized.
 .SS "Chinese remainder theorem
 .PP
 When computing in a non-prime modulus, 
--- a/sys/src/libmp/port/betomp.c
+++ b/sys/src/libmp/port/betomp.c
@@ -13,19 +13,12 @@
 		b = mpnew(0);
 		setmalloctag(b, getcallerpc(&p));
 	}
-
-	// dump leading zeros
-	while(*p == 0 && n > 1){
-		p++;
-		n--;
-	}
-
-	// get the space
 	mpbits(b, n*8);
-	b->top = DIGITS(n*8);
-	m = b->top-1;
 
-	// first digit might not be Dbytes long
+	m = DIGITS(n*8);
+	b->top = m--;
+	b->sign = 1;
+
 	s = ((n-1)*8)%Dbits;
 	x = 0;
 	for(; n > 0; n--){
@@ -37,6 +30,5 @@
 			x = 0;
 		}
 	}
-
-	return b;
+	return mpnorm(b);
 }
--- a/sys/src/libmp/port/letomp.c
+++ b/sys/src/libmp/port/letomp.c
@@ -9,8 +9,10 @@
 	int i=0, m = 0;
 	mpdigit x=0;
 
-	if(b == nil)
+	if(b == nil){
 		b = mpnew(0);
+		setmalloctag(b, getcallerpc(&s));
+	}
 	mpbits(b, 8*n);
 	for(; n > 0; n--){
 		x |= ((mpdigit)(*s++)) << i;
@@ -24,5 +26,6 @@
 	if(i > 0)
 		b->p[m++] = x;
 	b->top = m;
-	return b;
+	b->sign = 1;
+	return mpnorm(b);
 }
--- a/sys/src/libmp/port/mkfile
+++ b/sys/src/libmp/port/mkfile
@@ -6,12 +6,15 @@
 	mpfmt\
 	strtomp\
 	mptobe\
+	mptober\
 	mptole\
+	mptolel\
 	betomp\
 	letomp\
 	mpadd\
 	mpsub\
 	mpcmp\
+	mpsel\
 	mpfactorial\
 	mpmul\
 	mpleft\
@@ -20,10 +23,12 @@
 	mpvecsub\
 	mpvecdigmuladd\
 	mpveccmp\
+	mpvectscmp\
 	mpdigdiv\
 	mpdiv\
 	mpexp\
 	mpmod\
+	mpmodop\
 	mpextendedgcd\
 	mpinvert\
 	mprand\
--- a/sys/src/libmp/port/mpadd.c
+++ b/sys/src/libmp/port/mpadd.c
@@ -9,6 +9,8 @@
 	int m, n;
 	mpint *t;
 
+	sum->flags |= (b1->flags | b2->flags) & MPtimesafe;
+
 	// get the sizes right
 	if(b2->top > b1->top){
 		t = b1;
@@ -41,6 +43,7 @@
 	int sign;
 
 	if(b1->sign != b2->sign){
+		assert(((b1->flags | b2->flags | sum->flags) & MPtimesafe) == 0);
 		if(b1->sign < 0)
 			mpmagsub(b2, b1, sum);
 		else
--- a/sys/src/libmp/port/mpaux.c
+++ b/sys/src/libmp/port/mpaux.c
@@ -5,11 +5,9 @@
 static mpdigit _mptwodata[1] = { 2 };
 static mpint _mptwo =
 {
-	1,
-	1,
-	1,
+	1, 1, 1,
 	_mptwodata,
-	MPstatic
+	MPstatic|MPnorm
 };
 mpint *mptwo = &_mptwo;
 
@@ -16,11 +14,9 @@
 static mpdigit _mponedata[1] = { 1 };
 static mpint _mpone =
 {
-	1,
-	1,
-	1,
+	1, 1, 1,
 	_mponedata,
-	MPstatic
+	MPstatic|MPnorm
 };
 mpint *mpone = &_mpone;
 
@@ -27,11 +23,9 @@
 static mpdigit _mpzerodata[1] = { 0 };
 static mpint _mpzero =
 {
-	1,
-	1,
-	0,
+	1, 1, 0,
 	_mpzerodata,
-	MPstatic
+	MPstatic|MPnorm
 };
 mpint *mpzero = &_mpzero;
 
@@ -57,18 +51,17 @@
 	if(n < 0)
 		sysfatal("mpsetminbits: n < 0");
 
-	b = mallocz(sizeof(mpint), 1);
-	setmalloctag(b, getcallerpc(&n));
-	if(b == nil)
-		sysfatal("mpnew: %r");
 	n = DIGITS(n);
 	if(n < mpmindigits)
 		n = mpmindigits;
-	b->p = (mpdigit*)mallocz(n*Dbytes, 1);
-	if(b->p == nil)
+	b = mallocz(sizeof(mpint) + n*Dbytes, 1);
+	if(b == nil)
 		sysfatal("mpnew: %r");
+	setmalloctag(b, getcallerpc(&n));
+	b->p = (mpdigit*)&b[1];
 	b->size = n;
 	b->sign = 1;
+	b->flags = MPnorm;
 
 	return b;
 }
@@ -83,16 +76,23 @@
 	if(b->size >= n){
 		if(b->top >= n)
 			return;
-		memset(&b->p[b->top], 0, Dbytes*(n - b->top));
-		b->top = n;
-		return;
+	} else {
+		if(b->p == (mpdigit*)&b[1]){
+			b->p = (mpdigit*)mallocz(n*Dbytes, 0);
+			if(b->p == nil)
+				sysfatal("mpbits: %r");
+			memmove(b->p, &b[1], Dbytes*b->top);
+			memset(&b[1], 0, Dbytes*b->size);
+		} else {
+			b->p = (mpdigit*)realloc(b->p, n*Dbytes);
+			if(b->p == nil)
+				sysfatal("mpbits: %r");
+		}
+		b->size = n;
 	}
-	b->p = (mpdigit*)realloc(b->p, n*Dbytes);
-	if(b->p == nil)
-		sysfatal("mpbits: %r");
 	memset(&b->p[b->top], 0, Dbytes*(n - b->top));
-	b->size = n;
 	b->top = n;
+	b->flags &= ~MPnorm;
 }
 
 void
@@ -102,16 +102,22 @@
 		return;
 	if(b->flags & MPstatic)
 		sysfatal("freeing mp constant");
-	memset(b->p, 0, b->size*Dbytes);	// information hiding
-	free(b->p);
+	memset(b->p, 0, b->size*Dbytes);
+	if(b->p != (mpdigit*)&b[1])
+		free(b->p);
 	free(b);
 }
 
-void
+mpint*
 mpnorm(mpint *b)
 {
 	int i;
 
+	if(b->flags & MPtimesafe){
+		assert(b->sign == 1);
+		b->flags &= ~MPnorm;
+		return b;
+	}
 	for(i = b->top-1; i >= 0; i--)
 		if(b->p[i] != 0)
 			break;
@@ -118,6 +124,8 @@
 	b->top = i+1;
 	if(b->top == 0)
 		b->sign = 1;
+	b->flags |= MPnorm;
+	return b;
 }
 
 mpint*
@@ -126,8 +134,10 @@
 	mpint *new;
 
 	new = mpnew(Dbits*old->size);
-	new->top = old->top;
+	setmalloctag(new, getcallerpc(&old));
 	new->sign = old->sign;
+	new->top = old->top;
+	new->flags = old->flags & ~MPstatic;
 	memmove(new->p, old->p, Dbytes*old->top);
 	return new;
 }
@@ -135,9 +145,14 @@
 void
 mpassign(mpint *old, mpint *new)
 {
+	if(new == nil || old == new)
+		return;
+	new->top = 0;
 	mpbits(new, Dbits*old->top);
 	new->sign = old->sign;
 	new->top = old->top;
+	new->flags &= ~MPnorm;
+	new->flags |= old->flags & ~MPstatic;
 	memmove(new->p, old->p, Dbytes*old->top);
 }
 
@@ -167,6 +182,7 @@
 	int k, bit, digit;
 	mpdigit d;
 
+	assert(n->flags & MPnorm);
 	if(n->top==0)
 		return 0;
 	k = 0;
@@ -187,4 +203,3 @@
 	}
 	return k;
 }
-
--- a/sys/src/libmp/port/mpcmp.c
+++ b/sys/src/libmp/port/mpcmp.c
@@ -8,10 +8,14 @@
 {
 	int i;
 
-	i = b1->top - b2->top;
-	if(i)
-		return i;
-
+	i = b1->flags | b2->flags;
+	if(i & MPtimesafe)
+		return mpvectscmp(b1->p, b1->top, b2->p, b2->top);
+	if(i & MPnorm){
+		i = b1->top - b2->top;
+		if(i)
+			return i;
+	}
 	return mpveccmp(b1->p, b1->top, b2->p, b2->top);
 }
 
@@ -19,10 +23,8 @@
 int
 mpcmp(mpint *b1, mpint *b2)
 {
-	if(b1->sign != b2->sign)
-		return b1->sign - b2->sign;
-	if(b1->sign < 0)
-		return mpmagcmp(b2, b1);
-	else
-		return mpmagcmp(b1, b2);
+	int sign;
+
+	sign = (b1->sign - b2->sign) >> 1;	// -1, 0, 1
+	return sign | (sign&1)-1 & mpmagcmp(b1, b2)*b1->sign;
 }
--- a/sys/src/libmp/port/mpdiv.c
+++ b/sys/src/libmp/port/mpdiv.c
@@ -13,10 +13,29 @@
 	mpdigit qd, *up, *vp, *qp;
 	mpint *u, *v, *t;
 
+	assert(quotient != remainder);
+	assert(divisor->flags & MPnorm);
+
 	// divide bv zero
 	if(divisor->top == 0)
 		abort();
 
+	// division by one or small powers of two
+	if(divisor->top == 1 && (divisor->p[0] & divisor->p[0]-1) == 0){
+		vlong r = (vlong)dividend->sign * (dividend->p[0] & divisor->p[0]-1);
+		if(quotient != nil){
+			for(s = 0; ((divisor->p[0] >> s) & 1) == 0; s++)
+				;
+			mpright(dividend, s, quotient);
+		}
+		if(remainder != nil){
+			remainder->flags |= dividend->flags & MPtimesafe;
+			vtomp(r, remainder);
+		}
+		return;
+	}
+	assert((dividend->flags & MPtimesafe) == 0);
+
 	// quick check
 	if(mpmagcmp(dividend, divisor) < 0){
 		if(remainder != nil)
@@ -95,6 +114,7 @@
 		*up-- = 0;
 	}
 	if(qp != nil){
+		assert((quotient->flags & MPtimesafe) == 0);
 		mpnorm(quotient);
 		if(dividend->sign != divisor->sign)
 			quotient->sign = -1;
@@ -101,6 +121,7 @@
 	}
 
 	if(remainder != nil){
+		assert((remainder->flags & MPtimesafe) == 0);
 		mpright(u, s, remainder);	// u is the remainder shifted
 		remainder->sign = dividend->sign;
 	}
--- a/sys/src/libmp/port/mpeuclid.c
+++ b/sys/src/libmp/port/mpeuclid.c
@@ -13,6 +13,9 @@
 {
 	mpint *tmp, *x0, *x1, *x2, *y0, *y1, *y2, *q, *r;
 
+	assert((a->flags&b->flags) & MPnorm);
+	assert(((a->flags|b->flags|d->flags|x->flags|y->flags) & MPtimesafe) == 0);
+
 	if(a->sign<0 || b->sign<0)
 		sysfatal("mpeuclid: negative arg");
 
--- a/sys/src/libmp/port/mpexp.c
+++ b/sys/src/libmp/port/mpexp.c
@@ -22,6 +22,10 @@
 	mpdigit d, bit;
 	int i, j;
 
+	assert(m->flags & MPnorm);
+	assert((e->flags & MPtimesafe) == 0);
+	res->flags |= b->flags & MPtimesafe;
+
 	i = mpcmp(e,mpzero);
 	if(i==0){
 		mpassign(mpone, res);
--- a/sys/src/libmp/port/mpextendedgcd.c
+++ b/sys/src/libmp/port/mpextendedgcd.c
@@ -5,7 +5,7 @@
 
 // extended binary gcd
 //
-// For a anv b it solves, v = gcd(a,b) and finds x and y s.t.
+// For a and b it solves, v = gcd(a,b) and finds x and y s.t.
 // ax + by = v
 //
 // Handbook of Applied Cryptography, Menezes et al, 1997, pg 608.  
@@ -14,6 +14,9 @@
 {
 	mpint *u, *A, *B, *C, *D;
 	int g;
+
+	assert((a->flags&b->flags) & MPnorm);
+	assert(((a->flags|b->flags|v->flags|x->flags|y->flags) & MPtimesafe) == 0);
 
 	if(a->sign < 0 || b->sign < 0){
 		mpassign(mpzero, v);
--- a/sys/src/libmp/port/mpfmt.c
+++ b/sys/src/libmp/port/mpfmt.c
@@ -102,6 +102,7 @@
 		return -1;
 
 	d = mpcopy(b);
+	mpnorm(d);
 	r = mpnew(0);
 	billion = uitomp(1000000000, nil);
 	out = buf+len;
@@ -128,14 +129,19 @@
 mpfmt(Fmt *fmt)
 {
 	mpint *b;
-	char *p;
+	char *p, f;
 
 	b = va_arg(fmt->args, mpint*);
 	if(b == nil)
 		return fmtstrcpy(fmt, "*");
-	
+
+	f = b->flags;
+	b->flags &= ~MPtimesafe;
+
 	p = mptoa(b, fmt->prec, nil, 0);
 	fmt->flags &= ~FmtPrec;
+
+	b->flags = f;
 
 	if(p == nil)
 		return fmtstrcpy(fmt, "*");
--- a/sys/src/libmp/port/mpleft.c
+++ b/sys/src/libmp/port/mpleft.c
@@ -15,8 +15,8 @@
 		return;
 	}
 
-	// a negative left shift is a right shift
-	if(shift < 0){
+	// a zero or negative left shift is a right shift
+	if(shift <= 0){
 		mpright(b, -shift, res);
 		return;
 	}
@@ -46,7 +46,6 @@
 	for(i = 0; i < d; i++)
 		res->p[i] = 0;
 
-	// normalize
-	while(res->top > 0 && res->p[res->top-1] == 0)
-		res->top--;
+	res->flags |= b->flags & MPtimesafe;
+	mpnorm(res);
 }
--- a/sys/src/libmp/port/mpmod.c
+++ b/sys/src/libmp/port/mpmod.c
@@ -2,14 +2,100 @@
 #include <mp.h>
 #include "dat.h"
 
-// remainder = b mod m
-//
-// knuth, vol 2, pp 398-400
-
 void
-mpmod(mpint *b, mpint *m, mpint *remainder)
+mpmod(mpint *x, mpint *n, mpint *r)
 {
-	mpdiv(b, m, nil, remainder);
-	if(remainder->sign < 0)
-		mpadd(m, remainder, remainder);
+	static int busy;
+	static mpint *p, *m, *c, *v;
+	mpdigit q[32], t[64], d;
+	int sign, k, s, qn, tn;
+
+	sign = x->sign;
+
+	assert(n->flags & MPnorm);
+	if(n->top < 2 || n->top > nelem(q) || (x->top-n->top) > nelem(q))
+		goto hard;
+
+	/*
+	 * check if n = 2**k - c where c has few power of two factors
+	 * above the lowest digit.
+	 */
+	for(k = n->top-1; k > 0; k--){
+		d = n->p[k] >> 1;
+		if((d+1 & d) != 0)
+			goto hard;
+	}
+
+	d = n->p[n->top-1];
+	for(s = 0; (d & (mpdigit)1<<Dbits-1) == 0; s++)
+		d <<= 1;
+
+	/* lo(x) = x[0:k-1], hi(x) = x[k:xn-1] */
+	k = n->top;
+
+	while(_tas(&busy))
+		;
+
+	if(p == nil || mpmagcmp(n, p) != 0){
+		if(m == nil){
+			m = mpnew(0);
+			c = mpnew(0);
+			p = mpnew(0);
+		}
+		mpassign(n, p);
+
+		mpleft(n, s, m);
+		mpleft(mpone, k*Dbits, c);
+		mpsub(c, m, c);
+	}
+
+	mpleft(x, s, r);
+	if(r->top <= k){
+		mpbits(r, (k+1)*Dbits);
+		r->top = k+1;
+	}
+
+	/* q = hi(r) */
+	qn = r->top - k;
+	memmove(q, r->p+k, qn*Dbytes);
+
+	/* r = lo(r) */
+	r->top = k;
+
+	do {
+		/* t = q*c */
+		tn = qn + c->top;
+		memset(t, 0, tn*Dbytes);
+		mpvecmul(q, qn, c->p, c->top, t);
+
+		/* q = hi(t) */
+		qn = tn - k;
+		if(qn <= 0) qn = 0;
+		else memmove(q, t+k, qn*Dbytes);
+
+		/* r += lo(t) */
+		if(tn > k)
+			tn = k;
+		mpvecadd(r->p, k, t, tn, r->p);
+
+		/* if(r >= m) r -= m */
+		mpvecsub(r->p, k+1, m->p, k, t), d = t[k];
+		for(tn = 0; tn < k; tn++)
+			r->p[tn] = (r->p[tn] & d) | (t[tn] & ~d);
+	} while(qn > 0);
+
+	busy = 0;
+
+	if(s != 0)
+		mpright(r, s, r);
+	else
+		mpnorm(r);
+	goto done;
+
+hard:
+	mpdiv(x, n, nil, r);
+
+done:
+	if(sign < 0)
+		mpmagsub(n, r, r);
 }
--- /dev/null
+++ b/sys/src/libmp/port/mpmodop.c
@@ -1,0 +1,96 @@
+#include <u.h>
+#include <libc.h>
+#include <mp.h>
+
+/* operands need to have m->top+1 digits of space and satisfy 0 ≤ a ≤ m-1 */
+static mpint*
+modarg(mpint *a, mpint *m)
+{
+	if(a->size <= m->top || a->sign < 0 || mpmagcmp(a, m) >= 0){
+		a = mpcopy(a);
+		mpmod(a, m, a);
+		mpbits(a, Dbits*(m->top+1));
+		a->top = m->top;
+	} else if(a->top < m->top){
+		memset(&a->p[a->top], 0, (m->top - a->top)*Dbytes);
+	}
+	return a;
+}
+
+void
+mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum)
+{
+	mpint *a, *b;
+	mpdigit d;
+	int i, j;
+
+	a = modarg(b1, m);
+	b = modarg(b2, m);
+
+	sum->flags |= (a->flags | b->flags) & MPtimesafe;
+	mpbits(sum, Dbits*2*(m->top+1));
+
+	mpvecadd(a->p, m->top, b->p, m->top, sum->p);
+	mpvecsub(sum->p, m->top+1, m->p, m->top, sum->p+m->top+1);
+
+	d = sum->p[2*m->top+1];
+	for(i = 0, j = m->top+1; i < m->top; i++, j++)
+		sum->p[i] = (sum->p[i] & d) | (sum->p[j] & ~d);
+
+	sum->top = m->top;
+	sum->sign = 1;
+	mpnorm(sum);
+
+	if(a != b1)
+		mpfree(a);
+	if(b != b2)
+		mpfree(b);
+}
+
+void
+mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff)
+{
+	mpint *a, *b;
+	mpdigit d;
+	int i, j;
+
+	a = modarg(b1, m);
+	b = modarg(b2, m);
+
+	diff->flags |= (a->flags | b->flags) & MPtimesafe;
+	mpbits(diff, Dbits*2*(m->top+1));
+
+	a->p[m->top] = 0;
+	mpvecsub(a->p, m->top+1, b->p, m->top, diff->p);
+	mpvecadd(diff->p, m->top, m->p, m->top, diff->p+m->top+1);
+
+	d = ~diff->p[m->top];
+	for(i = 0, j = m->top+1; i < m->top; i++, j++)
+		diff->p[i] = (diff->p[i] & d) | (diff->p[j] & ~d);
+
+	diff->top = m->top;
+	diff->sign = 1;
+	mpnorm(diff);
+
+	if(a != b1)
+		mpfree(a);
+	if(b != b2)
+		mpfree(b);
+}
+
+void
+mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod)
+{
+	mpint *a, *b;
+
+	a = modarg(b1, m);
+	b = modarg(b2, m);
+
+	mpmul(a, b, prod);
+	mpmod(prod, m, prod);
+
+	if(a != b1)
+		mpfree(a);
+	if(b != b2)
+		mpfree(b);
+}
--- a/sys/src/libmp/port/mpmul.c
+++ b/sys/src/libmp/port/mpmul.c
@@ -113,10 +113,6 @@
 		a = b;
 		b = t;
 	}
-	if(blen == 0){
-		memset(p, 0, Dbytes*(alen+blen));
-		return;
-	}
 
 	if(alen >= KARATSUBAMIN && blen > 1){
 		// O(n^1.585)
@@ -132,24 +128,48 @@
 }
 
 void
+mpvectsmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
+{
+	int i;
+	mpdigit *t;
+
+	if(alen < blen){
+		i = alen;
+		alen = blen;
+		blen = i;
+		t = a;
+		a = b;
+		b = t;
+	}
+	if(blen == 0)
+		return;
+	for(i = 0; i < blen; i++)
+		mpvecdigmuladd(a, alen, b[i], &p[i]);
+}
+
+void
 mpmul(mpint *b1, mpint *b2, mpint *prod)
 {
 	mpint *oprod;
 
-	oprod = nil;
+	oprod = prod;
 	if(prod == b1 || prod == b2){
-		oprod = prod;
 		prod = mpnew(0);
+		prod->flags = oprod->flags;
 	}
+	prod->flags |= (b1->flags | b2->flags) & MPtimesafe;
 
 	prod->top = 0;
 	mpbits(prod, (b1->top+b2->top+1)*Dbits);
-	mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p);
+	if(prod->flags & MPtimesafe)
+		mpvectsmul(b1->p, b1->top, b2->p, b2->top, prod->p);
+	else
+		mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p);
 	prod->top = b1->top+b2->top+1;
 	prod->sign = b1->sign*b2->sign;
 	mpnorm(prod);
 
-	if(oprod != nil){
+	if(oprod != prod){
 		mpassign(prod, oprod);
 		mpfree(prod);
 	}
--- a/sys/src/libmp/port/mpnrand.c
+++ b/sys/src/libmp/port/mpnrand.c
@@ -16,8 +16,10 @@
 	mpleft(mpone, bits, m);
 	mpsub(m, mpone, m);
 
-	if(b == nil)
+	if(b == nil){
 		b = mpnew(bits);
+		setmalloctag(b, getcallerpc(&n));
+	}
 
 	/* m = m - (m % n) */
 	mpmod(m, n, b);
--- a/sys/src/libmp/port/mprand.c
+++ b/sys/src/libmp/port/mprand.c
@@ -6,19 +6,20 @@
 mpint*
 mprand(int bits, void (*gen)(uchar*, int), mpint *b)
 {
-	int n, m;
 	mpdigit mask;
+	int n, m;
 	uchar *p;
 
 	n = DIGITS(bits);
-	if(b == nil)
+	if(b == nil){
 		b = mpnew(bits);
-	else
+		setmalloctag(b, getcallerpc(&bits));
+	}else
 		mpbits(b, bits);
 
 	p = malloc(n*Dbytes);
 	if(p == nil)
-		return nil;
+		sysfatal("mprand: %r");
 	(*gen)(p, n*Dbytes);
 	betomp(p, n*Dbytes, b);
 	free(p);
@@ -25,18 +26,12 @@
 
 	// make sure we don't give too many bits
 	m = bits%Dbits;
-	n--;
-	if(m > 0){
-		mask = 1;
-		mask <<= m;
-		mask--;
-		b->p[n] &= mask;
-	}
+	if(m == 0)
+		return b;
 
-	for(; n >= 0; n--)
-		if(b->p[n] != 0)
-			break;
-	b->top = n+1;
-	b->sign = 1;
-	return b;
+	mask = 1;
+	mask <<= m;
+	mask--;
+	b->p[n-1] &= mask;
+	return mpnorm(b);
 }
--- a/sys/src/libmp/port/mpright.c
+++ b/sys/src/libmp/port/mpright.c
@@ -23,6 +23,9 @@
 
 	if(res != b)
 		mpbits(res, b->top*Dbits - shift);
+	else if(shift == 0)
+		return;
+
 	d = shift/Dbits;
 	r = shift - d*Dbits;
 	l = Dbits - r;
@@ -29,6 +32,7 @@
 
 	//  shift all the bits out == zero
 	if(d>=b->top){
+		res->sign = 1;
 		res->top = 0;
 		return;
 	}
@@ -46,9 +50,8 @@
 		}
 		res->p[i++] = last>>r;
 	}
-	while(i > 0 && res->p[i-1] == 0)
-		i--;
+
 	res->top = i;
-	if(i==0)
-		res->sign = 1;
+	res->flags |= b->flags & MPtimesafe;
+	mpnorm(res);
 }
--- /dev/null
+++ b/sys/src/libmp/port/mpsel.c
@@ -1,0 +1,42 @@
+#include "os.h"
+#include <mp.h>
+#include "dat.h"
+
+// res = s != 0 ? b1 : b2
+void
+mpsel(int s, mpint *b1, mpint *b2, mpint *res)
+{
+	mpdigit d;
+	int n, m, i;
+
+	res->flags |= (b1->flags | b2->flags) & MPtimesafe;
+	if((res->flags & MPtimesafe) == 0){
+		mpassign(s ? b1 : b2, res);
+		return;
+	}
+	res->flags &= ~MPnorm;
+
+	n = b1->top;
+	m = b2->top;
+	mpbits(res, Dbits*(n >= m ? n : m));
+	res->top = n >= m ? n : m;
+
+	s = (-s^s|s)>>(sizeof(s)*8-1);
+	res->sign = (b1->sign & s) | (b2->sign & ~s);
+
+	d = -((mpdigit)s & 1);
+
+	i = 0;
+	while(i < n && i < m){
+		res->p[i] = (b1->p[i] & d) | (b2->p[i] & ~d);
+		i++;
+	}
+	while(i < n){
+		res->p[i] = b1->p[i] & d;
+		i++;
+	}
+	while(i < m){
+		res->p[i] = b2->p[i] & ~d;
+		i++;
+	}
+}
--- a/sys/src/libmp/port/mpsub.c
+++ b/sys/src/libmp/port/mpsub.c
@@ -11,12 +11,15 @@
 
 	// get the sizes right
 	if(mpmagcmp(b1, b2) < 0){
+		assert(((b1->flags | b2->flags | diff->flags) & MPtimesafe) == 0);
 		sign = -1;
 		t = b1;
 		b1 = b2;
 		b2 = t;
-	} else
+	} else {
+		diff->flags |= (b1->flags | b2->flags) & MPtimesafe;
 		sign = 1;
+	}
 	n = b1->top;
 	m = b2->top;
 	if(m == 0){
@@ -39,6 +42,7 @@
 	int sign;
 
 	if(b1->sign != b2->sign){
+		assert(((b1->flags | b2->flags | diff->flags) & MPtimesafe) == 0);
 		sign = b1->sign;
 		mpmagadd(b1, b2, diff);
 		diff->sign = sign;
--- a/sys/src/libmp/port/mptobe.c
+++ b/sys/src/libmp/port/mptobe.c
@@ -2,57 +2,31 @@
 #include <mp.h>
 #include "dat.h"
 
-// convert an mpint into a big endian byte array (most significant byte first)
+// convert an mpint into a big endian byte array (most significant byte first; left adjusted)
 //   return number of bytes converted
 //   if p == nil, allocate and result array
 int
 mptobe(mpint *b, uchar *p, uint n, uchar **pp)
 {
-	int i, j, suppress;
-	mpdigit x;
-	uchar *e, *s, c;
+	int m;
 
+	m = (mpsignif(b)+7)/8;
+	if(m == 0)
+		m++;
 	if(p == nil){
-		n = (b->top+1)*Dbytes;
+		n = m;
 		p = malloc(n);
+		if(p == nil)
+			sysfatal("mptobe: %r");
 		setmalloctag(p, getcallerpc(&b));
+	} else {
+		if(n < m)
+			return -1;
+		if(n > m)
+			memset(p+m, 0, n-m);
 	}
-	if(p == nil)
-		return -1;
 	if(pp != nil)
 		*pp = p;
-	memset(p, 0, n);
-
-	// special case 0
-	if(b->top == 0){
-		if(n < 1)
-			return -1;
-		else
-			return 1;
-	}
-		
-	s = p;
-	e = s+n;
-	suppress = 1;
-	for(i = b->top-1; i >= 0; i--){
-		x = b->p[i];
-		for(j = Dbits-8; j >= 0; j -= 8){
-			c = x>>j;
-			if(c == 0 && suppress)
-				continue;
-			if(p >= e)
-				return -1;
-			*p++ = c;
-			suppress = 0;
-		}
-	}
-
-	// guarantee at least one byte
-	if(s == p){
-		if(p >= e)
-			return -1;
-		*p++ = 0;
-	}
-
-	return p - s;
+	mptober(b, p, m);
+	return m;
 }
--- /dev/null
+++ b/sys/src/libmp/port/mptober.c
@@ -1,0 +1,34 @@
+#include "os.h"
+#include <mp.h>
+#include "dat.h"
+
+void
+mptober(mpint *b, uchar *p, int n)
+{
+	int i, j, m;
+	mpdigit x;
+
+	memset(p, 0, n);
+
+	p += n;
+	m = b->top*Dbytes;
+	if(m < n)
+		n = m;
+
+	i = 0;
+	while(n >= Dbytes){
+		n -= Dbytes;
+		x = b->p[i++];
+		for(j = 0; j < Dbytes; j++){
+			*--p = x;
+			x >>= 8;
+		}
+	}
+	if(n > 0){
+		x = b->p[i];
+		for(j = 0; j < n; j++){
+			*--p = x;
+			x >>= 8;
+		}
+	}
+}
--- a/sys/src/libmp/port/mptoi.c
+++ b/sys/src/libmp/port/mptoi.c
@@ -10,17 +10,15 @@
 mpint*
 itomp(int i, mpint *b)
 {
-	if(b == nil)
+	if(b == nil){
 		b = mpnew(0);
-	mpassign(mpzero, b);
-	if(i != 0)
-		b->top = 1;
-	if(i < 0){
-		b->sign = -1;
-		*b->p = -i;
-	} else
-		*b->p = i;
-	return b;
+		setmalloctag(b, getcallerpc(&i));
+	}
+	b->sign = (i >> (sizeof(i)*8 - 1)) | 1;
+	i *= b->sign;
+	*b->p = i;
+	b->top = 1;
+	return mpnorm(b);
 }
 
 int
--- a/sys/src/libmp/port/mptole.c
+++ b/sys/src/libmp/port/mptole.c
@@ -3,52 +3,26 @@
 #include "dat.h"
 
 // convert an mpint into a little endian byte array (least significant byte first)
-
 //   return number of bytes converted
 //   if p == nil, allocate and result array
 int
 mptole(mpint *b, uchar *p, uint n, uchar **pp)
 {
-	int i, j;
-	mpdigit x;
-	uchar *e, *s;
+	int m;
 
+	m = (mpsignif(b)+7)/8;
+	if(m == 0)
+		m++;
 	if(p == nil){
-		n = (b->top+1)*Dbytes;
+		n = m;
 		p = malloc(n);
-	}
+		if(p == nil)
+			sysfatal("mptole: %r");
+		setmalloctag(p, getcallerpc(&b));
+	} else if(n < m)
+		return -1;
 	if(pp != nil)
 		*pp = p;
-	if(p == nil)
-		return -1;
-	memset(p, 0, n);
-
-	// special case 0
-	if(b->top == 0){
-		if(n < 1)
-			return -1;
-		else
-			return 0;
-	}
-		
-	s = p;
-	e = s+n;
-	for(i = 0; i < b->top-1; i++){
-		x = b->p[i];
-		for(j = 0; j < Dbytes; j++){
-			if(p >= e)
-				return -1;
-			*p++ = x;
-			x >>= 8;
-		}
-	}
-	x = b->p[i];
-	while(x > 0){
-		if(p >= e)
-			return -1;
-		*p++ = x;
-		x >>= 8;
-	}
-
-	return p - s;
+	mptolel(b, p, n);
+	return m;
 }
--- /dev/null
+++ b/sys/src/libmp/port/mptolel.c
@@ -1,0 +1,33 @@
+#include "os.h"
+#include <mp.h>
+#include "dat.h"
+
+void
+mptolel(mpint *b, uchar *p, int n)
+{
+	int i, j, m;
+	mpdigit x;
+
+	memset(p, 0, n);
+
+	m = b->top*Dbytes;
+	if(m < n)
+		n = m;
+
+	i = 0;
+	while(n >= Dbytes){
+		n -= Dbytes;
+		x = b->p[i++];
+		for(j = 0; j < Dbytes; j++){
+			*p++ = x;
+			x >>= 8;
+		}
+	}
+	if(n > 0){
+		x = b->p[i];
+		for(j = 0; j < n; j++){
+			*p++ = x;
+			x >>= 8;
+		}
+	}
+}
--- a/sys/src/libmp/port/mptoui.c
+++ b/sys/src/libmp/port/mptoui.c
@@ -10,13 +10,14 @@
 mpint*
 uitomp(uint i, mpint *b)
 {
-	if(b == nil)
+	if(b == nil){
 		b = mpnew(0);
-	mpassign(mpzero, b);
-	if(i != 0)
-		b->top = 1;
+		setmalloctag(b, getcallerpc(&i));
+	}
 	*b->p = i;
-	return b;
+	b->top = 1;
+	b->sign = 1;
+	return mpnorm(b);
 }
 
 uint
--- a/sys/src/libmp/port/mptouv.c
+++ b/sys/src/libmp/port/mptouv.c
@@ -13,19 +13,18 @@
 {
 	int s;
 
-	if(b == nil)
+	if(b == nil){
 		b = mpnew(VLDIGITS*sizeof(mpdigit));
-	else
+		setmalloctag(b, getcallerpc(&v));
+	}else
 		mpbits(b, VLDIGITS*sizeof(mpdigit));
-	mpassign(mpzero, b);
-	if(v == 0)
-		return b;
-	for(s = 0; s < VLDIGITS && v != 0; s++){
+	b->sign = 1;
+	for(s = 0; s < VLDIGITS; s++){
 		b->p[s] = v;
 		v >>= sizeof(mpdigit)*8;
 	}
 	b->top = s;
-	return b;
+	return mpnorm(b);
 }
 
 uvlong
@@ -37,7 +36,6 @@
 	if(b->top == 0)
 		return 0LL;
 
-	mpnorm(b);
 	if(b->top > VLDIGITS)
 		return MAXVLONG;
 
--- a/sys/src/libmp/port/mptov.c
+++ b/sys/src/libmp/port/mptov.c
@@ -14,24 +14,19 @@
 	int s;
 	uvlong uv;
 
-	if(b == nil)
+	if(b == nil){
 		b = mpnew(VLDIGITS*sizeof(mpdigit));
-	else
+		setmalloctag(b, getcallerpc(&v));
+	}else
 		mpbits(b, VLDIGITS*sizeof(mpdigit));
-	mpassign(mpzero, b);
-	if(v == 0)
-		return b;
-	if(v < 0){
-		b->sign = -1;
-		uv = -v;
-	} else
-		uv = v;
-	for(s = 0; s < VLDIGITS && uv != 0; s++){
+	b->sign = (v >> (sizeof(v)*8 - 1)) | 1;
+	uv = v * b->sign;
+	for(s = 0; s < VLDIGITS; s++){
 		b->p[s] = uv;
 		uv >>= sizeof(mpdigit)*8;
 	}
 	b->top = s;
-	return b;
+	return mpnorm(b);
 }
 
 vlong
@@ -43,7 +38,6 @@
 	if(b->top == 0)
 		return 0LL;
 
-	mpnorm(b);
 	if(b->top > VLDIGITS){
 		if(b->sign > 0)
 			return (vlong)MAXVLONG;
--- /dev/null
+++ b/sys/src/libmp/port/mpvectscmp.c
@@ -1,0 +1,34 @@
+#include "os.h"
+#include <mp.h>
+#include "dat.h"
+
+int
+mpvectscmp(mpdigit *a, int alen, mpdigit *b, int blen)
+{
+	mpdigit x, y, z, v;
+	int m, p;
+
+	if(alen > blen){
+		v = 0;
+		while(alen > blen)
+			v |= a[--alen];
+		m = p = (-v^v|v)>>Dbits-1;
+	} else if(blen > alen){
+		v = 0;
+		while(blen > alen)
+			v |= b[--blen];
+		m = (-v^v|v)>>Dbits-1;
+		p = m^1;
+	} else
+		m = p = 0;
+	while(alen-- > 0){
+		x = a[alen];
+		y = b[alen];
+		z = x - y;
+		x = ~x;
+		v = ((-z^z|z)>>Dbits-1) & ~m;
+		p = ((~(x&y|x&z|y&z)>>Dbits-1) & v) | (p & ~v);
+		m |= v;
+	}
+	return (p-m) | m;
+}
--- a/sys/src/libmp/port/strtomp.c
+++ b/sys/src/libmp/port/strtomp.c
@@ -50,7 +50,6 @@
 	int i;
 	mpdigit x;
 
-	b->top = 0;
 	for(p = a; *p; p++)
 		if(tab.t16[*(uchar*)p] == INVAL)
 			break;
@@ -157,8 +156,10 @@
 	int sign;
 	char *e;
 
-	if(b == nil)
+	if(b == nil){
 		b = mpnew(0);
+		setmalloctag(b, getcallerpc(&a));
+	}
 
 	if(tab.inited == 0)
 		init();
@@ -196,10 +197,9 @@
 	if(e == a)
 		return nil;
 
-	mpnorm(b);
-	b->sign = sign;
 	if(pp != nil)
 		*pp = e;
 
-	return b;
+	b->sign = sign;
+	return mpnorm(b);
 }