shithub: riscv

Download patch

ref: 0bfac109a491e61d7cd585060b88e1251da1e928
parent: 340d83d49d659b53e711ab59d6e819be39a0ad16
author: cinap_lenrek <[email protected]>
date: Mon Feb 1 14:27:57 EST 2016

mpc: constant expression folding

--- a/sys/src/cmd/mpc.y
+++ b/sys/src/cmd/mpc.y
@@ -28,6 +28,7 @@
 	Node*	l;
 	Node*	r;
 	Sym*	s;
+	mpint*	m;
 	int	n;
 };
 
@@ -238,11 +239,11 @@
 	{
 		$$ = new('e', $1, $2);
 	}
-|	expr LSH num
+|	expr LSH expr
 	{
 		$$ = new(LSH, $1, $3);
 	}
-|	expr RSH num
+|	expr RSH expr
 	{
 		$$ = new(RSH, $1, $3);
 	}
@@ -390,6 +391,7 @@
 	n->l = l;
 	n->r = r;
 	n->s = nil;
+	n->m = nil;
 	n->n = lineno;
 	return n;
 }
@@ -561,7 +563,7 @@
 {
 	if(n->c == NAME)
 		return 0;
-	if(n->c == NUM && strlen(n->s->n) == 1 && atoi(n->s->n) < 3)
+	if(n->c == NUM && n->m->sign > 0 && mpcmp(n->m, mptwo) <= 0)
 		return 0;
 	return 1;
 }
@@ -570,37 +572,131 @@
 bcom(Node *n, Node *t);
 
 Node*
+ccom(Node *f)
+{
+	Node *l, *r;
+
+	if(f == nil)
+		return nil;
+
+	if(f->m != nil)
+		return f;
+	f->m = (void*)~0;
+
+	switch(f->c){
+	case NUM:
+		f->m = strtomp(f->s->n, nil, 0, nil);
+		if(f->m == nil)
+			diag(f, "bad constant");
+		goto out;
+
+	case LSH:
+	case RSH:
+		break;
+
+	case '+':
+	case '-':
+	case '*':
+	case '/':
+	case '%':
+	case '^':
+		if(modulo == nil || modulo->c == NUM)
+			break;
+
+		/* wet floor */
+	default:
+		return f;
+	}
+
+	f->l = l = ccom(f->l);
+	f->r = r = ccom(f->r);
+	if(l == nil || r == nil || l->c != NUM || r->c != NUM)
+		return f;
+
+	f->m = mpnew(0);
+	switch(f->c){
+	case LSH:
+	case RSH:
+		if(mpsignif(r->m) > 32)
+			diag(f, "bad shift");
+		if(f->c == LSH)
+			mpleft(l->m, mptoi(r->m), f->m);
+		else
+			mpright(l->m, mptoi(r->m), f->m);
+		goto out;
+
+	case '+':
+		mpadd(l->m, r->m, f->m);
+		break;
+	case '-':
+		mpsub(l->m, r->m, f->m);
+		break;
+	case '*':
+		mpmul(l->m, r->m, f->m);
+		break;
+	case '/':
+		if(modulo != nil){
+			mpinvert(r->m, modulo->m, f->m);
+			mpmul(f->m, l->m, f->m);
+		} else {
+			mpdiv(l->m, r->m, f->m, nil);
+		}
+		break;
+	case '%':
+		mpmod(l->m, r->m, f->m);
+		break;
+	case '^':
+		mpexp(l->m, r->m, modulo != nil ? modulo->m : nil, f->m);
+		goto out;
+	}
+	if(modulo != nil)
+		mpmod(f->m, modulo->m, f->m);
+
+out:
+	f->l = nil;
+	f->r = nil;
+	f->s = nil;
+	f->c = NUM;
+	return f;
+}
+
+Node*
 ecom(Node *f, Node *t)
 {
 	Node *l, *r, *t2;
-	mpint *m;
 
 	if(f == nil)
 		return nil;
 
+	f = ccom(f);
 	if(f->c == NUM){
-		m = strtomp(f->s->n, nil, 0, nil);
-		if(m == nil)
-			diag(f, "bad constant");
-		if(mpcmp(m, mpzero) == 0){
+		if(f->m->sign < 0){
+			f->m->sign = 1;
+			t = ecom(f, t);
+			f->m->sign = -1;
+			if(isconst(t))
+				t = ecom(t, alloctmp());
+			cprint("%N->sign = -1;\n", t);
+			return t;
+		}
+		if(mpcmp(f->m, mpzero) == 0){
 			f->c = NAME;
 			f->s = sym("mpzero");
 			f->s->f = FSET;
 			return ecom(f, t);
 		}
-		if(mpcmp(m, mpone) == 0){
+		if(mpcmp(f->m, mpone) == 0){
 			f->c = NAME;
 			f->s = sym("mpone");
 			f->s->f = FSET;
 			return ecom(f, t);
 		}
-		if(mpcmp(m, mptwo) == 0){
+		if(mpcmp(f->m, mptwo) == 0){
 			f->c = NAME;
 			f->s = sym("mptwo");
 			f->s->f = FSET;
 			return ecom(f, t);
 		}
-		mpfree(m);
 	}
 
 	if(f->c == ','){
@@ -645,24 +741,23 @@
 
 	switch(f->c){
 	case NUM:
-		m = strtomp(f->s->n, nil, 0, nil);
-		if(m == nil)
-			diag(f, "bad constant");
-		if(mpsignif(m) <= 32)
-			cprint("uitomp(%udUL, %N);\n", mptoui(m), t);
-		else if(mpsignif(m) <= 64)
-			cprint("uvtomp(%lludULL, %N);\n", mptouv(m), t);
+		if(mpsignif(f->m) <= 32)
+			cprint("uitomp(%udUL, %N);\n", mptoui(f->m), t);
+		else if(mpsignif(f->m) <= 64)
+			cprint("uvtomp(%lludULL, %N);\n", mptouv(f->m), t);
 		else
-			cprint("strtomp(\"%.16B\", nil, 16, %N);\n", m, t);
-		mpfree(m);
+			cprint("strtomp(\"%.16B\", nil, 16, %N);\n", f->m, t);
 		goto out;
 	case LSH:
-		l = f->l->c == NAME ? f->l : ecom(f->l, t);
-		cprint("mpleft(%N, %N, %N);\n", l, f->r, t);
-		goto out;
 	case RSH:
+		r = ccom(f->r);
+		if(r == nil || r->c != NUM || mpsignif(r->m) > 32)
+			diag(f, "bad shift");
 		l = f->l->c == NAME ? f->l : ecom(f->l, t);
-		cprint("mpright(%N, %N, %N);\n", l, f->r, t);
+		if(f->c == LSH)
+			cprint("mpleft(%N, %d, %N);\n", l, mptoi(r->m), t);
+		else
+			cprint("mpright(%N, %d, %N);\n", l, mptoi(r->m), t);
 		goto out;
 	case '*':
 	case '/':
@@ -670,8 +765,10 @@
 		r = ecom(f->r, nil);
 		break;
 	default:
-		l = ecom(f->l, complex(f->l) && !symref(f->r, t->s) ? t : nil);
-		r = ecom(f->r, complex(f->r) && l->s != t->s ? t : nil);
+		l = ccom(f->l);
+		r = ccom(f->r);
+		l = ecom(l, complex(l) && !symref(r, t->s) ? t : nil);
+		r = ecom(r, complex(r) && l->s != t->s ? t : nil);
 		break;
 	}
 
@@ -975,8 +1072,11 @@
 		return fmtprint(f, "%N, %N", n->l, n->r);
 
 	switch(n->c){
-	case NAME:
 	case NUM:
+		if(n->m != nil)
+			return fmtprint(f, "%B", n->m);
+		/* wet floor */
+	case NAME:
 		return fmtprint(f, "%s", n->s->n);
 	case EQ:
 		return fmtprint(f, "==");