ref: efd3ac8a2328d1baf55c296a00807052473d549e
parent: b6f04b77e3d11699d664d0ca7d0ba991f9599acc
author: cinap_lenrek <[email protected]>
date: Wed Dec 16 16:18:20 EST 2015
libmp: add mpfield() function for fast field arithmetic instead of testing for special field primes each time in mpmod(), make it explicit with a mpfiled() function that tests a modulus N to be of some special form that can be reduced more efficiently with some precalculation, and replaces N with a Mfield* when it can. the Mfield*'s are recognized by mpmod() as they have the MPfield flag set and provide a function pointer that executes the fast reduction.
--- a/sys/include/mp.h
+++ b/sys/include/mp.h
@@ -8,7 +8,6 @@
* mpdigit must be an atomic type. mpdigit is defined
* in the architecture specific u.h
*/
-
typedef struct mpint mpint;
struct mpint
@@ -25,6 +24,7 @@
MPstatic= 0x01, /* static constant */
MPnorm= 0x02, /* normalization status */
MPtimesafe= 0x04, /* request time invariant computation */
+ MPfield= 0x08, /* this mpint is a field modulus */
Dbytes= sizeof(mpdigit), /* bytes per digit */
Dbits= Dbytes*8 /* bits per digit */
@@ -165,5 +165,18 @@
void crtprefree(CRTpre*);
void crtresfree(CRTres*);
+/* fast field arithmetic */
+typedef struct Mfield Mfield;
+
+struct Mfield
+{
+ mpint;
+ int (*reduce)(Mfield*, mpint*, mpint*);
+};
+
+mpint *mpfield(mpint*);
+
+Mfield *gmfield(mpint*);
+Mfield *cnfield(mpint*);
#pragma varargck type "B" mpint*
--- /dev/null
+++ b/sys/src/libmp/port/cnfield.c
@@ -1,0 +1,114 @@
+#include "os.h"
+#include <mp.h>
+#include "dat.h"
+
+/*
+ * fast reduction for crandall numbers of the form: 2^n - c
+ */
+
+enum {
+ MAXDIG = 1024 / Dbits,
+};
+
+typedef struct CNfield CNfield;
+struct CNfield
+{
+ Mfield;
+
+ mpint m[1];
+
+ int s;
+ mpdigit c;
+};
+
+static int
+cnreduce(Mfield *m, mpint *a, mpint *r)
+{
+ mpdigit q[MAXDIG-1], t[MAXDIG], d;
+ CNfield *f = (CNfield*)m;
+ int qn, tn, k;
+
+ k = f->top;
+ if((a->top - k) >= MAXDIG)
+ return -1;
+
+ mpleft(a, f->s, r);
+ if(r->top <= k)
+ mpbits(r, (k+1)*Dbits);
+
+ /* q = hi(r) */
+ qn = r->top - k;
+ memmove(q, r->p+k, qn*Dbytes);
+
+ /* r = lo(r) */
+ r->top = k;
+ r->sign = 1;
+
+ do {
+ /* t = q*c */
+ tn = qn+1;
+ memset(t, 0, tn*Dbytes);
+ mpvecdigmuladd(q, qn, f->c, t);
+
+ /* q = hi(t) */
+ qn = tn - k;
+ if(qn <= 0) qn = 0;
+ else memmove(q, t+k, qn*Dbytes);
+
+ /* r += lo(t) */
+ if(tn > k)
+ tn = k;
+ mpvecadd(r->p, k, t, tn, r->p);
+
+ /* if(r >= m) r -= m */
+ mpvecsub(r->p, k+1, f->m->p, k, t);
+ d = t[k];
+ for(tn = 0; tn < k; tn++)
+ r->p[tn] = (r->p[tn] & d) | (t[tn] & ~d);
+ } while(qn > 0);
+
+ if(f->s != 0)
+ mpright(r, f->s, r);
+ mpnorm(r);
+
+ return 0;
+}
+
+Mfield*
+cnfield(mpint *N)
+{
+ mpint *M, *C;
+ CNfield *f;
+ mpdigit d;
+ int s;
+
+ if(N->top <= 2 || N->top >= MAXDIG)
+ return nil;
+ f = nil;
+ d = N->p[N->top-1];
+ for(s = 0; (d & (mpdigit)1<<Dbits-1) == 0; s++)
+ d <<= 1;
+ C = mpnew(0);
+ M = mpcopy(N);
+ mpleft(N, s, M);
+ mpleft(mpone, M->top*Dbits, C);
+ mpsub(C, M, C);
+ if(C->top != 1)
+ goto out;
+ f = mallocz(sizeof(CNfield) + M->top*sizeof(mpdigit), 1);
+ if(f == nil)
+ goto out;
+ f->s = s;
+ f->c = C->p[0];
+ f->m->size = M->top;
+ f->m->p = (mpdigit*)&f[1];
+ mpassign(M, f->m);
+ mpassign(N, f);
+ f->reduce = cnreduce;
+ f->flags |= MPfield;
+out:
+ mpfree(M);
+ mpfree(C);
+
+ return f;
+}
--- /dev/null
+++ b/sys/src/libmp/port/gmfield.c
@@ -1,0 +1,170 @@
+#include "os.h"
+#include <mp.h>
+#include "dat.h"
+
+/*
+ * fast reduction for generalized mersenne numbers (GM)
+ * using a series of additions and subtractions.
+ */
+
+enum {
+ MAXDIG = 1024/Dbits,
+};
+
+typedef struct GMfield GMfield;
+struct GMfield
+{
+ Mfield;
+
+ mpint m2[1];
+
+ int nadd;
+ int nsub;
+ int indx[256];
+};
+
+static int
+gmreduce(Mfield *m, mpint *a, mpint *r)
+{
+ GMfield *g = (GMfield*)m;
+ mpdigit d0, t[MAXDIG];
+ int i, j, d, *x;
+
+ if(mpmagcmp(a, g->m2) >= 0)
+ return -1;
+
+ if(a != r)
+ mpassign(a, r);
+
+ d = g->top;
+ mpbits(r, (d+1)*Dbits*2);
+ memmove(t+d, r->p+d, d*Dbytes);
+
+ r->sign = 1;
+ r->top = d;
+ r->p[d] = 0;
+
+ if(g->nsub > 0)
+ mpvecdigmuladd(g->p, d, g->nsub, r->p);
+
+ x = g->indx;
+ for(i=0; i<g->nadd; i++){
+ t[0] = 0;
+ d0 = t[*x++];
+ for(j=1; j<d; j++)
+ t[j] = t[*x++];
+ t[0] = d0;
+
+ mpvecadd(r->p, d+1, t, d, r->p);
+ }
+
+ for(i=0; i<g->nsub; i++){
+ t[0] = 0;
+ d0 = t[*x++];
+ for(j=1; j<d; j++)
+ t[j] = t[*x++];
+ t[0] = d0;
+
+ mpvecsub(r->p, d+1, t, d, r->p);
+ }
+
+ mpvecdigmulsub(g->p, d, r->p[d], r->p);
+ r->p[d] = 0;
+
+ mpvecsub(r->p, d+1, g->p, d, r->p+d+1);
+ d0 = r->p[2*d+1];
+ for(j=0; j<d; j++)
+ r->p[j] = (r->p[j] & d0) | (r->p[j+d+1] & ~d0);
+
+ mpnorm(r);
+
+ return 0;
+}
+
+Mfield*
+gmfield(mpint *N)
+{
+ int i,j,d, s, *C, *X, *x, *e;
+ mpint *M, *T;
+ GMfield *g;
+
+ d = N->top;
+ if(d <= 2 || d > MAXDIG/2 || (mpsignif(N) % Dbits) != 0)
+ return nil;
+ g = nil;
+ T = mpnew(0);
+ M = mpcopy(N);
+ C = malloc(sizeof(int)*(d+1));
+ X = malloc(sizeof(int)*(d*d));
+ for(i=0; i<=d; i++){
+ if((M->p[i]>>8) != 0 && (~M->p[i]>>8) != 0)
+ goto out;
+ j = M->p[i];
+ C[d - i] = -j;
+ itomp(j, T);
+ mpleft(T, i*Dbits, T);
+ mpsub(M, T, M);
+ }
+ for(j=0; j<d; j++)
+ X[j] = C[d-j];
+ for(i=1; i<d; i++){
+ X[d*i] = X[d*(i-1) + d-1]*C[d];
+ for(j=1; j<d; j++)
+ X[d*i + j] = X[d*(i-1) + j-1] + X[d*(i-1) + d-1]*C[d-j];
+ }
+ g = mallocz(sizeof(GMfield) + (d+1)*sizeof(mpdigit)*2, 1);
+ if(g == nil)
+ goto out;
+
+ g->m2->p = (mpdigit*)&g[1];
+ g->m2->size = d*2+1;
+ mpmul(N, N, g->m2);
+ mpassign(N, g);
+ g->reduce = gmreduce;
+ g->flags |= MPfield;
+
+ s = 0;
+ x = g->indx;
+ e = x + nelem(g->indx) - d;
+ for(g->nadd=0; x <= e; x += d, g->nadd++){
+ s = 0;
+ for(i=0; i<d; i++){
+ for(j=0; j<d; j++){
+ if(X[d*i+j] > 0 && x[j] == 0){
+ X[d*i+j]--;
+ x[j] = d+i;
+ s = 1;
+ break;
+ }
+ }
+ }
+ if(s == 0)
+ break;
+ }
+ for(g->nsub=0; x <= e; x += d, g->nsub++){
+ s = 0;
+ for(i=0; i<d; i++){
+ for(j=0; j<d; j++){
+ if(X[d*i+j] < 0 && x[j] == 0){
+ X[d*i+j]++;
+ x[j] = d+i;
+ s = 1;
+ break;
+ }
+ }
+ }
+ if(s == 0)
+ break;
+ }
+ if(s != 0){
+ mpfree(g);
+ g = nil;
+ }
+out:
+ free(C);
+ free(X);
+ mpfree(M);
+ mpfree(T);
+ return g;
+}
+
--- a/sys/src/libmp/port/mkfile
+++ b/sys/src/libmp/port/mkfile
@@ -38,6 +38,9 @@
mptoui\
mptov\
mptouv\
+ mpfield\
+ cnfield\
+ gmfield\
mplogic\
ALLOFILES=${FILES:%=%.$O}
--- a/sys/src/libmp/port/mpaux.c
+++ b/sys/src/libmp/port/mpaux.c
@@ -137,7 +137,7 @@
setmalloctag(new, getcallerpc(&old));
new->sign = old->sign;
new->top = old->top;
- new->flags = old->flags & ~MPstatic;
+ new->flags = old->flags & ~(MPstatic|MPfield);
memmove(new->p, old->p, Dbytes*old->top);
return new;
}
@@ -152,7 +152,7 @@
new->sign = old->sign;
new->top = old->top;
new->flags &= ~MPnorm;
- new->flags |= old->flags & ~MPstatic;
+ new->flags |= old->flags & ~(MPstatic|MPfield);
memmove(new->p, old->p, Dbytes*old->top);
}
--- a/sys/src/libmp/port/mpexp.c
+++ b/sys/src/libmp/port/mpexp.c
@@ -61,24 +61,22 @@
j = 0;
for(;;){
for(; bit != 0; bit >>= 1){
- mpmul(t[j], t[j], t[j^1]);
- if(bit & d)
- mpmul(t[j^1], b, t[j]);
+ if(m != nil)
+ mpmodmul(t[j], t[j], m, t[j^1]);
else
+ mpmul(t[j], t[j], t[j^1]);
+ if(bit & d) {
+ if(m != nil)
+ mpmodmul(t[j^1], b, m, t[j]);
+ else
+ mpmul(t[j^1], b, t[j]);
+ } else
j ^= 1;
- if(m != nil && t[j]->top > m->top){
- mpmod(t[j], m, t[j^1]);
- j ^= 1;
- }
}
if(--i < 0)
break;
bit = mpdighi;
d = e->p[i];
- }
- if(m != nil){
- mpmod(t[j], m, t[j^1]);
- j ^= 1;
}
if(t[j] == res){
mpfree(t[j^1]);
--- /dev/null
+++ b/sys/src/libmp/port/mpfield.c
@@ -1,0 +1,21 @@
+#include "os.h"
+#include <mp.h>
+#include "dat.h"
+
+mpint*
+mpfield(mpint *N)
+{
+ Mfield *f;
+
+ if(N == nil || N->flags & (MPfield|MPstatic))
+ return N;
+ if((f = cnfield(N)) != nil)
+ goto Exchange;
+ if((f = gmfield(N)) != nil)
+ goto Exchange;
+ return N;
+Exchange:
+ setmalloctag(f, getcallerpc(&N));
+ mpfree(N);
+ return f;
+}
--- a/sys/src/libmp/port/mpmod.c
+++ b/sys/src/libmp/port/mpmod.c
@@ -5,101 +5,12 @@
void
mpmod(mpint *x, mpint *n, mpint *r)
{
- static int busy;
- static mpint *p, *m, *c, *v;
- mpdigit q[32], t[64], d;
- int sign, k, s, qn, tn;
+ int sign;
sign = x->sign;
-
- assert(n->flags & MPnorm);
- if(n->top <= 2 || n->top > nelem(q) || (x->top-n->top) > nelem(q))
- goto hard;
-
- /*
- * check if n = 2**k - c where c has few power of two factors
- * above the lowest digit.
- */
- for(k = n->top-1; k > 0; k--){
- d = n->p[k] >> 1;
- if((d+1 & d) != 0)
- goto hard;
- }
-
- d = n->p[n->top-1];
- for(s = 0; (d & (mpdigit)1<<Dbits-1) == 0; s++)
- d <<= 1;
-
- /* lo(x) = x[0:k-1], hi(x) = x[k:xn-1] */
- k = n->top;
-
- while(_tas(&busy))
- ;
-
- if(p == nil || mpmagcmp(n, p) != 0){
- if(m == nil){
- m = mpnew(0);
- c = mpnew(0);
- p = mpnew(0);
- }
- mpleft(n, s, m);
- mpleft(mpone, k*Dbits, c);
- mpsub(c, m, c);
- if(c->top >= k){
- mpassign(mpzero, p);
- busy = 0;
- goto hard;
- }
- mpassign(n, p);
- }
-
- mpleft(x, s, r);
- if(r->top <= k){
- mpbits(r, (k+1)*Dbits);
- r->top = k+1;
- }
-
- /* q = hi(r) */
- qn = r->top - k;
- memmove(q, r->p+k, qn*Dbytes);
-
- /* r = lo(r) */
- r->top = k;
-
- do {
- /* t = q*c */
- tn = qn + c->top;
- memset(t, 0, tn*Dbytes);
- mpvecmul(q, qn, c->p, c->top, t);
-
- /* q = hi(t) */
- qn = tn - k;
- if(qn <= 0) qn = 0;
- else memmove(q, t+k, qn*Dbytes);
-
- /* r += lo(t) */
- if(tn > k)
- tn = k;
- mpvecadd(r->p, k, t, tn, r->p);
-
- /* if(r >= m) r -= m */
- mpvecsub(r->p, k+1, m->p, k, t), d = t[k];
- for(tn = 0; tn < k; tn++)
- r->p[tn] = (r->p[tn] & d) | (t[tn] & ~d);
- } while(qn > 0);
-
- busy = 0;
-
- if(s != 0)
- mpright(r, s, r);
- else
- mpnorm(r);
- goto done;
-
-hard:
- mpdiv(x, n, nil, r);
-
-done:
+ if((n->flags & MPfield) == 0
+ || ((Mfield*)n)->reduce((Mfield*)n, x, r) != 0)
+ mpdiv(x, n, nil, r);
if(sign < 0)
mpmagsub(n, r, r);
}