shithub: riscv

Download patch

ref: efd3ac8a2328d1baf55c296a00807052473d549e
parent: b6f04b77e3d11699d664d0ca7d0ba991f9599acc
author: cinap_lenrek <[email protected]>
date: Wed Dec 16 16:18:20 EST 2015

libmp: add mpfield() function for fast field arithmetic

instead of testing for special field primes each time in mpmod(),
make it explicit with a mpfiled() function that tests a modulus N
to be of some special form that can be reduced more efficiently with
some precalculation, and replaces N with a Mfield* when it can. the
Mfield*'s are recognized by mpmod() as they have the MPfield flag
set and provide a function pointer that executes the fast reduction.

--- a/sys/include/mp.h
+++ b/sys/include/mp.h
@@ -8,7 +8,6 @@
  * mpdigit must be an atomic type.  mpdigit is defined
  * in the architecture specific u.h
  */
-
 typedef struct mpint mpint;
 
 struct mpint
@@ -25,6 +24,7 @@
 	MPstatic=	0x01,	/* static constant */
 	MPnorm=		0x02,	/* normalization status */
 	MPtimesafe=	0x04,	/* request time invariant computation */
+	MPfield=	0x08,	/* this mpint is a field modulus */
 
 	Dbytes=		sizeof(mpdigit),	/* bytes per digit */
 	Dbits=		Dbytes*8		/* bits per digit */
@@ -165,5 +165,18 @@
 void	crtprefree(CRTpre*);
 void	crtresfree(CRTres*);
 
+/* fast field arithmetic */
+typedef struct Mfield	Mfield;
+
+struct Mfield
+{
+	mpint;
+	int	(*reduce)(Mfield*, mpint*, mpint*);
+};
+
+mpint *mpfield(mpint*);
+
+Mfield *gmfield(mpint*);
+Mfield *cnfield(mpint*);
 
 #pragma	varargck	type	"B"	mpint*
--- /dev/null
+++ b/sys/src/libmp/port/cnfield.c
@@ -1,0 +1,114 @@
+#include "os.h"
+#include <mp.h>
+#include "dat.h"
+
+/*
+ * fast reduction for crandall numbers of the form: 2^n - c
+ */
+
+enum {
+	MAXDIG = 1024 / Dbits,
+};
+
+typedef struct CNfield CNfield;
+struct CNfield
+{
+	Mfield;	
+
+	mpint	m[1];
+
+	int	s;
+	mpdigit	c;
+};
+
+static int
+cnreduce(Mfield *m, mpint *a, mpint *r)
+{
+	mpdigit q[MAXDIG-1], t[MAXDIG], d;
+	CNfield *f = (CNfield*)m;
+	int qn, tn, k;
+
+	k = f->top;
+	if((a->top - k) >= MAXDIG)
+		return -1;
+
+	mpleft(a, f->s, r);
+	if(r->top <= k)
+		mpbits(r, (k+1)*Dbits);
+
+	/* q = hi(r) */
+	qn = r->top - k;
+	memmove(q, r->p+k, qn*Dbytes);
+
+	/* r = lo(r) */
+	r->top = k;
+	r->sign = 1;
+
+	do {
+		/* t = q*c */
+		tn = qn+1;
+		memset(t, 0, tn*Dbytes);
+		mpvecdigmuladd(q, qn, f->c, t);
+
+		/* q = hi(t) */
+		qn = tn - k;
+		if(qn <= 0) qn = 0;
+		else memmove(q, t+k, qn*Dbytes);
+
+		/* r += lo(t) */
+		if(tn > k)
+			tn = k;
+		mpvecadd(r->p, k, t, tn, r->p);
+
+		/* if(r >= m) r -= m */
+		mpvecsub(r->p, k+1, f->m->p, k, t);
+		d = t[k];
+		for(tn = 0; tn < k; tn++)
+			r->p[tn] = (r->p[tn] & d) | (t[tn] & ~d);
+	} while(qn > 0);
+
+	if(f->s != 0)
+		mpright(r, f->s, r);
+	mpnorm(r);
+
+	return 0;
+}
+
+Mfield*
+cnfield(mpint *N)
+{
+	mpint *M, *C;
+	CNfield *f;
+	mpdigit d;
+	int s;
+
+	if(N->top <= 2 || N->top >= MAXDIG)
+		return nil;
+	f = nil;
+	d = N->p[N->top-1];
+	for(s = 0; (d & (mpdigit)1<<Dbits-1) == 0; s++)
+		d <<= 1;
+	C = mpnew(0);
+	M = mpcopy(N);
+	mpleft(N, s, M);
+	mpleft(mpone, M->top*Dbits, C);
+	mpsub(C, M, C);
+	if(C->top != 1)
+		goto out;
+	f = mallocz(sizeof(CNfield) + M->top*sizeof(mpdigit), 1);
+	if(f == nil)
+		goto out;
+	f->s = s;
+	f->c = C->p[0];
+	f->m->size = M->top;
+	f->m->p = (mpdigit*)&f[1];
+	mpassign(M, f->m);
+	mpassign(N, f);
+	f->reduce = cnreduce;
+	f->flags |= MPfield;
+out:
+	mpfree(M);
+	mpfree(C);
+
+	return f;
+}
--- /dev/null
+++ b/sys/src/libmp/port/gmfield.c
@@ -1,0 +1,170 @@
+#include "os.h"
+#include <mp.h>
+#include "dat.h"
+
+/*
+ * fast reduction for generalized mersenne numbers (GM)
+ * using a series of additions and subtractions.
+ */
+
+enum {
+	MAXDIG = 1024/Dbits,
+};
+
+typedef struct GMfield GMfield;
+struct GMfield
+{
+	Mfield;	
+
+	mpint	m2[1];
+
+	int	nadd;
+	int	nsub;
+	int	indx[256];
+};
+
+static int
+gmreduce(Mfield *m, mpint *a, mpint *r)
+{
+	GMfield *g = (GMfield*)m;
+	mpdigit d0, t[MAXDIG];
+	int i, j, d, *x;
+
+	if(mpmagcmp(a, g->m2) >= 0)
+		return -1;
+
+	if(a != r)
+		mpassign(a, r);
+
+	d = g->top;
+	mpbits(r, (d+1)*Dbits*2);
+	memmove(t+d, r->p+d, d*Dbytes);
+
+	r->sign = 1;
+	r->top = d;
+	r->p[d] = 0;
+
+	if(g->nsub > 0)
+		mpvecdigmuladd(g->p, d, g->nsub, r->p);
+
+	x = g->indx;
+	for(i=0; i<g->nadd; i++){
+		t[0] = 0;
+		d0 = t[*x++];
+		for(j=1; j<d; j++)
+			t[j] = t[*x++];
+		t[0] = d0;
+
+		mpvecadd(r->p, d+1, t, d, r->p);
+	}
+
+	for(i=0; i<g->nsub; i++){
+		t[0] = 0;
+		d0 = t[*x++];
+		for(j=1; j<d; j++)
+			t[j] = t[*x++];
+		t[0] = d0;
+
+		mpvecsub(r->p, d+1, t, d, r->p);
+	}
+
+	mpvecdigmulsub(g->p, d, r->p[d], r->p);
+	r->p[d] = 0;
+
+	mpvecsub(r->p, d+1, g->p, d, r->p+d+1);
+	d0 = r->p[2*d+1];
+	for(j=0; j<d; j++)
+		r->p[j] = (r->p[j] & d0) | (r->p[j+d+1] & ~d0);
+
+	mpnorm(r);
+
+	return 0;
+}
+
+Mfield*
+gmfield(mpint *N)
+{
+	int i,j,d, s, *C, *X, *x, *e;
+	mpint *M, *T;
+	GMfield *g;
+
+	d = N->top;
+	if(d <= 2 || d > MAXDIG/2 || (mpsignif(N) % Dbits) != 0)
+		return nil;
+	g = nil;
+	T = mpnew(0);
+	M = mpcopy(N);
+	C = malloc(sizeof(int)*(d+1));
+	X = malloc(sizeof(int)*(d*d));
+	for(i=0; i<=d; i++){
+		if((M->p[i]>>8) != 0 && (~M->p[i]>>8) != 0)
+			goto out;
+		j = M->p[i];
+		C[d - i] = -j;
+		itomp(j, T);
+		mpleft(T, i*Dbits, T);
+		mpsub(M, T, M);
+	}
+	for(j=0; j<d; j++)
+		X[j] = C[d-j];
+	for(i=1; i<d; i++){
+		X[d*i] = X[d*(i-1) + d-1]*C[d];
+		for(j=1; j<d; j++)
+			X[d*i + j] = X[d*(i-1) + j-1] + X[d*(i-1) + d-1]*C[d-j];
+	}
+	g = mallocz(sizeof(GMfield) + (d+1)*sizeof(mpdigit)*2, 1);
+	if(g == nil)
+		goto out;
+
+	g->m2->p = (mpdigit*)&g[1];
+	g->m2->size = d*2+1;
+	mpmul(N, N, g->m2);
+	mpassign(N, g);
+	g->reduce = gmreduce;
+	g->flags |= MPfield;
+
+	s = 0;
+	x = g->indx;
+	e = x + nelem(g->indx) - d;
+	for(g->nadd=0; x <= e; x += d, g->nadd++){
+		s = 0;
+		for(i=0; i<d; i++){
+			for(j=0; j<d; j++){
+				if(X[d*i+j] > 0 && x[j] == 0){
+					X[d*i+j]--;
+					x[j] = d+i;
+					s = 1;
+					break;
+				}
+			}
+		}
+		if(s == 0)
+			break;
+	}
+	for(g->nsub=0; x <= e; x += d, g->nsub++){
+		s = 0;
+		for(i=0; i<d; i++){
+			for(j=0; j<d; j++){
+				if(X[d*i+j] < 0 && x[j] == 0){
+					X[d*i+j]++;
+					x[j] = d+i;
+					s = 1;
+					break;
+				}
+			}
+		}
+		if(s == 0)
+			break;
+	}
+	if(s != 0){
+		mpfree(g);
+		g = nil;
+	}
+out:
+	free(C);
+	free(X);
+	mpfree(M);
+	mpfree(T);
+	return g;
+}
+
--- a/sys/src/libmp/port/mkfile
+++ b/sys/src/libmp/port/mkfile
@@ -38,6 +38,9 @@
 	mptoui\
 	mptov\
 	mptouv\
+	mpfield\
+	cnfield\
+	gmfield\
 	mplogic\
 
 ALLOFILES=${FILES:%=%.$O}
--- a/sys/src/libmp/port/mpaux.c
+++ b/sys/src/libmp/port/mpaux.c
@@ -137,7 +137,7 @@
 	setmalloctag(new, getcallerpc(&old));
 	new->sign = old->sign;
 	new->top = old->top;
-	new->flags = old->flags & ~MPstatic;
+	new->flags = old->flags & ~(MPstatic|MPfield);
 	memmove(new->p, old->p, Dbytes*old->top);
 	return new;
 }
@@ -152,7 +152,7 @@
 	new->sign = old->sign;
 	new->top = old->top;
 	new->flags &= ~MPnorm;
-	new->flags |= old->flags & ~MPstatic;
+	new->flags |= old->flags & ~(MPstatic|MPfield);
 	memmove(new->p, old->p, Dbytes*old->top);
 }
 
--- a/sys/src/libmp/port/mpexp.c
+++ b/sys/src/libmp/port/mpexp.c
@@ -61,24 +61,22 @@
 	j = 0;
 	for(;;){
 		for(; bit != 0; bit >>= 1){
-			mpmul(t[j], t[j], t[j^1]);
-			if(bit & d)
-				mpmul(t[j^1], b, t[j]);
+			if(m != nil)
+				mpmodmul(t[j], t[j], m, t[j^1]);
 			else
+				mpmul(t[j], t[j], t[j^1]);
+			if(bit & d) {
+				if(m != nil)
+					mpmodmul(t[j^1], b, m, t[j]);
+				else
+					mpmul(t[j^1], b, t[j]);
+			} else
 				j ^= 1;
-			if(m != nil && t[j]->top > m->top){
-				mpmod(t[j], m, t[j^1]);
-				j ^= 1;
-			}
 		}
 		if(--i < 0)
 			break;
 		bit = mpdighi;
 		d = e->p[i];
-	}
-	if(m != nil){
-		mpmod(t[j], m, t[j^1]);
-		j ^= 1;
 	}
 	if(t[j] == res){
 		mpfree(t[j^1]);
--- /dev/null
+++ b/sys/src/libmp/port/mpfield.c
@@ -1,0 +1,21 @@
+#include "os.h"
+#include <mp.h>
+#include "dat.h"
+
+mpint*
+mpfield(mpint *N)
+{
+	Mfield *f;
+
+	if(N == nil || N->flags & (MPfield|MPstatic))
+		return N;
+	if((f = cnfield(N)) != nil)
+		goto Exchange;
+	if((f = gmfield(N)) != nil)
+		goto Exchange;
+	return N;
+Exchange:
+	setmalloctag(f, getcallerpc(&N));
+	mpfree(N);
+	return f;
+}
--- a/sys/src/libmp/port/mpmod.c
+++ b/sys/src/libmp/port/mpmod.c
@@ -5,101 +5,12 @@
 void
 mpmod(mpint *x, mpint *n, mpint *r)
 {
-	static int busy;
-	static mpint *p, *m, *c, *v;
-	mpdigit q[32], t[64], d;
-	int sign, k, s, qn, tn;
+	int sign;
 
 	sign = x->sign;
-
-	assert(n->flags & MPnorm);
-	if(n->top <= 2 || n->top > nelem(q) || (x->top-n->top) > nelem(q))
-		goto hard;
-
-	/*
-	 * check if n = 2**k - c where c has few power of two factors
-	 * above the lowest digit.
-	 */
-	for(k = n->top-1; k > 0; k--){
-		d = n->p[k] >> 1;
-		if((d+1 & d) != 0)
-			goto hard;
-	}
-
-	d = n->p[n->top-1];
-	for(s = 0; (d & (mpdigit)1<<Dbits-1) == 0; s++)
-		d <<= 1;
-
-	/* lo(x) = x[0:k-1], hi(x) = x[k:xn-1] */
-	k = n->top;
-
-	while(_tas(&busy))
-		;
-
-	if(p == nil || mpmagcmp(n, p) != 0){
-		if(m == nil){
-			m = mpnew(0);
-			c = mpnew(0);
-			p = mpnew(0);
-		}
-		mpleft(n, s, m);
-		mpleft(mpone, k*Dbits, c);
-		mpsub(c, m, c);
-		if(c->top >= k){
-			mpassign(mpzero, p);
-			busy = 0;
-			goto hard;
-		}
-		mpassign(n, p);
-	}
-
-	mpleft(x, s, r);
-	if(r->top <= k){
-		mpbits(r, (k+1)*Dbits);
-		r->top = k+1;
-	}
-
-	/* q = hi(r) */
-	qn = r->top - k;
-	memmove(q, r->p+k, qn*Dbytes);
-
-	/* r = lo(r) */
-	r->top = k;
-
-	do {
-		/* t = q*c */
-		tn = qn + c->top;
-		memset(t, 0, tn*Dbytes);
-		mpvecmul(q, qn, c->p, c->top, t);
-
-		/* q = hi(t) */
-		qn = tn - k;
-		if(qn <= 0) qn = 0;
-		else memmove(q, t+k, qn*Dbytes);
-
-		/* r += lo(t) */
-		if(tn > k)
-			tn = k;
-		mpvecadd(r->p, k, t, tn, r->p);
-
-		/* if(r >= m) r -= m */
-		mpvecsub(r->p, k+1, m->p, k, t), d = t[k];
-		for(tn = 0; tn < k; tn++)
-			r->p[tn] = (r->p[tn] & d) | (t[tn] & ~d);
-	} while(qn > 0);
-
-	busy = 0;
-
-	if(s != 0)
-		mpright(r, s, r);
-	else
-		mpnorm(r);
-	goto done;
-
-hard:
-	mpdiv(x, n, nil, r);
-
-done:
+	if((n->flags & MPfield) == 0
+	|| ((Mfield*)n)->reduce((Mfield*)n, x, r) != 0)
+		mpdiv(x, n, nil, r);
 	if(sign < 0)
 		mpmagsub(n, r, r);
 }