shithub: riscv

ref: f9efc8ed87337b0d1220822b9b0356966f8fce44
dir: /sys/src/cmd/forp/logic.c/

View raw version
#include <u.h>
#include <libc.h>
#include <mp.h>
#include <sat.h>
#include "dat.h"
#include "fns.h"

extern int satvar;

int
satand1(SATSolve *sat, int *a, int n)
{
	int i, j, r;
	int *b;

	if(n < 0)
		for(n = 0; a[n] != 0; n++)
			;
	r = 2;
	for(i = j = 0; i < n; i++){
		if(a[i] == 1 || a[i] == -2)
			return 1;
		if(a[i] == 2 || a[i] == -1)
			j++;
		else
			r = a[i];
	}
	if(j >= n - 1) return r;
	r = satvar++;
	b = malloc(sizeof(int) * (n+1));
	for(i = j = 0; i < n; i++){
		if(a[i] == 2 || a[i] == -1)
			continue;
		b[j++] = -a[i];
		sataddv(sat, -r, a[i], 0);
	}
	b[j++] = r;
	satadd1(sat, b, j);
	return r;
}

int
satandv(SATSolve *sat, ...)
{
	int r;
	va_list va;
	
	va_start(va, sat);
	satvafix(va);
	r = satand1(sat, (int*)va, -1);
	va_end(va);
	return r;
}

int
sator1(SATSolve *sat, int *a, int n)
{
	int i, j, r;
	int *b;

	if(n < 0)
		for(n = 0; a[n] != 0; n++)
			;
	r = 1;
	for(i = j = 0; i < n; i++){
		if(a[i] == 2 || a[i] == -1)
			return 2;
		if(a[i] == 1 || a[i] == -2)
			j++;
		else
			r = a[i];
	}
	if(j >= n-1) return r;
	r = satvar++;
	b = malloc(sizeof(int) * (n+1));
	for(i = j = 0; i < n; i++){
		if(a[i] == 1 || a[i] == -2)
			continue;
		b[j++] = a[i];
		sataddv(sat, r, -a[i], 0);
	}
	b[j++] = -r;
	satadd1(sat, b, j);
	return r;
}

int
satorv(SATSolve *sat, ...)
{
	va_list va;
	int r;
	
	va_start(va, sat);
	satvafix(va);
	r = sator1(sat, (int*)va, -1);
	va_end(va);
	return r;
}

typedef struct { u8int x, m; } Pi;
static Pi *π;
static int nπ;
static u64int *πm;

static void
pimp(u64int op, int n)
{
	int i, j, k;
	u8int δ;

	nπ = 0;
	for(i = 0; i < 1<<n; i++)
		if((op >> i & 1) != 0){
			π = realloc(π, sizeof(Pi) * (nπ + 1));
			π[nπ++] = (Pi){i, 0};
		}
	for(i = 0; i < nπ; i++){
		for(j = 0; j < i; j++){
			δ = π[i].x ^ π[j].x;
			if(δ == 0 || (δ & δ - 1) != 0 || π[i].m != π[j].m) continue;
			if(((π[i].m | π[j].m) & δ) != 0) continue;
			if(π[nπ-1].x == (π[i].x & π[j].x) && π[nπ-1].m == (π[i].m | δ)) continue;
			π = realloc(π, sizeof(Pi) * (nπ + 1));
			π[nπ++] = (Pi){π[i].x & π[j].x, π[i].m | δ};
		}
	}
	for(i = k = 0; i < nπ; i++){
		for(j = i+1; j < nπ; j++)
			if((π[i].m & ~π[j].m) == 0 && (π[i].x & ~π[j].m) == π[j].x)
				break;
		if(j == nπ)
			π[k++] = π[i];
	}
	nπ = k;
	assert(nπ <= 1<<n);
}

static void
pimpmask(void)
{
	int i, j;
	u64int m;

	πm = realloc(πm, sizeof(u64int) * nπ);
	for(i = 0; i < nπ; i++){
		m = 0;
		for(j = π[i].m; ; j = j - 1 & π[i].m){
			m |= 1ULL<<(π[i].x | j);
			if(j == 0) break;
		}
		πm[i] = m;
	}
}

static int
popcnt(u64int m)
{
	m = (m & 0x5555555555555555ULL) + (m >> 1 & 0x5555555555555555ULL);
	m = (m & 0x3333333333333333ULL) + (m >> 2 & 0x3333333333333333ULL);
	m = (m & 0x0F0F0F0F0F0F0F0FULL) + (m >> 4 & 0x0F0F0F0F0F0F0F0FULL);
	m = (m & 0x00FF00FF00FF00FFULL) + (m >> 8 & 0x00FF00FF00FF00FFULL);
	m = (m & 0x0000FFFF0000FFFFULL) + (m >> 16 & 0x0000FFFF0000FFFFULL);
	m = (u32int)m + (u32int)(m >> 32);
	return m;
}

static u64int
pimpcover(u64int op, int)
{
	int i, j, maxi, p, maxp;
	u64int cov, yes, m;
	
	yes = 0;
	cov = op;
	for(i = 0; i < nπ; i++){
		if((yes & 1<<i) != 0) continue;
		m = πm[i];
		for(j = 0; j < nπ; j++){
			if(j == i) continue;
			m &= ~πm[j];
			if(m == 0) break;
		}
		if(j == nπ){
			yes |= 1<<i;
			cov &= ~πm[i];
		}
	}
	while(cov != 0){
		j = popcnt(~cov & cov - 1);
		maxi = -1;
		maxp = 0;
		for(i = 0; i < nπ; i++){
			if((πm[i] & 1<<j) == 0) continue;
			if((p = popcnt(πm[i] & cov)) > maxp)
				maxi = i, maxp = p;
		}
		assert(maxi >= 0);
		yes |= 1<<maxi;
		cov &= ~πm[maxi];
	}
	return yes;
}

static void
pimpsat(SATSolve *sat, u64int yes, int *a, int n, int r)
{
	int i, j, k;
	int *cl;

	cl = emalloc(sizeof(int) * (n + 1));
	while(yes != 0){
		i = popcnt(~yes & yes - 1);
		yes &= yes - 1;
		k = 0;
		cl[k++] = r;
		for(j = 0; j < n; j++)
			if((π[i].m & 1<<j) == 0)
				cl[k++] = (π[i].x >> j & 1) != 0 ? -a[j] : a[j];
//		for(i = 0; i < k; i++) print("%d ", cl[i]); print("\n");
		satadd1(sat, cl, k);
	}
	free(cl);
}

int
satlogic1(SATSolve *sat, u64int op, int *a, int n)
{
	int i, j, o, r;
	int s;

	if(n < 0)
		for(n = 0; a[n] != 0; n++)
			;
	assert(op >> (1<<n) == 0);
	s = 0;
	j = -1;
	for(i = n; --i >= 0; ){
		if((uint)(a[i] + 2) > 4){
			if(j >= 0) break;
			j = i;
		}
		s = s << 1 | a[i] == 2 | a[i] == -1;
	}
	if(i < 0){
		if(j < 0) return 1 + (op >> s & 1);
		o = op >> s & 1 | op >> s + (1<<j) - 1 & 2;
		switch(o){
		case 0: return 1;
		case 1: return -a[j];
		case 2: return a[j];
		case 3: return 2;
		}
	}
	r = satvar++;
	pimp(op, n);
	pimpmask();
	pimpsat(sat, pimpcover(op, n), a, n, r);
	op ^= (u64int)-1 >> 64-(1<<n);
	pimp(op, n);
	pimpmask();
	pimpsat(sat, pimpcover(op, n), a, n, -r);
	return r;
}

int
satlogicv(SATSolve *sat, u64int op, ...)
{
	va_list va;
	int r;
	
	va_start(va, op);
	satvafix(va);
	r = satlogic1(sat, op, (int*)va, -1);
	va_end(va);
	return r;
}