ref: 38e1e5272fc9c66a00d702246813135452819ffe
parent: b677ab0c5909942bf8946e9e9bd148dea7dae718
author: cinap_lenrek <[email protected]>
date: Sat Nov 21 04:39:59 EST 2015
libmp: initial attempt at constant time code, faster reductions for special primes (for ecc) introduce MPtimesafe flag to request time invariant computation disables normalization so significant digits are not leaked.
--- a/sys/include/mp.h
+++ b/sys/include/mp.h
@@ -22,7 +22,10 @@
enum
{
- MPstatic= 0x01,
+ MPstatic= 0x01, /* static constant */
+ MPnorm= 0x02, /* normalization status */
+ MPtimesafe= 0x04, /* request time invariant computation */
+
Dbytes= sizeof(mpdigit), /* bytes per digit */
Dbits= Dbytes*8 /* bits per digit */
};
@@ -32,7 +35,7 @@
mpint* mpnew(int n); /* create a new mpint with at least n bits */
void mpfree(mpint *b);
void mpbits(mpint *b, int n); /* ensure that b has at least n bits */
-void mpnorm(mpint *b); /* dump leading zeros */
+mpint* mpnorm(mpint *b); /* dump leading zeros */
mpint* mpcopy(mpint *b);
void mpassign(mpint *old, mpint *new);
@@ -47,8 +50,10 @@
char* mptoa(mpint*, int, char*, int);
mpint* letomp(uchar*, uint, mpint*); /* byte array, little-endian */
int mptole(mpint*, uchar*, uint, uchar**);
+void mptolel(mpint *b, uchar *p, int n);
mpint* betomp(uchar*, uint, mpint*); /* byte array, big-endian */
int mptobe(mpint*, uchar*, uint, uchar**);
+void mptober(mpint *b, uchar *p, int n);
uint mptoui(mpint*); /* unsigned int */
mpint* uitomp(uint, mpint*);
int mptoi(mpint*); /* int */
@@ -71,6 +76,11 @@
void mpexp(mpint *b, mpint *e, mpint *m, mpint *res); /* res = b**e mod m */
void mpmod(mpint *b, mpint *m, mpint *remainder); /* remainder = b mod m */
+/* modular arithmetic, time invariant when 0≤b1≤m-1 and 0≤b2≤m-1 */
+void mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum); /* sum = b1+b2 % m */
+void mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff); /* diff = b1-b2 % m */
+void mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod); /* prod = b1*b2 % m */
+
/* quotient = dividend/divisor, remainder = dividend % divisor */
void mpdiv(mpint *dividend, mpint *divisor, mpint *quotient, mpint *remainder);
@@ -77,6 +87,9 @@
/* return neg, 0, pos as b1-b2 is neg, 0, pos */
int mpcmp(mpint *b1, mpint *b2);
+/* res = s != 0 ? b1 : b2 */
+void mpsel(int s, mpint *b1, mpint *b2, mpint *res);
+
/* extended gcd return d, x, and y, s.t. d = gcd(a,b) and ax+by = d */
void mpextendedgcd(mpint *a, mpint *b, mpint *d, mpint *x, mpint *y);
@@ -106,12 +119,14 @@
/* prereq: p has room for n+1 digits */
int mpvecdigmulsub(mpdigit *b, int n, mpdigit m, mpdigit *p);
-/* p[0:alen*blen-1] = a[0:alen-1] * b[0:blen-1] */
+/* p[0:alen+blen-1] = a[0:alen-1] * b[0:blen-1] */
/* prereq: alen >= blen, p has room for m*n digits */
void mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p);
+void mpvectsmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p);
/* sign of a - b or zero if the same */
int mpveccmp(mpdigit *a, int alen, mpdigit *b, int blen);
+int mpvectscmp(mpdigit *a, int alen, mpdigit *b, int blen);
/* divide the 2 digit dividend by the one digit divisor and stick in quotient */
/* we assume that the result is one digit - overflow is all 1's */
--- a/sys/man/2/mp
+++ b/sys/man/2/mp
@@ -1,6 +1,6 @@
.TH MP 2
.SH NAME
-mpsetminbits, mpnew, mpfree, mpbits, mpnorm, mpcopy, mpassign, mprand, mpnrand, strtomp, mpfmt,mptoa, betomp, mptobe, letomp, mptole, mptoui, uitomp, mptoi, itomp, uvtomp, mptouv, vtomp, mptov, mpdigdiv, mpadd, mpsub, mpleft, mpright, mpmul, mpexp, mpmod, mpdiv, mpcmp, mpextendedgcd, mpinvert, mpsignif, mplowbits0, mpvecdigmuladd, mpvecdigmulsub, mpvecadd, mpvecsub, mpveccmp, mpvecmul, mpmagcmp, mpmagadd, mpmagsub, crtpre, crtin, crtout, crtprefree, crtresfree \- extended precision arithmetic
+mpsetminbits, mpnew, mpfree, mpbits, mpnorm, mpcopy, mpassign, mprand, mpnrand, strtomp, mpfmt,mptoa, betomp, mptobe, mptober, letomp, mptole, mptolel, mptoui, uitomp, mptoi, itomp, uvtomp, mptouv, vtomp, mptov, mpdigdiv, mpadd, mpsub, mpleft, mpright, mpmul, mpexp, mpmod, mpmodadd, mpmodsub, mpmodmul, mpdiv, mpcmp, mpsel, mpextendedgcd, mpinvert, mpsignif, mplowbits0, mpvecdigmuladd, mpvecdigmulsub, mpvecadd, mpvecsub, mpveccmp, mpvecmul, mpmagcmp, mpmagadd, mpmagsub, crtpre, crtin, crtout, crtprefree, crtresfree \- extended precision arithmetic
.SH SYNOPSIS
.B #include <u.h>
.br
@@ -22,7 +22,7 @@
void mpbits(mpint *b, int n)
.PP
.B
-void mpnorm(mpint *b)
+mpint* mpnorm(mpint *b)
.PP
.B
mpint* mpcopy(mpint *b)
@@ -52,6 +52,9 @@
int mptobe(mpint *b, uchar *buf, uint blen, uchar **bufp)
.PP
.B
+void mptober(mpint *b, uchar *buf, int blen)
+.PP
+.B
mpint* letomp(uchar *buf, uint blen, mpint *b)
.PP
.B
@@ -58,6 +61,9 @@
int mptole(mpint *b, uchar *buf, uint blen, uchar **bufp)
.PP
.B
+void mptolel(mpint *b, uchar *buf, int blen)
+.PP
+.B
uint mptoui(mpint*)
.PP
.B
@@ -115,6 +121,15 @@
mpint *remainder)
.PP
.B
+void mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum)
+.PP
+.B
+void mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff)
+.PP
+.B
+void mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod)
+.PP
+.B
int mpcmp(mpint *b1, mpint *b2)
.PP
.B
@@ -121,6 +136,9 @@
int mpmagcmp(mpint *b1, mpint *b2)
.PP
.B
+void mpsel(int s, mpint *b1, mpint *b2, mpint *res)
+.PP
+.B
void mpextendedgcd(mpint *a, mpint *b, mpint *d, mpint *x,
.br
.B
@@ -383,6 +401,24 @@
Sign is ignored in these conversions, i.e., the byte
array version is always positive.
.PP
+.I Mptober
+and
+.I mptolel
+fill
+.I blen
+lower bytes of an
+.I mpint
+into a fixed length byte array.
+.I Mptober
+fills the bytes right adjusted in big endian order so that the least
+significant byte is at
+.I buf[blen-1]
+while
+.I mptolel
+fills in little endian order; left adjusted; so that the least
+significat byte is filled into
+.IR buf[0] .
+.PP
.IR Betomp ,
and
.I letomp
@@ -486,8 +522,33 @@
the same as
.I mpcmp
but ignores the sign and just compares magnitudes.
+.TP
+.I mpsel
+assigns
+.I b1
+to
+.I res
+when
+.I s
+is not zero, otherwise
+.I b2
+is assigned to
+.IR res .
.PD
.PP
+Modular arithmetic:
+.TF mpmodmul_
+.TP
+.I mpmodadd
+.BR "sum = b1+b2 mod m" .
+.TP
+.I mpmodsub
+.BR "diff = b1-b2 mod m" .
+.TP
+.I mpmodmul
+.BR "prod = b1*b2 mod m" .
+.PD
+.PP
.I Mpextendedgcd
computes the greatest common denominator,
.IR d ,
@@ -564,8 +625,8 @@
-1 if negative.
.TP
.I mpvecmul
-.BR "p[0:alen*blen] = a[0:alen-1] * b[0:blen-1]" .
-We assume that p has room for alen*blen+1 digits.
+.BR "p[0:alen+blen] = a[0:alen-1] * b[0:blen-1]" .
+We assume that p has room for alen+blen+1 digits.
.TP
.I mpveccmp
This returns -1, 0, or +1 as a - b is negative, 0, or positive.
@@ -576,6 +637,17 @@
and
.I mpzero
are the constants 2, 1 and 0. These cannot be freed.
+.SS "Time invariant computation"
+.PP
+In the field of cryptography, it is sometimes neccesary to implement
+algorithms such that the runtime of the algorithm is not depdenent on
+the input data. This library provides partial support for time
+invariant computation with the
+.I MPtimesafe
+flag that can be set on input or destination operands to request timing
+safe operation. The result of a timing safe operation will also have the
+.I MPtimesafe
+flag set and is not normalized.
.SS "Chinese remainder theorem
.PP
When computing in a non-prime modulus,
--- a/sys/src/libmp/port/betomp.c
+++ b/sys/src/libmp/port/betomp.c
@@ -13,19 +13,12 @@
b = mpnew(0);
setmalloctag(b, getcallerpc(&p));
}
-
- // dump leading zeros
- while(*p == 0 && n > 1){
- p++;
- n--;
- }
-
- // get the space
mpbits(b, n*8);
- b->top = DIGITS(n*8);
- m = b->top-1;
- // first digit might not be Dbytes long
+ m = DIGITS(n*8);
+ b->top = m--;
+ b->sign = 1;
+
s = ((n-1)*8)%Dbits;
x = 0;
for(; n > 0; n--){
@@ -37,6 +30,5 @@
x = 0;
}
}
-
- return b;
+ return mpnorm(b);
}
--- a/sys/src/libmp/port/letomp.c
+++ b/sys/src/libmp/port/letomp.c
@@ -9,8 +9,10 @@
int i=0, m = 0;
mpdigit x=0;
- if(b == nil)
+ if(b == nil){
b = mpnew(0);
+ setmalloctag(b, getcallerpc(&s));
+ }
mpbits(b, 8*n);
for(; n > 0; n--){
x |= ((mpdigit)(*s++)) << i;
@@ -24,5 +26,6 @@
if(i > 0)
b->p[m++] = x;
b->top = m;
- return b;
+ b->sign = 1;
+ return mpnorm(b);
}
--- a/sys/src/libmp/port/mkfile
+++ b/sys/src/libmp/port/mkfile
@@ -6,12 +6,15 @@
mpfmt\
strtomp\
mptobe\
+ mptober\
mptole\
+ mptolel\
betomp\
letomp\
mpadd\
mpsub\
mpcmp\
+ mpsel\
mpfactorial\
mpmul\
mpleft\
@@ -20,10 +23,12 @@
mpvecsub\
mpvecdigmuladd\
mpveccmp\
+ mpvectscmp\
mpdigdiv\
mpdiv\
mpexp\
mpmod\
+ mpmodop\
mpextendedgcd\
mpinvert\
mprand\
--- a/sys/src/libmp/port/mpadd.c
+++ b/sys/src/libmp/port/mpadd.c
@@ -9,6 +9,8 @@
int m, n;
mpint *t;
+ sum->flags |= (b1->flags | b2->flags) & MPtimesafe;
+
// get the sizes right
if(b2->top > b1->top){
t = b1;
@@ -41,6 +43,7 @@
int sign;
if(b1->sign != b2->sign){
+ assert(((b1->flags | b2->flags | sum->flags) & MPtimesafe) == 0);
if(b1->sign < 0)
mpmagsub(b2, b1, sum);
else
--- a/sys/src/libmp/port/mpaux.c
+++ b/sys/src/libmp/port/mpaux.c
@@ -5,11 +5,9 @@
static mpdigit _mptwodata[1] = { 2 };
static mpint _mptwo =
{
- 1,
- 1,
- 1,
+ 1, 1, 1,
_mptwodata,
- MPstatic
+ MPstatic|MPnorm
};
mpint *mptwo = &_mptwo;
@@ -16,11 +14,9 @@
static mpdigit _mponedata[1] = { 1 };
static mpint _mpone =
{
- 1,
- 1,
- 1,
+ 1, 1, 1,
_mponedata,
- MPstatic
+ MPstatic|MPnorm
};
mpint *mpone = &_mpone;
@@ -27,11 +23,9 @@
static mpdigit _mpzerodata[1] = { 0 };
static mpint _mpzero =
{
- 1,
- 1,
- 0,
+ 1, 1, 0,
_mpzerodata,
- MPstatic
+ MPstatic|MPnorm
};
mpint *mpzero = &_mpzero;
@@ -57,18 +51,17 @@
if(n < 0)
sysfatal("mpsetminbits: n < 0");
- b = mallocz(sizeof(mpint), 1);
- setmalloctag(b, getcallerpc(&n));
- if(b == nil)
- sysfatal("mpnew: %r");
n = DIGITS(n);
if(n < mpmindigits)
n = mpmindigits;
- b->p = (mpdigit*)mallocz(n*Dbytes, 1);
- if(b->p == nil)
+ b = mallocz(sizeof(mpint) + n*Dbytes, 1);
+ if(b == nil)
sysfatal("mpnew: %r");
+ setmalloctag(b, getcallerpc(&n));
+ b->p = (mpdigit*)&b[1];
b->size = n;
b->sign = 1;
+ b->flags = MPnorm;
return b;
}
@@ -83,16 +76,23 @@
if(b->size >= n){
if(b->top >= n)
return;
- memset(&b->p[b->top], 0, Dbytes*(n - b->top));
- b->top = n;
- return;
+ } else {
+ if(b->p == (mpdigit*)&b[1]){
+ b->p = (mpdigit*)mallocz(n*Dbytes, 0);
+ if(b->p == nil)
+ sysfatal("mpbits: %r");
+ memmove(b->p, &b[1], Dbytes*b->top);
+ memset(&b[1], 0, Dbytes*b->size);
+ } else {
+ b->p = (mpdigit*)realloc(b->p, n*Dbytes);
+ if(b->p == nil)
+ sysfatal("mpbits: %r");
+ }
+ b->size = n;
}
- b->p = (mpdigit*)realloc(b->p, n*Dbytes);
- if(b->p == nil)
- sysfatal("mpbits: %r");
memset(&b->p[b->top], 0, Dbytes*(n - b->top));
- b->size = n;
b->top = n;
+ b->flags &= ~MPnorm;
}
void
@@ -102,16 +102,22 @@
return;
if(b->flags & MPstatic)
sysfatal("freeing mp constant");
- memset(b->p, 0, b->size*Dbytes); // information hiding
- free(b->p);
+ memset(b->p, 0, b->size*Dbytes);
+ if(b->p != (mpdigit*)&b[1])
+ free(b->p);
free(b);
}
-void
+mpint*
mpnorm(mpint *b)
{
int i;
+ if(b->flags & MPtimesafe){
+ assert(b->sign == 1);
+ b->flags &= ~MPnorm;
+ return b;
+ }
for(i = b->top-1; i >= 0; i--)
if(b->p[i] != 0)
break;
@@ -118,6 +124,8 @@
b->top = i+1;
if(b->top == 0)
b->sign = 1;
+ b->flags |= MPnorm;
+ return b;
}
mpint*
@@ -126,8 +134,10 @@
mpint *new;
new = mpnew(Dbits*old->size);
- new->top = old->top;
+ setmalloctag(new, getcallerpc(&old));
new->sign = old->sign;
+ new->top = old->top;
+ new->flags = old->flags & ~MPstatic;
memmove(new->p, old->p, Dbytes*old->top);
return new;
}
@@ -135,9 +145,14 @@
void
mpassign(mpint *old, mpint *new)
{
+ if(new == nil || old == new)
+ return;
+ new->top = 0;
mpbits(new, Dbits*old->top);
new->sign = old->sign;
new->top = old->top;
+ new->flags &= ~MPnorm;
+ new->flags |= old->flags & ~MPstatic;
memmove(new->p, old->p, Dbytes*old->top);
}
@@ -167,6 +182,7 @@
int k, bit, digit;
mpdigit d;
+ assert(n->flags & MPnorm);
if(n->top==0)
return 0;
k = 0;
@@ -187,4 +203,3 @@
}
return k;
}
-
--- a/sys/src/libmp/port/mpcmp.c
+++ b/sys/src/libmp/port/mpcmp.c
@@ -8,10 +8,14 @@
{
int i;
- i = b1->top - b2->top;
- if(i)
- return i;
-
+ i = b1->flags | b2->flags;
+ if(i & MPtimesafe)
+ return mpvectscmp(b1->p, b1->top, b2->p, b2->top);
+ if(i & MPnorm){
+ i = b1->top - b2->top;
+ if(i)
+ return i;
+ }
return mpveccmp(b1->p, b1->top, b2->p, b2->top);
}
@@ -19,10 +23,8 @@
int
mpcmp(mpint *b1, mpint *b2)
{
- if(b1->sign != b2->sign)
- return b1->sign - b2->sign;
- if(b1->sign < 0)
- return mpmagcmp(b2, b1);
- else
- return mpmagcmp(b1, b2);
+ int sign;
+
+ sign = (b1->sign - b2->sign) >> 1; // -1, 0, 1
+ return sign | (sign&1)-1 & mpmagcmp(b1, b2)*b1->sign;
}
--- a/sys/src/libmp/port/mpdiv.c
+++ b/sys/src/libmp/port/mpdiv.c
@@ -13,10 +13,29 @@
mpdigit qd, *up, *vp, *qp;
mpint *u, *v, *t;
+ assert(quotient != remainder);
+ assert(divisor->flags & MPnorm);
+
// divide bv zero
if(divisor->top == 0)
abort();
+ // division by one or small powers of two
+ if(divisor->top == 1 && (divisor->p[0] & divisor->p[0]-1) == 0){
+ vlong r = (vlong)dividend->sign * (dividend->p[0] & divisor->p[0]-1);
+ if(quotient != nil){
+ for(s = 0; ((divisor->p[0] >> s) & 1) == 0; s++)
+ ;
+ mpright(dividend, s, quotient);
+ }
+ if(remainder != nil){
+ remainder->flags |= dividend->flags & MPtimesafe;
+ vtomp(r, remainder);
+ }
+ return;
+ }
+ assert((dividend->flags & MPtimesafe) == 0);
+
// quick check
if(mpmagcmp(dividend, divisor) < 0){
if(remainder != nil)
@@ -95,6 +114,7 @@
*up-- = 0;
}
if(qp != nil){
+ assert((quotient->flags & MPtimesafe) == 0);
mpnorm(quotient);
if(dividend->sign != divisor->sign)
quotient->sign = -1;
@@ -101,6 +121,7 @@
}
if(remainder != nil){
+ assert((remainder->flags & MPtimesafe) == 0);
mpright(u, s, remainder); // u is the remainder shifted
remainder->sign = dividend->sign;
}
--- a/sys/src/libmp/port/mpeuclid.c
+++ b/sys/src/libmp/port/mpeuclid.c
@@ -13,6 +13,9 @@
{
mpint *tmp, *x0, *x1, *x2, *y0, *y1, *y2, *q, *r;
+ assert((a->flags&b->flags) & MPnorm);
+ assert(((a->flags|b->flags|d->flags|x->flags|y->flags) & MPtimesafe) == 0);
+
if(a->sign<0 || b->sign<0)
sysfatal("mpeuclid: negative arg");
--- a/sys/src/libmp/port/mpexp.c
+++ b/sys/src/libmp/port/mpexp.c
@@ -22,6 +22,10 @@
mpdigit d, bit;
int i, j;
+ assert(m->flags & MPnorm);
+ assert((e->flags & MPtimesafe) == 0);
+ res->flags |= b->flags & MPtimesafe;
+
i = mpcmp(e,mpzero);
if(i==0){
mpassign(mpone, res);
--- a/sys/src/libmp/port/mpextendedgcd.c
+++ b/sys/src/libmp/port/mpextendedgcd.c
@@ -5,7 +5,7 @@
// extended binary gcd
//
-// For a anv b it solves, v = gcd(a,b) and finds x and y s.t.
+// For a and b it solves, v = gcd(a,b) and finds x and y s.t.
// ax + by = v
//
// Handbook of Applied Cryptography, Menezes et al, 1997, pg 608.
@@ -14,6 +14,9 @@
{
mpint *u, *A, *B, *C, *D;
int g;
+
+ assert((a->flags&b->flags) & MPnorm);
+ assert(((a->flags|b->flags|v->flags|x->flags|y->flags) & MPtimesafe) == 0);
if(a->sign < 0 || b->sign < 0){
mpassign(mpzero, v);
--- a/sys/src/libmp/port/mpfmt.c
+++ b/sys/src/libmp/port/mpfmt.c
@@ -102,6 +102,7 @@
return -1;
d = mpcopy(b);
+ mpnorm(d);
r = mpnew(0);
billion = uitomp(1000000000, nil);
out = buf+len;
@@ -128,14 +129,19 @@
mpfmt(Fmt *fmt)
{
mpint *b;
- char *p;
+ char *p, f;
b = va_arg(fmt->args, mpint*);
if(b == nil)
return fmtstrcpy(fmt, "*");
-
+
+ f = b->flags;
+ b->flags &= ~MPtimesafe;
+
p = mptoa(b, fmt->prec, nil, 0);
fmt->flags &= ~FmtPrec;
+
+ b->flags = f;
if(p == nil)
return fmtstrcpy(fmt, "*");
--- a/sys/src/libmp/port/mpleft.c
+++ b/sys/src/libmp/port/mpleft.c
@@ -15,8 +15,8 @@
return;
}
- // a negative left shift is a right shift
- if(shift < 0){
+ // a zero or negative left shift is a right shift
+ if(shift <= 0){
mpright(b, -shift, res);
return;
}
@@ -46,7 +46,6 @@
for(i = 0; i < d; i++)
res->p[i] = 0;
- // normalize
- while(res->top > 0 && res->p[res->top-1] == 0)
- res->top--;
+ res->flags |= b->flags & MPtimesafe;
+ mpnorm(res);
}
--- a/sys/src/libmp/port/mpmod.c
+++ b/sys/src/libmp/port/mpmod.c
@@ -2,14 +2,100 @@
#include <mp.h>
#include "dat.h"
-// remainder = b mod m
-//
-// knuth, vol 2, pp 398-400
-
void
-mpmod(mpint *b, mpint *m, mpint *remainder)
+mpmod(mpint *x, mpint *n, mpint *r)
{
- mpdiv(b, m, nil, remainder);
- if(remainder->sign < 0)
- mpadd(m, remainder, remainder);
+ static int busy;
+ static mpint *p, *m, *c, *v;
+ mpdigit q[32], t[64], d;
+ int sign, k, s, qn, tn;
+
+ sign = x->sign;
+
+ assert(n->flags & MPnorm);
+ if(n->top < 2 || n->top > nelem(q) || (x->top-n->top) > nelem(q))
+ goto hard;
+
+ /*
+ * check if n = 2**k - c where c has few power of two factors
+ * above the lowest digit.
+ */
+ for(k = n->top-1; k > 0; k--){
+ d = n->p[k] >> 1;
+ if((d+1 & d) != 0)
+ goto hard;
+ }
+
+ d = n->p[n->top-1];
+ for(s = 0; (d & (mpdigit)1<<Dbits-1) == 0; s++)
+ d <<= 1;
+
+ /* lo(x) = x[0:k-1], hi(x) = x[k:xn-1] */
+ k = n->top;
+
+ while(_tas(&busy))
+ ;
+
+ if(p == nil || mpmagcmp(n, p) != 0){
+ if(m == nil){
+ m = mpnew(0);
+ c = mpnew(0);
+ p = mpnew(0);
+ }
+ mpassign(n, p);
+
+ mpleft(n, s, m);
+ mpleft(mpone, k*Dbits, c);
+ mpsub(c, m, c);
+ }
+
+ mpleft(x, s, r);
+ if(r->top <= k){
+ mpbits(r, (k+1)*Dbits);
+ r->top = k+1;
+ }
+
+ /* q = hi(r) */
+ qn = r->top - k;
+ memmove(q, r->p+k, qn*Dbytes);
+
+ /* r = lo(r) */
+ r->top = k;
+
+ do {
+ /* t = q*c */
+ tn = qn + c->top;
+ memset(t, 0, tn*Dbytes);
+ mpvecmul(q, qn, c->p, c->top, t);
+
+ /* q = hi(t) */
+ qn = tn - k;
+ if(qn <= 0) qn = 0;
+ else memmove(q, t+k, qn*Dbytes);
+
+ /* r += lo(t) */
+ if(tn > k)
+ tn = k;
+ mpvecadd(r->p, k, t, tn, r->p);
+
+ /* if(r >= m) r -= m */
+ mpvecsub(r->p, k+1, m->p, k, t), d = t[k];
+ for(tn = 0; tn < k; tn++)
+ r->p[tn] = (r->p[tn] & d) | (t[tn] & ~d);
+ } while(qn > 0);
+
+ busy = 0;
+
+ if(s != 0)
+ mpright(r, s, r);
+ else
+ mpnorm(r);
+ goto done;
+
+hard:
+ mpdiv(x, n, nil, r);
+
+done:
+ if(sign < 0)
+ mpmagsub(n, r, r);
}
--- /dev/null
+++ b/sys/src/libmp/port/mpmodop.c
@@ -1,0 +1,96 @@
+#include <u.h>
+#include <libc.h>
+#include <mp.h>
+
+/* operands need to have m->top+1 digits of space and satisfy 0 ≤ a ≤ m-1 */
+static mpint*
+modarg(mpint *a, mpint *m)
+{
+ if(a->size <= m->top || a->sign < 0 || mpmagcmp(a, m) >= 0){
+ a = mpcopy(a);
+ mpmod(a, m, a);
+ mpbits(a, Dbits*(m->top+1));
+ a->top = m->top;
+ } else if(a->top < m->top){
+ memset(&a->p[a->top], 0, (m->top - a->top)*Dbytes);
+ }
+ return a;
+}
+
+void
+mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum)
+{
+ mpint *a, *b;
+ mpdigit d;
+ int i, j;
+
+ a = modarg(b1, m);
+ b = modarg(b2, m);
+
+ sum->flags |= (a->flags | b->flags) & MPtimesafe;
+ mpbits(sum, Dbits*2*(m->top+1));
+
+ mpvecadd(a->p, m->top, b->p, m->top, sum->p);
+ mpvecsub(sum->p, m->top+1, m->p, m->top, sum->p+m->top+1);
+
+ d = sum->p[2*m->top+1];
+ for(i = 0, j = m->top+1; i < m->top; i++, j++)
+ sum->p[i] = (sum->p[i] & d) | (sum->p[j] & ~d);
+
+ sum->top = m->top;
+ sum->sign = 1;
+ mpnorm(sum);
+
+ if(a != b1)
+ mpfree(a);
+ if(b != b2)
+ mpfree(b);
+}
+
+void
+mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff)
+{
+ mpint *a, *b;
+ mpdigit d;
+ int i, j;
+
+ a = modarg(b1, m);
+ b = modarg(b2, m);
+
+ diff->flags |= (a->flags | b->flags) & MPtimesafe;
+ mpbits(diff, Dbits*2*(m->top+1));
+
+ a->p[m->top] = 0;
+ mpvecsub(a->p, m->top+1, b->p, m->top, diff->p);
+ mpvecadd(diff->p, m->top, m->p, m->top, diff->p+m->top+1);
+
+ d = ~diff->p[m->top];
+ for(i = 0, j = m->top+1; i < m->top; i++, j++)
+ diff->p[i] = (diff->p[i] & d) | (diff->p[j] & ~d);
+
+ diff->top = m->top;
+ diff->sign = 1;
+ mpnorm(diff);
+
+ if(a != b1)
+ mpfree(a);
+ if(b != b2)
+ mpfree(b);
+}
+
+void
+mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod)
+{
+ mpint *a, *b;
+
+ a = modarg(b1, m);
+ b = modarg(b2, m);
+
+ mpmul(a, b, prod);
+ mpmod(prod, m, prod);
+
+ if(a != b1)
+ mpfree(a);
+ if(b != b2)
+ mpfree(b);
+}
--- a/sys/src/libmp/port/mpmul.c
+++ b/sys/src/libmp/port/mpmul.c
@@ -113,10 +113,6 @@
a = b;
b = t;
}
- if(blen == 0){
- memset(p, 0, Dbytes*(alen+blen));
- return;
- }
if(alen >= KARATSUBAMIN && blen > 1){
// O(n^1.585)
@@ -132,24 +128,48 @@
}
void
+mpvectsmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
+{
+ int i;
+ mpdigit *t;
+
+ if(alen < blen){
+ i = alen;
+ alen = blen;
+ blen = i;
+ t = a;
+ a = b;
+ b = t;
+ }
+ if(blen == 0)
+ return;
+ for(i = 0; i < blen; i++)
+ mpvecdigmuladd(a, alen, b[i], &p[i]);
+}
+
+void
mpmul(mpint *b1, mpint *b2, mpint *prod)
{
mpint *oprod;
- oprod = nil;
+ oprod = prod;
if(prod == b1 || prod == b2){
- oprod = prod;
prod = mpnew(0);
+ prod->flags = oprod->flags;
}
+ prod->flags |= (b1->flags | b2->flags) & MPtimesafe;
prod->top = 0;
mpbits(prod, (b1->top+b2->top+1)*Dbits);
- mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p);
+ if(prod->flags & MPtimesafe)
+ mpvectsmul(b1->p, b1->top, b2->p, b2->top, prod->p);
+ else
+ mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p);
prod->top = b1->top+b2->top+1;
prod->sign = b1->sign*b2->sign;
mpnorm(prod);
- if(oprod != nil){
+ if(oprod != prod){
mpassign(prod, oprod);
mpfree(prod);
}
--- a/sys/src/libmp/port/mpnrand.c
+++ b/sys/src/libmp/port/mpnrand.c
@@ -16,8 +16,10 @@
mpleft(mpone, bits, m);
mpsub(m, mpone, m);
- if(b == nil)
+ if(b == nil){
b = mpnew(bits);
+ setmalloctag(b, getcallerpc(&n));
+ }
/* m = m - (m % n) */
mpmod(m, n, b);
--- a/sys/src/libmp/port/mprand.c
+++ b/sys/src/libmp/port/mprand.c
@@ -6,19 +6,20 @@
mpint*
mprand(int bits, void (*gen)(uchar*, int), mpint *b)
{
- int n, m;
mpdigit mask;
+ int n, m;
uchar *p;
n = DIGITS(bits);
- if(b == nil)
+ if(b == nil){
b = mpnew(bits);
- else
+ setmalloctag(b, getcallerpc(&bits));
+ }else
mpbits(b, bits);
p = malloc(n*Dbytes);
if(p == nil)
- return nil;
+ sysfatal("mprand: %r");
(*gen)(p, n*Dbytes);
betomp(p, n*Dbytes, b);
free(p);
@@ -25,18 +26,12 @@
// make sure we don't give too many bits
m = bits%Dbits;
- n--;
- if(m > 0){
- mask = 1;
- mask <<= m;
- mask--;
- b->p[n] &= mask;
- }
+ if(m == 0)
+ return b;
- for(; n >= 0; n--)
- if(b->p[n] != 0)
- break;
- b->top = n+1;
- b->sign = 1;
- return b;
+ mask = 1;
+ mask <<= m;
+ mask--;
+ b->p[n-1] &= mask;
+ return mpnorm(b);
}
--- a/sys/src/libmp/port/mpright.c
+++ b/sys/src/libmp/port/mpright.c
@@ -23,6 +23,9 @@
if(res != b)
mpbits(res, b->top*Dbits - shift);
+ else if(shift == 0)
+ return;
+
d = shift/Dbits;
r = shift - d*Dbits;
l = Dbits - r;
@@ -29,6 +32,7 @@
// shift all the bits out == zero
if(d>=b->top){
+ res->sign = 1;
res->top = 0;
return;
}
@@ -46,9 +50,8 @@
}
res->p[i++] = last>>r;
}
- while(i > 0 && res->p[i-1] == 0)
- i--;
+
res->top = i;
- if(i==0)
- res->sign = 1;
+ res->flags |= b->flags & MPtimesafe;
+ mpnorm(res);
}
--- /dev/null
+++ b/sys/src/libmp/port/mpsel.c
@@ -1,0 +1,42 @@
+#include "os.h"
+#include <mp.h>
+#include "dat.h"
+
+// res = s != 0 ? b1 : b2
+void
+mpsel(int s, mpint *b1, mpint *b2, mpint *res)
+{
+ mpdigit d;
+ int n, m, i;
+
+ res->flags |= (b1->flags | b2->flags) & MPtimesafe;
+ if((res->flags & MPtimesafe) == 0){
+ mpassign(s ? b1 : b2, res);
+ return;
+ }
+ res->flags &= ~MPnorm;
+
+ n = b1->top;
+ m = b2->top;
+ mpbits(res, Dbits*(n >= m ? n : m));
+ res->top = n >= m ? n : m;
+
+ s = (-s^s|s)>>(sizeof(s)*8-1);
+ res->sign = (b1->sign & s) | (b2->sign & ~s);
+
+ d = -((mpdigit)s & 1);
+
+ i = 0;
+ while(i < n && i < m){
+ res->p[i] = (b1->p[i] & d) | (b2->p[i] & ~d);
+ i++;
+ }
+ while(i < n){
+ res->p[i] = b1->p[i] & d;
+ i++;
+ }
+ while(i < m){
+ res->p[i] = b2->p[i] & ~d;
+ i++;
+ }
+}
--- a/sys/src/libmp/port/mpsub.c
+++ b/sys/src/libmp/port/mpsub.c
@@ -11,12 +11,15 @@
// get the sizes right
if(mpmagcmp(b1, b2) < 0){
+ assert(((b1->flags | b2->flags | diff->flags) & MPtimesafe) == 0);
sign = -1;
t = b1;
b1 = b2;
b2 = t;
- } else
+ } else {
+ diff->flags |= (b1->flags | b2->flags) & MPtimesafe;
sign = 1;
+ }
n = b1->top;
m = b2->top;
if(m == 0){
@@ -39,6 +42,7 @@
int sign;
if(b1->sign != b2->sign){
+ assert(((b1->flags | b2->flags | diff->flags) & MPtimesafe) == 0);
sign = b1->sign;
mpmagadd(b1, b2, diff);
diff->sign = sign;
--- a/sys/src/libmp/port/mptobe.c
+++ b/sys/src/libmp/port/mptobe.c
@@ -2,57 +2,31 @@
#include <mp.h>
#include "dat.h"
-// convert an mpint into a big endian byte array (most significant byte first)
+// convert an mpint into a big endian byte array (most significant byte first; left adjusted)
// return number of bytes converted
// if p == nil, allocate and result array
int
mptobe(mpint *b, uchar *p, uint n, uchar **pp)
{
- int i, j, suppress;
- mpdigit x;
- uchar *e, *s, c;
+ int m;
+ m = (mpsignif(b)+7)/8;
+ if(m == 0)
+ m++;
if(p == nil){
- n = (b->top+1)*Dbytes;
+ n = m;
p = malloc(n);
+ if(p == nil)
+ sysfatal("mptobe: %r");
setmalloctag(p, getcallerpc(&b));
+ } else {
+ if(n < m)
+ return -1;
+ if(n > m)
+ memset(p+m, 0, n-m);
}
- if(p == nil)
- return -1;
if(pp != nil)
*pp = p;
- memset(p, 0, n);
-
- // special case 0
- if(b->top == 0){
- if(n < 1)
- return -1;
- else
- return 1;
- }
-
- s = p;
- e = s+n;
- suppress = 1;
- for(i = b->top-1; i >= 0; i--){
- x = b->p[i];
- for(j = Dbits-8; j >= 0; j -= 8){
- c = x>>j;
- if(c == 0 && suppress)
- continue;
- if(p >= e)
- return -1;
- *p++ = c;
- suppress = 0;
- }
- }
-
- // guarantee at least one byte
- if(s == p){
- if(p >= e)
- return -1;
- *p++ = 0;
- }
-
- return p - s;
+ mptober(b, p, m);
+ return m;
}
--- /dev/null
+++ b/sys/src/libmp/port/mptober.c
@@ -1,0 +1,34 @@
+#include "os.h"
+#include <mp.h>
+#include "dat.h"
+
+void
+mptober(mpint *b, uchar *p, int n)
+{
+ int i, j, m;
+ mpdigit x;
+
+ memset(p, 0, n);
+
+ p += n;
+ m = b->top*Dbytes;
+ if(m < n)
+ n = m;
+
+ i = 0;
+ while(n >= Dbytes){
+ n -= Dbytes;
+ x = b->p[i++];
+ for(j = 0; j < Dbytes; j++){
+ *--p = x;
+ x >>= 8;
+ }
+ }
+ if(n > 0){
+ x = b->p[i];
+ for(j = 0; j < n; j++){
+ *--p = x;
+ x >>= 8;
+ }
+ }
+}
--- a/sys/src/libmp/port/mptoi.c
+++ b/sys/src/libmp/port/mptoi.c
@@ -10,17 +10,15 @@
mpint*
itomp(int i, mpint *b)
{
- if(b == nil)
+ if(b == nil){
b = mpnew(0);
- mpassign(mpzero, b);
- if(i != 0)
- b->top = 1;
- if(i < 0){
- b->sign = -1;
- *b->p = -i;
- } else
- *b->p = i;
- return b;
+ setmalloctag(b, getcallerpc(&i));
+ }
+ b->sign = (i >> (sizeof(i)*8 - 1)) | 1;
+ i *= b->sign;
+ *b->p = i;
+ b->top = 1;
+ return mpnorm(b);
}
int
--- a/sys/src/libmp/port/mptole.c
+++ b/sys/src/libmp/port/mptole.c
@@ -3,52 +3,26 @@
#include "dat.h"
// convert an mpint into a little endian byte array (least significant byte first)
-
// return number of bytes converted
// if p == nil, allocate and result array
int
mptole(mpint *b, uchar *p, uint n, uchar **pp)
{
- int i, j;
- mpdigit x;
- uchar *e, *s;
+ int m;
+ m = (mpsignif(b)+7)/8;
+ if(m == 0)
+ m++;
if(p == nil){
- n = (b->top+1)*Dbytes;
+ n = m;
p = malloc(n);
- }
+ if(p == nil)
+ sysfatal("mptole: %r");
+ setmalloctag(p, getcallerpc(&b));
+ } else if(n < m)
+ return -1;
if(pp != nil)
*pp = p;
- if(p == nil)
- return -1;
- memset(p, 0, n);
-
- // special case 0
- if(b->top == 0){
- if(n < 1)
- return -1;
- else
- return 0;
- }
-
- s = p;
- e = s+n;
- for(i = 0; i < b->top-1; i++){
- x = b->p[i];
- for(j = 0; j < Dbytes; j++){
- if(p >= e)
- return -1;
- *p++ = x;
- x >>= 8;
- }
- }
- x = b->p[i];
- while(x > 0){
- if(p >= e)
- return -1;
- *p++ = x;
- x >>= 8;
- }
-
- return p - s;
+ mptolel(b, p, n);
+ return m;
}
--- /dev/null
+++ b/sys/src/libmp/port/mptolel.c
@@ -1,0 +1,33 @@
+#include "os.h"
+#include <mp.h>
+#include "dat.h"
+
+void
+mptolel(mpint *b, uchar *p, int n)
+{
+ int i, j, m;
+ mpdigit x;
+
+ memset(p, 0, n);
+
+ m = b->top*Dbytes;
+ if(m < n)
+ n = m;
+
+ i = 0;
+ while(n >= Dbytes){
+ n -= Dbytes;
+ x = b->p[i++];
+ for(j = 0; j < Dbytes; j++){
+ *p++ = x;
+ x >>= 8;
+ }
+ }
+ if(n > 0){
+ x = b->p[i];
+ for(j = 0; j < n; j++){
+ *p++ = x;
+ x >>= 8;
+ }
+ }
+}
--- a/sys/src/libmp/port/mptoui.c
+++ b/sys/src/libmp/port/mptoui.c
@@ -10,13 +10,14 @@
mpint*
uitomp(uint i, mpint *b)
{
- if(b == nil)
+ if(b == nil){
b = mpnew(0);
- mpassign(mpzero, b);
- if(i != 0)
- b->top = 1;
+ setmalloctag(b, getcallerpc(&i));
+ }
*b->p = i;
- return b;
+ b->top = 1;
+ b->sign = 1;
+ return mpnorm(b);
}
uint
--- a/sys/src/libmp/port/mptouv.c
+++ b/sys/src/libmp/port/mptouv.c
@@ -13,19 +13,18 @@
{
int s;
- if(b == nil)
+ if(b == nil){
b = mpnew(VLDIGITS*sizeof(mpdigit));
- else
+ setmalloctag(b, getcallerpc(&v));
+ }else
mpbits(b, VLDIGITS*sizeof(mpdigit));
- mpassign(mpzero, b);
- if(v == 0)
- return b;
- for(s = 0; s < VLDIGITS && v != 0; s++){
+ b->sign = 1;
+ for(s = 0; s < VLDIGITS; s++){
b->p[s] = v;
v >>= sizeof(mpdigit)*8;
}
b->top = s;
- return b;
+ return mpnorm(b);
}
uvlong
@@ -37,7 +36,6 @@
if(b->top == 0)
return 0LL;
- mpnorm(b);
if(b->top > VLDIGITS)
return MAXVLONG;
--- a/sys/src/libmp/port/mptov.c
+++ b/sys/src/libmp/port/mptov.c
@@ -14,24 +14,19 @@
int s;
uvlong uv;
- if(b == nil)
+ if(b == nil){
b = mpnew(VLDIGITS*sizeof(mpdigit));
- else
+ setmalloctag(b, getcallerpc(&v));
+ }else
mpbits(b, VLDIGITS*sizeof(mpdigit));
- mpassign(mpzero, b);
- if(v == 0)
- return b;
- if(v < 0){
- b->sign = -1;
- uv = -v;
- } else
- uv = v;
- for(s = 0; s < VLDIGITS && uv != 0; s++){
+ b->sign = (v >> (sizeof(v)*8 - 1)) | 1;
+ uv = v * b->sign;
+ for(s = 0; s < VLDIGITS; s++){
b->p[s] = uv;
uv >>= sizeof(mpdigit)*8;
}
b->top = s;
- return b;
+ return mpnorm(b);
}
vlong
@@ -43,7 +38,6 @@
if(b->top == 0)
return 0LL;
- mpnorm(b);
if(b->top > VLDIGITS){
if(b->sign > 0)
return (vlong)MAXVLONG;
--- /dev/null
+++ b/sys/src/libmp/port/mpvectscmp.c
@@ -1,0 +1,34 @@
+#include "os.h"
+#include <mp.h>
+#include "dat.h"
+
+int
+mpvectscmp(mpdigit *a, int alen, mpdigit *b, int blen)
+{
+ mpdigit x, y, z, v;
+ int m, p;
+
+ if(alen > blen){
+ v = 0;
+ while(alen > blen)
+ v |= a[--alen];
+ m = p = (-v^v|v)>>Dbits-1;
+ } else if(blen > alen){
+ v = 0;
+ while(blen > alen)
+ v |= b[--blen];
+ m = (-v^v|v)>>Dbits-1;
+ p = m^1;
+ } else
+ m = p = 0;
+ while(alen-- > 0){
+ x = a[alen];
+ y = b[alen];
+ z = x - y;
+ x = ~x;
+ v = ((-z^z|z)>>Dbits-1) & ~m;
+ p = ((~(x&y|x&z|y&z)>>Dbits-1) & v) | (p & ~v);
+ m |= v;
+ }
+ return (p-m) | m;
+}
--- a/sys/src/libmp/port/strtomp.c
+++ b/sys/src/libmp/port/strtomp.c
@@ -50,7 +50,6 @@
int i;
mpdigit x;
- b->top = 0;
for(p = a; *p; p++)
if(tab.t16[*(uchar*)p] == INVAL)
break;
@@ -157,8 +156,10 @@
int sign;
char *e;
- if(b == nil)
+ if(b == nil){
b = mpnew(0);
+ setmalloctag(b, getcallerpc(&a));
+ }
if(tab.inited == 0)
init();
@@ -196,10 +197,9 @@
if(e == a)
return nil;
- mpnorm(b);
- b->sign = sign;
if(pp != nil)
*pp = e;
- return b;
+ b->sign = sign;
+ return mpnorm(b);
}