shithub: riscv

ref: 8a53b8192db4848626722a00bf4478a2836102d8
dir: /sys/src/cmd/upas/bayes/bayes.c/

View raw version
#include <u.h>
#include <libc.h>
#include <bio.h>
#include <regexp.h>
#include "hash.h"

enum
{
	MAXTAB = 256,
	MAXBEST = 32,
};

typedef struct Table Table;
struct Table
{
	char *file;
	Hash *hash;
	int nmsg;
};

typedef struct Word Word;
struct Word
{
	Stringtab *s;	/* from hmsg */
	int count[MAXTAB];	/* counts from each table */
	double p[MAXTAB];	/* probabilities from each table */
	double mp;	/* max probability */
	int mi;		/* w.p[w.mi] = w.mp */
};

Table tab[MAXTAB];
int ntab;

Word best[MAXBEST];
int mbest;
int nbest;

int debug;

void
usage(void)
{
	fprint(2, "usage: bayes [-D] [-m maxword] boxhash ... ~ msghash ...\n");
	exits("usage");
}

void*
emalloc(int n)
{
	void *v;

	v = mallocz(n, 1);
	if(v == nil)
		sysfatal("out of memory");
	return v;
}

void
noteword(Word *w)
{
	int i;

	for(i=nbest-1; i>=0; i--)
		if(w->mp < best[i].mp)
			break;
	i++;

	if(i >= mbest)
		return;
	if(nbest == mbest)
		nbest--;
	if(i < nbest)
		memmove(&best[i+1], &best[i], (nbest-i)*sizeof(best[0]));
	best[i] = *w;
	nbest++;
}

Hash*
hread(char *s)
{
	Hash *h;
	Biobuf *b;

	if((b = Bopenlock(s, OREAD)) == nil)
		sysfatal("open %s: %r", s);

	h = emalloc(sizeof(Hash));
	Breadhash(b, h, 1);
	Bterm(b);
	return h;
}

void
main(int argc, char **argv)
{
	int i, j, a, mi, oi, tot, keywords;
	double totp, p, xp[MAXTAB];
	Hash *hmsg;
	Word w;
	Stringtab *s, *t;
	Biobuf bout;

	mbest = 15;
	keywords = 0;
	ARGBEGIN{
	case 'D':
		debug = 1;
		break;
	case 'k':
		keywords = 1;
		break;
	case 'm':
		mbest = atoi(EARGF(usage()));
		if(mbest > MAXBEST)
			sysfatal("cannot keep more than %d words", MAXBEST);
		break;
	default:
		usage();
	}ARGEND

	for(i=0; i<argc; i++)
		if(strcmp(argv[i], "~") == 0)
			break;

	if(i > MAXTAB)
		sysfatal("cannot handle more than %d tables", MAXTAB);

	if(i+1 >= argc)
		usage();

	for(i=0; i<argc; i++){
		if(strcmp(argv[i], "~") == 0)
			break;
		tab[ntab].file = argv[i];
		tab[ntab].hash = hread(argv[i]);
		s = findstab(tab[ntab].hash, "*nmsg*", 6, 1);
		if(s == nil || s->count == 0)
			tab[ntab].nmsg = 1;
		else
			tab[ntab].nmsg = s->count;
		ntab++;
	}

	Binit(&bout, 1, OWRITE);

	oi = ++i;
	for(a=i; a<argc; a++){
		hmsg = hread(argv[a]);
		nbest = 0;
		for(s=hmsg->all; s; s=s->link){
			w.s = s;
			tot = 0;
			totp = 0.0;
			for(i=0; i<ntab; i++){
				t = findstab(tab[i].hash, s->str, s->n, 0);
				if(t == nil)
					w.count[i] = 0;
				else
					w.count[i] = t->count;
				tot += w.count[i];
				p = w.count[i]/(double)tab[i].nmsg;
				if(p >= 1.0)
					p = 1.0;
				w.p[i] = p;
				totp += p;
			}

			if(tot < 5){		/* word does not appear enough; give to box 0 */
				w.p[0] = 0.5;
				for(i=1; i<ntab; i++)
					w.p[i] = 0.1;
				w.mp = 0.5;
				w.mi = 0;
				noteword(&w);
				continue;
			}

			w.mp = 0.0;
			for(i=0; i<ntab; i++){
				p = w.p[i];
				p /= totp;
				if(p < 0.01)
					p = 0.01;
				else if(p > 0.99)
					p = 0.99;
				if(p > w.mp){
					w.mp = p;
					w.mi = i;
				}
				w.p[i] = p;
			}
			noteword(&w);
		}

		totp = 0.0;
		for(i=0; i<ntab; i++){
			p = 1.0;
			for(j=0; j<nbest; j++)
				p *= best[j].p[i];
			xp[i] = p;
			totp += p;
		}
		for(i=0; i<ntab; i++)
			xp[i] /= totp;
		mi = 0;
		for(i=1; i<ntab; i++)
			if(xp[i] > xp[mi])
				mi = i;
		if(oi != argc-1)
			Bprint(&bout, "%s: ", argv[a]);
		Bprint(&bout, "%s %f", tab[mi].file, xp[mi]);
		if(keywords){
			for(i=0; i<nbest; i++){
				Bprint(&bout, " ");
				Bwrite(&bout, best[i].s->str, best[i].s->n);
				Bprint(&bout, " %f", best[i].p[mi]);
			}
		}
		freehash(hmsg);
		Bprint(&bout, "\n");
		if(debug){
			for(i=0; i<nbest; i++){
				Bwrite(&bout, best[i].s->str, best[i].s->n);
				Bprint(&bout, " %f", best[i].p[mi]);
				if(best[i].p[mi] < best[i].mp)
					Bprint(&bout, " (%f %s)", best[i].mp, tab[best[i].mi].file);
				Bprint(&bout, "\n");
			}
		}
	}
	Bterm(&bout);
}