ref: 2b23d05d57743af57385cd42c0fd2d223b11d8c8
parent: 1da09c81867fc8416f99b8d7a24d73b2e3acf6e6
author: Peter Mikkelsen <[email protected]>
date: Sun Jul 28 07:35:34 EDT 2024
Start working on constraints. Not even close to being useful yet
--- a/array.c
+++ b/array.c
@@ -20,6 +20,7 @@
vlong *intdata;
Rune *chardata;
Array **arraydata;
+ ConstraintVar **vardata;
};
};
@@ -47,6 +48,9 @@
case TypeArray:
size *= sizeof(Array *);
break;
+ case TypeVar:
+ size *= sizeof(ConstraintVar *);
+ break;
}
a->shape = allocextra(a, (sizeof(usize) * rank) + size);
@@ -74,6 +78,12 @@
}
void
+setvar(Array *a, usize offset, ConstraintVar *v)
+{
+ a->vardata[offset] = v;
+}
+
+void
setshape(Array *a, int dim, usize size)
{
a->shape[dim] = size;
@@ -115,7 +125,15 @@
return a->arraydata[i];
}
+ConstraintVar *
+getvar(Array *a, usize i)
+{
+ return a->vardata[i];
+}
+
+static int printconstraintvar(char *, ConstraintVar *, int);
static int printarraysub(char *, Array *, int);
+static int printexpr(char *, Ast *, int);
static int
printitem(char *p, Array *a, uvlong i, int depth)
{
@@ -126,6 +144,8 @@
return sprint(p, "%C", a->chardata[i]);
case TypeArray:
return printarraysub(p, a->arraydata[i], depth);
+ case TypeVar:
+ return printconstraintvar(p, a->vardata[i], depth);
default:
return sprint(p, "???");
}
@@ -141,6 +161,33 @@
}
static int
+printconstraintvar(char *buf, ConstraintVar *v, int depth)
+{
+ static int extrainfo = 1;
+
+ char *p = buf;
+ if(v->ast)
+ p += printexpr(p, v->ast, 0);
+ else{
+ p += sprint(p, "%s⍙%d", v->name, v->id);
+ if(v->count > 0 && extrainfo){
+ p += sprint(p, " {\n");
+ for(uvlong i = 0; i < v->count; i++){
+ p += indent(p, depth+1);
+ int ei = extrainfo;
+ extrainfo = 0;
+ p += printexpr(p, v->constraints[i]->ast, depth+1);
+ extrainfo = ei;
+ p += sprint(p, "\n");
+ }
+ p += indent(p, depth);
+ p += sprint(p, "}");
+ }
+ }
+ return p-buf;
+}
+
+static int
printarraysub(char *buf, Array *a, int depth)
{
char *p = buf;
@@ -160,7 +207,7 @@
p += printitem(p, a, i, depth); /* TODO: quoting */
p += sprint(p, "'");
goto end;
- }else if(a->rank == 1 && a->type == TypeArray){
+ }else if(a->rank == 1 && (a->type == TypeArray || a->type == TypeVar)){
if(a->size == 0){
p += sprint(p, "( ⋄ )");
goto end;
@@ -175,6 +222,9 @@
}
p += sprint(p, ")");
goto end;
+ }else if(a->rank == 0 && a->type == TypeVar){
+ p += printitem(p, a, 0, depth);
+ goto end;
}
p += sprint(p, "Some array I can't print yet");
@@ -311,6 +361,7 @@
goto end;
type = a->arraydata[0]->type;
+
b = allocarray(type, a->rank, a->size);
for(uvlong dim = 0; dim < a->rank; dim++)
b->shape[dim] = a->shape[dim];
@@ -327,6 +378,12 @@
case TypeChar:
b->chardata[i] = a->arraydata[i]->chardata[0];
break;
+ case TypeVar:
+ b->vardata[i] = a->arraydata[i]->vardata[0];
+ break;
+ default:
+ b = a;
+ goto end;
}
}
end:
--- /dev/null
+++ b/constraint.c
@@ -1,0 +1,218 @@
+#include <u.h>
+#include <libc.h>
+#include <thread.h>
+
+#include "dat.h"
+#include "fns.h"
+
+/* monadic constraints */
+
+/* dyadic constraints */
+static void constraint_equal(Ast *, Array *, Array *);
+
+
+Array *
+allocvar(char *name)
+{
+ static int id = 0;
+
+ if(name == nil)
+ name = "⎕var";
+
+ ConstraintVar *v = alloc(DataConstraintVar);
+ v->name = name;
+ v->id = id++;
+
+ Array *a = allocarray(TypeVar, 0, 1);
+ setvar(a, 0, v);
+ return a;
+}
+
+static Ast *
+varast(Array *a)
+{
+ if(a == nil)
+ return nil;
+
+ if(gettype(a) == TypeVar && getrank(a) == 0){
+ ConstraintVar *v = getvar(a, 0);
+ if(v->ast)
+ return v->ast;
+ }
+ Ast *c = alloc(DataAst);
+ c->tag = AstConst;
+ c->val = a;
+
+ return c;
+}
+
+Array *
+delayedexpr(int prim, Array *x, Array *y)
+{
+ Array *a = allocvar(nil);
+ ConstraintVar *v = getvar(a, 0);
+
+ Ast *func = alloc(DataAst);
+ func->tag = AstPrim;
+ func->prim = prim;
+
+ Ast *e = alloc(DataAst);
+ v->ast = e;
+ e->func = func;
+ e->tag = x ? AstDyadic : AstMonadic;
+ e->left = varast(x);
+ e->right = varast(y);
+
+ return a;
+}
+
+void
+graphadd(ConstraintGraph *g, Constraint *c)
+{
+ for(uvlong i = 0; i < g->ccount; i++){
+ if(g->cs[i] == c)
+ return; /* The constraint is already there. TODO: make a better test */
+ }
+
+ if(g->ccount == nelem(g->cs))
+ error(EInternal, "not enough space in the constraint graph");
+ g->cs[g->ccount] = c;
+ g->ccount++;
+
+ for(uvlong i = 0; i < nelem(c->vars); i++){
+ ConstraintVar *v = c->vars[i];
+ if(v == nil)
+ continue;
+ int new = 1;
+ for(uvlong j = 0; j < g->vcount && new; j++){
+ if(g->vs[j] == v)
+ new = 0;
+ }
+ if(!new)
+ continue;
+ g->vs[g->vcount] = v;
+ g->vcount++;
+ for(uvlong j = 0; j < v->count; j++)
+ graphadd(g, v->constraints[j]);
+ }
+}
+
+Array *
+solve(ConstraintVar *v)
+{
+ Array *res;
+
+ if(v->ast)
+ error(EDomain, "Cannot solve expression. Use ⎕assert first.");
+
+ /* Consider the available constraints on the variable, and find a solutions (just one).
+ * If that isn't possible, fail with some appropriate error.
+ *
+ * There are of course multiple strategies to perform this search, and perhaps it would
+ * make sense if ⎕solve let the user specify one as the left argument.
+ */
+
+ /* Build a graph containing all the variables and constraints involved.
+ * The number of max vars and constraints are fixed for now.
+ */
+ ConstraintGraph *g = alloc(DataConstraintGraph);
+
+ for(uvlong i = 0; i < v->count; i++)
+ graphadd(g, v->constraints[i]);
+ if(g->ccount == 0){
+ /* it can have any value */
+ res = allocarray(TypeNumber, 0, 1);
+ setint(res, 0, 0);
+ }else
+ error(EInternal, "⎕solve not implemented (%ulld vars and %ulld constraints)", g->vcount, g->ccount);
+ return res;
+}
+
+void
+constrain(ConstraintVar *v)
+{
+ if(!v->ast)
+ error(EDomain, "Expected a constraint expression, not a variable.");
+
+ /* Analyse the AST and add the appropriate constraints to the variables involved.
+ * Also simplify with the constraints already there, and give an error if
+ * the simplifications show that no solutions are possible.
+ */
+ int prim, dyadic;
+ Array *left = nil;
+ Array *right = nil;
+
+ if(!(v->ast->tag == AstMonadic || v->ast->tag == AstDyadic))
+ goto fail;
+ if(v->ast->func->tag != AstPrim)
+ goto fail;
+ prim = v->ast->func->prim;
+ dyadic = 0;
+ switch(v->ast->tag){
+ case AstDyadic:
+ dyadic = 1;
+ if(v->ast->left->tag != AstConst)
+ goto fail;
+ left = v->ast->left->val;
+ /* fall through */
+ case AstMonadic:
+ if(v->ast->right->tag != AstConst)
+ goto fail;
+ right = v->ast->right->val;
+ }
+
+ switch(prim){
+ case PMatch:
+ if(dyadic)
+ constraint_equal(v->ast, left, right);
+ else
+ goto fail;
+ break;
+ default:
+ goto fail;
+ }
+ return;
+
+fail:
+ error(EInternal, "don't know how to assert the given constraint");
+}
+
+static void
+applyconstraint(Constraint *c)
+{
+ /* Find the variables involved */
+ Array *args[2];
+ args[0] = c->left;
+ args[1] = c->right;
+ int nvars = 0;
+
+ for(int i = 0; i < nelem(args); i++){
+ Array *a = args[i];
+ if(gettype(a) != TypeVar || getrank(a) != 0)
+ continue;
+ ConstraintVar *v = getvar(a, 0);
+ c->vars[nvars] = v;
+ nvars++;
+
+ v->count++;
+ v->constraints = allocextra(v, sizeof(c) * v->count);
+ v->constraints[v->count-1] = c;
+ }
+
+ /* Should simplify here as well */
+}
+
+/* monadic constraints */
+
+/* dyadic constraints */
+
+static void
+constraint_equal(Ast *a, Array *x, Array *y)
+{
+ Constraint *c = alloc(DataConstraint);
+ c->tag = CEqual;
+ c->ast = a;
+ c->left = x;
+ c->right = y;
+ applyconstraint(c);
+}
\ No newline at end of file
--- a/dat.h
+++ b/dat.h
@@ -18,6 +18,9 @@
DataLocalList,
DataErrorCtx,
DataErrorTrap,
+ DataConstraint,
+ DataConstraintVar,
+ DataConstraintGraph,
DataMax,
};
@@ -151,6 +154,7 @@
TypeNumber,
TypeChar,
TypeArray,
+ TypeVar,
};
typedef struct Array Array;
@@ -230,6 +234,7 @@
ILocal,
IPop,
IDisplay,
+ IPushVar,
};
typedef struct ValueStack ValueStack;
@@ -317,4 +322,59 @@
uvlong count;
ErrorTrap **traps;
+};
+
+enum ConstraintType
+{
+ CEqual,
+};
+
+typedef struct Constraint Constraint;
+typedef struct ConstraintVar ConstraintVar;
+
+struct Constraint
+{
+ int tag;
+ Ast *ast;
+
+ Array *left;
+ Array *right;
+
+ ConstraintVar *vars[2]; /* max 2 vars for now */
+};
+
+struct ConstraintVar
+{
+ char *name;
+ int id;
+
+ Ast *ast;
+
+ uvlong count;
+ Constraint **constraints;
+};
+
+enum PrimitiveId
+{
+ PRight,
+ PLeft,
+ PPlus,
+ PMinus,
+ PRho,
+ PMatch,
+
+ PAssert,
+ PAll,
+ PSolve,
+ PVar,
+};
+
+typedef struct ConstraintGraph ConstraintGraph;
+struct ConstraintGraph
+{
+ uvlong vcount;
+ uvlong ccount;
+
+ ConstraintVar *vs[128];
+ Constraint *cs[128];
};
\ No newline at end of file
--- a/eval.c
+++ b/eval.c
@@ -47,11 +47,13 @@
uvlong id = sym(s, a->name);
emitbyte(c, ILocal);
emituvlong(c, id);
- if(assign){
- emitbyte(c, IAssign);
+ if(!assign){ /* create a new constraint var */
+ emitbyte(c, IPushVar);
emituvlong(c, id);
- emitbyte(c, IPop);
}
+ emitbyte(c, IAssign);
+ emituvlong(c, id);
+ emitbyte(c, IPop);
}
static void
@@ -418,6 +420,10 @@
break;
case IDisplay:
/* nothing to do, IPop checks for it */
+ break;
+ case IPushVar:
+ o += getuvlong(c->instrs+o, &v);
+ pushval(values, allocvar(symname(m->symtab, v)));
break;
default:
error(EInternal, "unknown instruction in evalbc: %d", instr);
--- a/fns.h
+++ b/fns.h
@@ -4,6 +4,7 @@
void setint(Array *, usize, vlong);
void setchar(Array *, usize, Rune);
void setarray(Array *, usize, Array *);
+void setvar(Array *, usize, ConstraintVar *);
void setshape(Array *, int, usize);
int gettype(Array *);
int getrank(Array *);
@@ -11,10 +12,17 @@
vlong getint(Array *, usize);
Rune getchar(Array *, usize);
Array *getarray(Array *, usize);
+ConstraintVar *getvar(Array *, usize);
Array *simplifyarray(Array *);
char *printarray(Array *);
char *printfunc(Function *);
+
+/* constraint.c */
+Array *allocvar(char *);
+Array *delayedexpr(int, Array *, Array *);
+Array *solve(ConstraintVar *);
+void constrain(ConstraintVar *);
/* error.c */
#define trap(num) (setjmp(setuptrap(1, num)->env))
--- a/memory.c
+++ b/memory.c
@@ -41,6 +41,9 @@
[DataLocalList] = {.size = sizeof(LocalList) },
[DataErrorCtx] = {.size = sizeof(ErrorCtx) },
[DataErrorTrap] = {.size = sizeof(ErrorTrap) },
+ [DataConstraint] = {.size = sizeof(Constraint) },
+ [DataConstraintVar] = {.size = sizeof(ConstraintVar) },
+ [DataConstraintGraph] = {.size = sizeof(ConstraintGraph) },
};
void *
--- a/mkfile
+++ b/mkfile
@@ -4,6 +4,7 @@
SCRIPTS=lpa
OFILES=\
array.$O\
+ constraint.$O\
error.$O\
eval.$O\
fs.$O\
--- a/parse.c
+++ b/parse.c
@@ -38,6 +38,7 @@
else
ast = parseprog(tokens);
match(tokens, TokEnd);
+
return ast;
}
--- a/prim.c
+++ b/prim.c
@@ -7,15 +7,21 @@
/* NOTE: In LPA, system functions are treated as primitives as well */
+/* niladic functions */
+static Array *primfn_var(void);
+
/* monadic functions */
static Array *primfn_same(Array *);
static Array *primfn_shape(Array *);
+static Array *primfn_assert(Array *);
+static Array *primfn_allsolutions(Array *);
+static Array *primfn_solve(Array *);
+
/* dyadic functions */
static Array *primfn_left(Array *, Array *);
static Array *primfn_right(Array *, Array *);
static Array *primfn_match(Array *, Array *);
-
struct {
char *spelling;
int nameclass;
@@ -23,12 +29,38 @@
Array *(*monad)(Array *);
Array *(*dyad)(Array *, Array *);
} primspecs[] = {
- "⊢", NameclassFunc, nil, primfn_same, primfn_right,
- "⊣", NameclassFunc, nil, primfn_same, primfn_left,
- "+", NameclassFunc, nil, nil, nil,
- "-", NameclassFunc, nil, nil, nil,
- "⍴", NameclassFunc, nil, primfn_shape, nil,
- "≡", NameclassFunc, nil, nil, primfn_match,
+ [PRight] = {
+ "⊢", NameclassFunc, nil, primfn_same, primfn_right
+ },
+ [PLeft] = {
+ "⊣", NameclassFunc, nil, primfn_same, primfn_left,
+ },
+ [PPlus] = {
+ "+", NameclassFunc, nil, nil, nil
+ },
+ [PMinus] = {
+ "-", NameclassFunc, nil, nil, nil
+ },
+ [PRho] = {
+ "⍴", NameclassFunc, nil, primfn_shape, nil
+ },
+ [PMatch] = {
+ "≡", NameclassFunc, nil, nil, primfn_match
+ },
+
+ /* Constraint stuff. Pick glyphs for them later */
+ [PAssert] = {
+ "⎕assert", NameclassFunc, nil, primfn_assert, nil
+ },
+ [PAll] = {
+ "⎕all", NameclassFunc, nil, primfn_allsolutions, nil
+ },
+ [PSolve] = {
+ "⎕solve", NameclassFunc, nil, primfn_solve, nil
+ },
+ [PVar] = {
+ "⎕var", NameclassFunc, primfn_var, nil, nil
+ }
};
char *
@@ -47,6 +79,8 @@
primvalence(int id)
{
int valence = 0;
+ if(primspecs[id].nilad)
+ valence |= Niladic;
if(primspecs[id].monad)
valence |= Monadic;
if(primspecs[id].dyad)
@@ -68,30 +102,45 @@
Array *
primnilad(int id)
{
- if(primspecs[id].nilad)
- return primspecs[id].nilad();
- else
+ Array *(*fn)(void) = primspecs[id].nilad;
+ if(fn == nil)
error(EInternal, "primitive %s has no niladic definition", primsymb(id));
+ return fn();
}
Array *
primmonad(int id, Array *y)
{
- if(primspecs[id].monad)
- return primspecs[id].monad(y);
- else
+ Array *(*fn)(Array *) = primspecs[id].monad;
+ if(fn == nil)
error(EInternal, "primitive %s has no monadic definition", primsymb(id));
+
+ if(gettype(y) == TypeVar && !(id == PAssert || id == PSolve))
+ return delayedexpr(id, nil, y);
+
+ return fn(y);
}
Array *
primdyad(int id, Array *x, Array *y)
{
- if(primspecs[id].dyad)
- return primspecs[id].dyad(x, y);
- else
+ Array *(*fn)(Array *, Array *) = primspecs[id].dyad;
+ if(fn == nil)
error(EInternal, "primitive %s has no dyadic definition", primsymb(id));
+
+ if(gettype(x) == TypeVar || gettype(y) == TypeVar)
+ return delayedexpr(id, x, y);
+
+ return fn(x, y);
}
+/* niladic functions */
+static Array *
+primfn_var(void)
+{
+ return allocvar(nil);
+}
+
/* monadic functions */
static Array *
primfn_same(Array *a)
@@ -112,6 +161,31 @@
return r;
}
+static Array *
+primfn_assert(Array *y)
+{
+ if(gettype(y) != TypeVar || getrank(y) != 0)
+ error(EDomain, "⎕assert expected a single constraint expression");
+ constrain(getvar(y, 0));
+ Array *r = allocarray(TypeNumber, 0, 1);
+ setint(r, 0, 0);
+ return r;
+}
+
+static Array *
+primfn_allsolutions(Array *)
+{
+ error(EInternal, "⎕all should never be evaluated");
+}
+
+static Array *
+primfn_solve(Array *y)
+{
+ if(gettype(y) != TypeVar || getrank(y) != 0)
+ error(EDomain, "expected single contraint variable");
+ return solve(getvar(y, 0));
+}
+
/* dyadic functions */
static Array *
primfn_left(Array *x, Array *)
@@ -175,4 +249,4 @@
Array *z = allocarray(TypeNumber, 0, 1);
setint(z, 0, matches(x, y));
return z;
-}
\ No newline at end of file
+}
--- a/util.c
+++ b/util.c
@@ -188,6 +188,10 @@
case IDisplay:
print("DISPLAY\n");
break;
+ case IPushVar:
+ o += getuvlong(c->instrs+o, &v);
+ print("PUSHVAR %ulld\n", v);
+ break;
default:
print("???");
return;