shithub: riscv

ref: 6fa3e08412c49a188e2abe7d9ee54cd1f37a85f5
dir: /sys/src/cmd/9nfs/server.c/

View raw version
#include "all.h"
#include <ndb.h>

static int	alarmflag;

static int	Iconv(Fmt*);
static void	cachereply(Rpccall*, void*, int);
static int	replycache(int, Rpccall*, long (*)(int, void*, long));
static void	udpserver(int, Progmap*);
static void	tcpserver(int, Progmap*);
static void	getendpoints(Udphdr*, char*);
static long	readtcp(int, void*, long);
static long	writetcp(int, void*, long);
static int	servemsg(int, long (*)(int, void*, long), long (*)(int, void*, long),
		int, Progmap*);
void	(*rpcalarm)(void);
int	rpcdebug;
int	rejectall;
int	p9debug;

int	nocache;

uchar	buf[9000];
uchar	rbuf[9000];
uchar	resultbuf[9000];

static int tcp;

char *commonopts = "[-9CDrtv]";			/* for usage() messages */

/*
 * this recognises common, nominally rcp-related options.
 * they may not take arguments.
 */
int
argopt(int c)
{
	switch(c){
	case '9':
		++p9debug;
		return 0;
	case 'C':
		++nocache;
		return 0;
	case 'D':
		++rpcdebug;
		return 0;
	case 'r':
		++rejectall;
		return 0;
	case 't':
		tcp = 1;
		return 0;
	case 'v':
		++chatty;
		return 0;
	default:
		return -1;
	}
}

/*
 * all option parsing is now done in (*pg->init)(), which can call back
 * here to argopt for common options.
 */
void
server(int argc, char **argv, int myport, Progmap *progmap)
{
	Progmap *pg;

	fmtinstall('I', Iconv);
	fmtinstall('F', fcallfmt);
	fmtinstall('D', dirfmt);

	switch(rfork(RFNOWAIT|RFENVG|RFNAMEG|RFNOTEG|RFFDG|RFPROC)){
	case -1:
		panic("fork");
	default:
		_exits(0);
	case 0:
		break;
	}

	switch(rfork(RFMEM|RFPROC)){
	case 0:
		for(;;){
			sleep(30*1000);
			alarmflag = 1;
		}
	case -1:
		sysfatal("rfork: %r");
	}

	for(pg=progmap; pg->init; pg++)
		(*pg->init)(argc, argv);
	if(tcp)
		tcpserver(myport, progmap);
	else
		udpserver(myport, progmap);
}

static void
udpserver(int myport, Progmap *progmap)
{
	char service[128];
	char data[128];
	char devdir[40];
	int ctlfd, datafd;

	snprint(service, sizeof service, "udp!*!%d", myport);
	ctlfd = announce(service, devdir);
	if(ctlfd < 0)
		panic("can't announce %s: %r\n", service);
	if(fprint(ctlfd, "headers") < 0)
		panic("can't set header mode: %r\n");

	snprint(data, sizeof data, "%s/data", devdir);
	datafd = open(data, ORDWR);
	if(datafd < 0)
		panic("can't open udp data: %r\n");
	close(ctlfd);

	chatsrv(0);
	clog("%s: listening to port %d\n", argv0, myport);
	while (servemsg(datafd, read, write, myport, progmap) >= 0)
		continue;
	exits(0);
}

static void
tcpserver(int myport, Progmap *progmap)
{
	char adir[40];
	char ldir[40];
	char ds[40];
	int actl, lctl, data;

	snprint(ds, sizeof ds, "tcp!*!%d", myport);
	chatsrv(0);
	actl = -1;
	for(;;){
		if(actl < 0){
			actl = announce(ds, adir);
			if(actl < 0){
				clog("%s: listening to tcp port %d\n",
					argv0, myport);
				clog("announcing: %r");
				break;
			}
		}
		lctl = listen(adir, ldir);
		if(lctl < 0){
			close(actl);
			actl = -1;
			continue;
		}
		switch(fork()){
		case -1:
			clog("%s!%d: %r\n", argv0, myport);
			/* fall through */
		default:
			close(lctl);
			continue;
		case 0:
			close(actl);
			data = accept(lctl, ldir);
			close(lctl);
			if(data < 0)
				exits(0);

			/* pretend it's udp; fill in Udphdr */
			getendpoints((Udphdr*)buf, ldir);

			while (servemsg(data, readtcp, writetcp, myport,
			    progmap) >= 0)
				continue;
			close(data);
			exits(0);
		}
	}
	exits(0);
}

static int
servemsg(int fd, long (*readmsg)(int, void*, long), long (*writemsg)(int, void*, long),
		int myport, Progmap * progmap)
{
	int i, n, nreply;
	Rpccall rcall, rreply;
	int vlo, vhi;
	Progmap *pg;
	Procmap *pp;
	char errbuf[ERRMAX];

	if(alarmflag){
		alarmflag = 0;
		if(rpcalarm)
			(*rpcalarm)();
	}
	n = (*readmsg)(fd, buf, sizeof buf);
	if(n < 0){
		errstr(errbuf, sizeof errbuf);
		if(strcmp(errbuf, "interrupted") == 0)
			return 0;
		clog("port %d: error: %s\n", myport, errbuf);
		return -1;
	}
	if(n == 0){
		clog("port %d: EOF\n", myport);
		return -1;
	}
	if(rpcdebug == 1)
		fprint(2, "%s: rpc from %d.%d.%d.%d/%d\n",
			argv0, buf[12], buf[13], buf[14], buf[15],
			(buf[32]<<8)|buf[33]);
	i = rpcM2S(buf, &rcall, n);
	if(i != 0){
		clog("udp port %d: message format error %d\n",
			myport, i);
		return 0;
	}
	if(rpcdebug > 1)
		rpcprint(2, &rcall);
	if(rcall.mtype != CALL)
		return 0;
	if(replycache(fd, &rcall, writemsg))
		return 0;
	nreply = 0;
	rreply.host = rcall.host;
	rreply.port = rcall.port;
	rreply.lhost = rcall.lhost;
	rreply.lport = rcall.lport;
	rreply.xid = rcall.xid;
	rreply.mtype = REPLY;
	if(rcall.rpcvers != 2){
		rreply.stat = MSG_DENIED;
		rreply.rstat = RPC_MISMATCH;
		rreply.rlow = 2;
		rreply.rhigh = 2;
		goto send_reply;
	}
	if(rejectall){
		rreply.stat = MSG_DENIED;
		rreply.rstat = AUTH_ERROR;
		rreply.authstat = AUTH_TOOWEAK;
		goto send_reply;
	}
	i = n - (((uchar *)rcall.args) - buf);
	if(rpcdebug > 1)
		fprint(2, "arg size = %d\n", i);
	rreply.stat = MSG_ACCEPTED;
	rreply.averf.flavor = 0;
	rreply.averf.count = 0;
	rreply.results = resultbuf;
	vlo = 0x7fffffff;
	vhi = -1;
	for(pg=progmap; pg->pmap; pg++){
		if(pg->progno != rcall.prog)
			continue;
		if(pg->vers == rcall.vers)
			break;
		if(pg->vers < vlo)
			vlo = pg->vers;
		if(pg->vers > vhi)
			vhi = pg->vers;
	}
	if(pg->pmap == 0){
		if(vhi < 0)
			rreply.astat = PROG_UNAVAIL;
		else{
			rreply.astat = PROG_MISMATCH;
			rreply.plow = vlo;
			rreply.phigh = vhi;
		}
		goto send_reply;
	}
	for(pp = pg->pmap; pp->procp; pp++)
		if(rcall.proc == pp->procno){
			if(rpcdebug > 1)
				fprint(2, "process %d\n", pp->procno);
			rreply.astat = SUCCESS;
			nreply = (*pp->procp)(i, &rcall, &rreply);
			goto send_reply;
		}
	rreply.astat = PROC_UNAVAIL;
send_reply:
	if(nreply >= 0){
		i = rpcS2M(&rreply, nreply, rbuf);
		if(rpcdebug > 1)
			rpcprint(2, &rreply);
		(*writemsg)(fd, rbuf, i);
		cachereply(&rreply, rbuf, i);
	}
	return 0;
}

static void
getendpoint(char *dir, char *file, uchar *addr, uchar *port)
{
	int fd, n;
	char buf[128];
	char *sys, *serv;

	sys = serv = 0;

	snprint(buf, sizeof buf, "%s/%s", dir, file);
	fd = open(buf, OREAD);
	if(fd >= 0){
		n = read(fd, buf, sizeof(buf)-1);
		if(n>0){
			buf[n-1] = 0;
			serv = strchr(buf, '!');
			if(serv){
				*serv++ = 0;
				serv = strdup(serv);
			}
			sys = strdup(buf);
		}
		close(fd);
	}
	if(serv == 0)
		serv = strdup("unknown");
	if(sys == 0)
		sys = strdup("unknown");
	parseip(addr, sys);
	n = atoi(serv);
	hnputs(port, n);
}

/* set Udphdr values from protocol dir local & remote files */
static void
getendpoints(Udphdr *ep, char *dir)
{
	getendpoint(dir, "local", ep->laddr, ep->lport);
	getendpoint(dir, "remote", ep->raddr, ep->rport);
}

static long
readtcp(int fd, void *vbuf, long blen)
{
	uchar mk[4];
	int n, m, sofar;
	ulong done;
	char *buf;

	buf = vbuf;
	buf += Udphdrsize;
	blen -= Udphdrsize;

	done = 0;
	for(sofar = 0; !done; sofar += n){
		m = readn(fd, mk, 4);
		if(m < 4)
			return 0;
		done = (mk[0]<<24)|(mk[1]<<16)|(mk[2]<<8)|mk[3];
		m = done & 0x7fffffff;
		done &= 0x80000000;
		if(m > blen-sofar)
			return -1;
		n = readn(fd, buf+sofar, m);
		if(m != n)
			return 0;
	}
	return sofar + Udphdrsize;
}

static long
writetcp(int fd, void *vbuf, long len)
{
	char *buf;

	buf = vbuf;
	buf += Udphdrsize;
	len -= Udphdrsize;

	buf -= 4;
	buf[0] = 0x80 | (len>>24);
	buf[1] = len>>16;
	buf[2] = len>>8;
	buf[3] = len;
	len += 4;
	return write(fd, buf, len);
}
/*
 *long
 *niwrite(int fd, void *buf, long count)
 *{
 *	char errbuf[ERRLEN];
 *	long n;
 *
 *	for(;;){
 *		n = write(fd, buf, count);
 *		if(n < 0){
 *			errstr(errbuf);
 *			if(strcmp(errbuf, "interrupted") == 0)
 *				continue;
 *			clog("niwrite error: %s\n", errbuf);
 *			werrstr(errbuf);
 *		}
 *		break;
 *	}
 *	return n;
 *}
 */
long
niwrite(int fd, void *buf, long n)
{
//	int savalarm;

// 	savalarm = alarm(0);
	n = write(fd, buf, n);
// 	if(savalarm > 0)
//		alarm(savalarm);
	return n;
}

typedef struct Namecache	Namecache;
struct Namecache {
	char dom[256];
	ulong ipaddr;
	Namecache *next;
};

Namecache *dnscache;

static Namecache*
domlookupl(void *name, int len)
{
	Namecache *n, **ln;

	if(len >= sizeof(n->dom))
		return nil;

	for(ln=&dnscache, n=*ln; n; ln=&(*ln)->next, n=*ln) {
		if(strncmp(n->dom, name, len) == 0 && n->dom[len] == 0) {
			*ln = n->next;
			n->next = dnscache;
			dnscache = n;
			return n;
		}
	}
	return nil;
}

static Namecache*
iplookup(ulong ip)
{
	Namecache *n, **ln;

	for(ln=&dnscache, n=*ln; n; ln=&(*ln)->next, n=*ln) {
		if(n->ipaddr == ip) {
			*ln = n->next;
			n->next = dnscache;
			dnscache = n;
			return n;
		}
	}
	return nil;
}

static Namecache*
addcacheentry(void *name, int len, ulong ip)
{
	Namecache *n;

	if(len >= sizeof(n->dom))
		return nil;

	n = malloc(sizeof(*n));
	if(n == nil)
		return nil;
	strncpy(n->dom, name, len);
	n->dom[len] = 0;
	n->ipaddr = ip;
	n->next = dnscache;
	dnscache = n;
	return nil;
}

int
getdnsdom(ulong ip, char *name, int len)
{
	char buf[128];
	Namecache *nc;
	char *p;

	if(nc=iplookup(ip)) {
		strncpy(name, nc->dom, len);
		name[len-1] = 0;
		return 0;
	}
	clog("getdnsdom: %I\n", ip);
	snprint(buf, sizeof buf, "%I", ip);
	p = csgetvalue("/net", "ip", buf, "dom", nil);
	if(p == nil)
		return -1;
	strncpy(name, p, len-1);
	name[len] = 0;
	free(p);
	addcacheentry(name, strlen(name), ip);
	return 0;
}

int
getdom(ulong ip, char *dom, int len)
{
	int i;
	static char *prefix[] = { "", "gate-", "fddi-", "u-", 0 };
	char **pr;

	if(getdnsdom(ip, dom, len)<0)
		return -1;

	for(pr=prefix; *pr; pr++){
		i = strlen(*pr);
		if(strncmp(dom, *pr, i) == 0) {
			memmove(dom, dom+i, len-i);
			break;
		}
	}
	return 0;
}

#define	MAXCACHE	64

static Rpccache *head, *tail;
static int	ncache;

static void
cachereply(Rpccall *rp, void *buf, int len)
{
	Rpccache *cp;

	if(nocache)
		return;

	if(ncache >= MAXCACHE){
		if(rpcdebug)
			fprint(2, "%s: drop  %I/%ld, xid %uld, len %d\n",
				argv0, tail->host,
				tail->port, tail->xid, tail->n);
		tail = tail->prev;
		free(tail->next);
		tail->next = 0;
		--ncache;
	}
	cp = malloc(sizeof(Rpccache)+len-4);
	if(cp == 0){
		clog("cachereply: malloc %d failed\n", len);
		return;
	}
	++ncache;
	cp->prev = 0;
	cp->next = head;
	if(head)
		head->prev = cp;
	else
		tail = cp;
	head = cp;
	cp->host = rp->host;
	cp->port = rp->port;
	cp->xid = rp->xid;
	cp->n = len;
	memmove(cp->data, buf, len);
	if(rpcdebug)
		fprint(2, "%s: cache %I/%ld, xid %uld, len %d\n",
			argv0, cp->host, cp->port, cp->xid, cp->n);
}

static int
replycache(int fd, Rpccall *rp, long (*writemsg)(int, void*, long))
{
	Rpccache *cp;

	for(cp=head; cp; cp=cp->next)
		if(cp->host == rp->host &&
		   cp->port == rp->port &&
		   cp->xid == rp->xid)
			break;
	if(cp == 0)
		return 0;
	if(cp->prev){	/* move to front */
		cp->prev->next = cp->next;
		if(cp->next)
			cp->next->prev = cp->prev;
		else
			tail = cp->prev;
		cp->prev = 0;
		cp->next = head;
		head->prev = cp;
		head = cp;
	}
	(*writemsg)(fd, cp->data, cp->n);
	if(rpcdebug)
		fprint(2, "%s: reply %I/%ld, xid %uld, len %d\n",
			argv0, cp->host, cp->port, cp->xid, cp->n);
	return 1;
}

static int
Iconv(Fmt *f)
{
	char buf[16];
	ulong h;

	h = va_arg(f->args, ulong);
	snprint(buf, sizeof buf, "%ld.%ld.%ld.%ld",
		(h>>24)&0xff, (h>>16)&0xff,
		(h>>8)&0xff, h&0xff);
	return fmtstrcpy(f, buf);
}