shithub: riscv

ref: 0e516cbf488a9ec8edae3dc6efca2baa0c3d95bd
dir: /sys/src/cmd/ssh/sshnet.c/

View raw version
/*
 * SSH network file system.
 * Presents remote TCP stack as /net-style file system.
 */

#include "ssh.h"
#include <bio.h>
#include <ndb.h>
#include <thread.h>
#include <fcall.h>
#include <9p.h>

int rawhack = 1;
Conn *conn;
char *remoteip	= "<remote>";
char *mtpt;

Cipher *allcipher[] = {
	&cipherrc4,
	&cipherblowfish,
	&cipher3des,
	&cipherdes,
	&ciphernone,
	&ciphertwiddle,
};

Auth *allauth[] = {
	&authpassword,
	&authrsa,
	&authtis,
};

char *cipherlist = "rc4 3des";
char *authlist = "rsa password tis";

Cipher*
findcipher(char *name, Cipher **list, int nlist)
{
	int i;

	for(i=0; i<nlist; i++)
		if(strcmp(name, list[i]->name) == 0)
			return list[i];
	error("unknown cipher %s", name);
	return nil;
}

Auth*
findauth(char *name, Auth **list, int nlist)
{
	int i;

	for(i=0; i<nlist; i++)
		if(strcmp(name, list[i]->name) == 0)
			return list[i];
	error("unknown auth %s", name);
	return nil;
}

void
usage(void)
{
	fprint(2, "usage: sshnet [-A authlist] [-c cipherlist] [-m mtpt] [user@]hostname\n");
	exits("usage");
}

int
isatty(int fd)
{
	char buf[64];

	buf[0] = '\0';
	fd2path(fd, buf, sizeof buf);
	if(strlen(buf)>=9 && strcmp(buf+strlen(buf)-9, "/dev/cons")==0)
		return 1;
	return 0;
}

enum
{
	Qroot,
	Qcs,
	Qtcp,
	Qclone,
	Qn,
	Qctl,
	Qdata,
	Qlocal,
	Qremote,
	Qstatus,
};

#define PATH(type, n)		((type)|((n)<<8))
#define TYPE(path)			((int)(path) & 0xFF)
#define NUM(path)			((uint)(path)>>8)

Channel *sshmsgchan;		/* chan(Msg*) */
Channel *fsreqchan;			/* chan(Req*) */
Channel *fsreqwaitchan;		/* chan(nil) */
Channel *fsclunkchan;		/* chan(Fid*) */
Channel *fsclunkwaitchan;	/* chan(nil) */
ulong time0;

enum
{
	Closed,
	Dialing,
	Established,
	Teardown,
};

char *statestr[] = {
	"Closed",
	"Dialing",
	"Established",
	"Teardown",
};

typedef struct Client Client;
struct Client
{
	int ref;
	int state;
	int num;
	int servernum;
	char *connect;
	Req *rq;
	Req **erq;
	Msg *mq;
	Msg **emq;
};

int nclient;
Client **client;

int
newclient(void)
{
	int i;
	Client *c;

	for(i=0; i<nclient; i++)
		if(client[i]->ref==0 && client[i]->state == Closed)
			return i;

	if(nclient%16 == 0)
		client = erealloc9p(client, (nclient+16)*sizeof(client[0]));

	c = emalloc9p(sizeof(Client));
	memset(c, 0, sizeof(*c));
	c->num = nclient;
	client[nclient++] = c;
	return c->num;
}

void
queuereq(Client *c, Req *r)
{
	if(c->rq==nil)
		c->erq = &c->rq;
	*c->erq = r;
	r->aux = nil;
	c->erq = (Req**)&r->aux;
}

void
queuemsg(Client *c, Msg *m)
{
	if(c->mq==nil)
		c->emq = &c->mq;
	*c->emq = m;
	m->link = nil;
	c->emq = (Msg**)&m->link;
}

void
matchmsgs(Client *c)
{
	Req *r;
	Msg *m;
	int n, rm;

	while(c->rq && c->mq){
		r = c->rq;
		c->rq = r->aux;

		rm = 0;
		m = c->mq;
		n = r->ifcall.count;
		if(n >= m->ep - m->rp){
			n = m->ep - m->rp;
			c->mq = m->link;
			rm = 1;
		}
		memmove(r->ofcall.data, m->rp, n);
		if(rm)
			free(m);
		else
			m->rp += n;
		r->ofcall.count = n;
		respond(r, nil);
	}
}

Req*
findreq(Client *c, Req *r)
{
	Req **l;

	for(l=&c->rq; *l; l=(Req**)&(*l)->aux){
		if(*l == r){
			*l = r->aux;
			if(*l == nil)
				c->erq = l;
			return r;
		}
	}
	return nil;
}

void
dialedclient(Client *c)
{
	Req *r;

	if(r=c->rq){
		if(r->aux != nil)
			sysfatal("more than one outstanding dial request (BUG)");
		if(c->state == Established)
			respond(r, nil);
		else
			respond(r, "connect failed");
	}
	c->rq = nil;
}

void
teardownclient(Client *c)
{
	Msg *m;

	c->state = Teardown;
	m = allocmsg(conn, SSH_MSG_CHANNEL_INPUT_EOF, 4);
	putlong(m, c->servernum);
	sendmsg(m);
}

void
hangupclient(Client *c)
{
	Req *r, *next;
	Msg *m, *mnext;

	c->state = Closed;
	for(m=c->mq; m; m=mnext){
		mnext = m->link;
		free(m);
	}
	c->mq = nil;
	for(r=c->rq; r; r=next){
		next = r->aux;
		respond(r, "hangup on network connection");
	}
	c->rq = nil;
}

void
closeclient(Client *c)
{
	Msg *m, *next;

	if(--c->ref)
		return;

	if(c->rq != nil)
		sysfatal("ref count reached zero with requests pending (BUG)");

	for(m=c->mq; m; m=next){
		next = m->link;
		free(m);
	}
	c->mq = nil;

	if(c->state != Closed)
		teardownclient(c);
}

	
void
sshreadproc(void *a)
{
	Conn *c;
	Msg *m;

	c = a;
	for(;;){
		m = recvmsg(c, -1);
		if(m == nil)
			sysfatal("eof on ssh connection");
		sendp(sshmsgchan, m);
	}
}

typedef struct Tab Tab;
struct Tab
{
	char *name;
	ulong mode;
};

Tab tab[] =
{
	"/",		DMDIR|0555,
	"cs",		0666,
	"tcp",	DMDIR|0555,	
	"clone",	0666,
	nil,		DMDIR|0555,
	"ctl",		0666,
	"data",	0666,
	"local",	0444,
	"remote",	0444,
	"status",	0444,
};

static void
fillstat(Dir *d, uvlong path)
{
	Tab *t;

	memset(d, 0, sizeof(*d));
	d->uid = estrdup9p("ssh");
	d->gid = estrdup9p("ssh");
	d->qid.path = path;
	d->atime = d->mtime = time0;
	t = &tab[TYPE(path)];
	if(t->name)
		d->name = estrdup9p(t->name);
	else{
		d->name = smprint("%ud", NUM(path));
		if(d->name == nil)
			sysfatal("out of memory");
	}
	d->qid.type = t->mode>>24;
	d->mode = t->mode;
}

static void
fsattach(Req *r)
{
	if(r->ifcall.aname && r->ifcall.aname[0]){
		respond(r, "invalid attach specifier");
		return;
	}
	r->fid->qid.path = PATH(Qroot, 0);
	r->fid->qid.type = QTDIR;
	r->fid->qid.vers = 0;
	r->ofcall.qid = r->fid->qid;
	respond(r, nil);
}

static void
fsstat(Req *r)
{
	fillstat(&r->d, r->fid->qid.path);
	respond(r, nil);
}

static int
rootgen(int i, Dir *d, void*)
{
	i += Qroot+1;
	if(i <= Qtcp){
		fillstat(d, i);
		return 0;
	}
	return -1;
}

static int
tcpgen(int i, Dir *d, void*)
{
	i += Qtcp+1;
	if(i < Qn){
		fillstat(d, i);
		return 0;
	}
	i -= Qn;
	if(i < nclient){
		fillstat(d, PATH(Qn, i));
		return 0;
	}
	return -1;
}

static int
clientgen(int i, Dir *d, void *aux)
{
	Client *c;

	c = aux;
	i += Qn+1;
	if(i <= Qstatus){
		fillstat(d, PATH(i, c->num));
		return 0;
	}
	return -1;
}

static char*
fswalk1(Fid *fid, char *name, Qid *qid)
{
	int i, n;
	char buf[32];
	ulong path;

	path = fid->qid.path;
	if(!(fid->qid.type&QTDIR))
		return "walk in non-directory";

	if(strcmp(name, "..") == 0){
		switch(TYPE(path)){
		case Qn:
			qid->path = PATH(Qtcp, NUM(path));
			qid->type = tab[Qtcp].mode>>24;
			return nil;
		case Qtcp:
			qid->path = PATH(Qroot, 0);
			qid->type = tab[Qroot].mode>>24;
			return nil;
		case Qroot:
			return nil;
		default:
			return "bug in fswalk1";
		}
	}

	i = TYPE(path)+1;
	for(; i<nelem(tab); i++){
		if(i==Qn){
			n = atoi(name);
			snprint(buf, sizeof buf, "%d", n);
			if(n < nclient && strcmp(buf, name) == 0){
				qid->path = PATH(i, n);
				qid->type = tab[i].mode>>24;
				return nil;
			}
			break;
		}
		if(strcmp(name, tab[i].name) == 0){
			qid->path = PATH(i, NUM(path));
			qid->type = tab[i].mode>>24;
			return nil;
		}
		if(tab[i].mode&DMDIR)
			break;
	}
	return "directory entry not found";
}

typedef struct Cs Cs;
struct Cs
{
	char *resp;
	int isnew;
};

static int
ndbfindport(char *p)
{
	char *s, *port;
	int n;
	static Ndb *db;

	if(*p == '\0')
		return -1;

	n = strtol(p, &s, 0);
	if(*s == '\0')
		return n;

	if(db == nil){
		db = ndbopen("/lib/ndb/common");
		if(db == nil)
			return -1;
	}

	port = ndbgetvalue(db, nil, "tcp", p, "port", nil);
	if(port == nil)
		return -1;
	n = atoi(port);
	free(port);

	return n;
}	

static void
csread(Req *r)
{
	Cs *cs;

	cs = r->fid->aux;
	if(cs->resp==nil){
		respond(r, "cs read without write");
		return;
	}
	if(r->ifcall.offset==0){
		if(!cs->isnew){
			r->ofcall.count = 0;
			respond(r, nil);
			return;
		}
		cs->isnew = 0;
	}
	readstr(r, cs->resp);
	respond(r, nil);
}

static void
cswrite(Req *r)
{
	int port, nf;
	char err[ERRMAX], *f[4], *s, *ns;
	Cs *cs;

	cs = r->fid->aux;
	s = emalloc(r->ifcall.count+1);
	memmove(s, r->ifcall.data, r->ifcall.count);
	s[r->ifcall.count] = '\0';

	nf = getfields(s, f, nelem(f), 0, "!");
	if(nf != 3){
		free(s);
		respond(r, "can't translate");
		return;
	}
	if(strcmp(f[0], "tcp") != 0 && strcmp(f[0], "net") != 0){
		free(s);
		respond(r, "unknown protocol");
		return;
	}
	port = ndbfindport(f[2]);
	if(port <= 0){
		free(s);
		respond(r, "no translation found");
		return;
	}

	ns = smprint("%s/tcp/clone %s!%d", mtpt, f[1], port);
	if(ns == nil){
		free(s);
		rerrstr(err, sizeof err);
		respond(r, err);
		return;
	}
	free(s);
	free(cs->resp);
	cs->resp = ns;
	cs->isnew = 1;
	r->ofcall.count = r->ifcall.count;
	respond(r, nil);
}

static void
ctlread(Req *r, Client *c)
{
	char buf[32];

	sprint(buf, "%d", c->num);
	readstr(r, buf);
	respond(r, nil);
}

static void
ctlwrite(Req *r, Client *c)
{
	char *f[3], *s;
	int nf;
	Msg *m;

	s = emalloc(r->ifcall.count+1);
	memmove(s, r->ifcall.data, r->ifcall.count);
	s[r->ifcall.count] = '\0';

	nf = tokenize(s, f, 3);
	if(nf == 0){
		free(s);
		respond(r, nil);
		return;
	}

	if(strcmp(f[0], "hangup") == 0){
		if(c->state != Established)
			goto Badarg;
		if(nf != 1)
			goto Badarg;
		queuereq(c, r);
		teardownclient(c);
	}else if(strcmp(f[0], "connect") == 0){
		if(c->state != Closed)
			goto Badarg;
		if(nf != 2)
			goto Badarg;
		c->connect = estrdup9p(f[1]);
		nf = getfields(f[1], f, nelem(f), 0, "!");
		if(nf != 2){
			free(c->connect);
			c->connect = nil;
			goto Badarg;
		}
		c->state = Dialing;
		m = allocmsg(conn, SSH_MSG_PORT_OPEN, 4+4+strlen(f[0])+4+4+strlen("localhost"));
		putlong(m, c->num);
		putstring(m, f[0]);
		putlong(m, ndbfindport(f[1]));
		putstring(m, "localhost");
		queuereq(c, r);
		sendmsg(m);
	}else{
	Badarg:
		respond(r, "bad or inappropriate tcp control message");
	}
	free(s);
}

static void
dataread(Req *r, Client *c)
{
	if(c->state != Established){
		respond(r, "not connected");
		return;
	}
	queuereq(c, r);
	matchmsgs(c);
}

static void
datawrite(Req *r, Client *c)
{
	Msg *m;

	if(c->state != Established){
		respond(r, "not connected");
		return;
	}
	if(r->ifcall.count){
		m = allocmsg(conn, SSH_MSG_CHANNEL_DATA, 4+4+r->ifcall.count);
		putlong(m, c->servernum);
		putlong(m, r->ifcall.count);
		putbytes(m, r->ifcall.data, r->ifcall.count);
		sendmsg(m);
	}
	r->ofcall.count = r->ifcall.count;
	respond(r, nil);
}

static void
localread(Req *r)
{
	char buf[128];

	snprint(buf, sizeof buf, "%s!%d\n", remoteip, 0);
	readstr(r, buf);
	respond(r, nil);
}

static void
remoteread(Req *r, Client *c)
{
	char *s;
	char buf[128];

	s = c->connect;
	if(s == nil)
		s = "::!0";
	snprint(buf, sizeof buf, "%s\n", s);
	readstr(r, buf);
	respond(r, nil);
}

static void
statusread(Req *r, Client *c)
{
	char buf[64];
	char *s;

	snprint(buf, sizeof buf, "%s!%d", remoteip, 0);
	s = statestr[c->state];
	readstr(r, s);
	respond(r, nil);
}

static void
fsread(Req *r)
{
	char e[ERRMAX];
	ulong path;

	path = r->fid->qid.path;
	switch(TYPE(path)){
	default:
		snprint(e, sizeof e, "bug in fsread path=%lux", path);
		respond(r, e);
		break;

	case Qroot:
		dirread9p(r, rootgen, nil);
		respond(r, nil);
		break;

	case Qcs:
		csread(r);
		break;

	case Qtcp:
		dirread9p(r, tcpgen, nil);
		respond(r, nil);
		break;

	case Qn:
		dirread9p(r, clientgen, client[NUM(path)]);
		respond(r, nil);
		break;

	case Qctl:
		ctlread(r, client[NUM(path)]);
		break;

	case Qdata:
		dataread(r, client[NUM(path)]);
		break;

	case Qlocal:
		localread(r);
		break;

	case Qremote:
		remoteread(r, client[NUM(path)]);
		break;

	case Qstatus:
		statusread(r, client[NUM(path)]);
		break;
	}
}

static void
fswrite(Req *r)
{
	ulong path;
	char e[ERRMAX];

	path = r->fid->qid.path;
	switch(TYPE(path)){
	default:
		snprint(e, sizeof e, "bug in fswrite path=%lux", path);
		respond(r, e);
		break;

	case Qcs:
		cswrite(r);
		break;

	case Qctl:
		ctlwrite(r, client[NUM(path)]);
		break;

	case Qdata:
		datawrite(r, client[NUM(path)]);
		break;
	}
}

static void
fsopen(Req *r)
{
	static int need[4] = { 4, 2, 6, 1 };
	ulong path;
	int n;
	Tab *t;
	Cs *cs;

	/*
	 * lib9p already handles the blatantly obvious.
	 * we just have to enforce the permissions we have set.
	 */
	path = r->fid->qid.path;
	t = &tab[TYPE(path)];
	n = need[r->ifcall.mode&3];
	if((n&t->mode) != n){
		respond(r, "permission denied");
		return;
	}

	switch(TYPE(path)){
	case Qcs:
		cs = emalloc(sizeof(Cs));
		r->fid->aux = cs;
		respond(r, nil);
		break;
	case Qclone:
		n = newclient();
		path = PATH(Qctl, n);
		r->fid->qid.path = path;
		r->ofcall.qid.path = path;
		if(chatty9p)
			fprint(2, "open clone => path=%lux\n", path);
		t = &tab[Qctl];
		/* fall through */
	default:
		if(t-tab >= Qn)
			client[NUM(path)]->ref++;
		respond(r, nil);
		break;
	}
}

static void
fsflush(Req *r)
{
	int i;

	for(i=0; i<nclient; i++)
		if(findreq(client[i], r->oldreq))
			respond(r->oldreq, "interrupted");
	respond(r, nil);
}

static void
handlemsg(Msg *m)
{
	int chan, n;
	Client *c;

	switch(m->type){
	case SSH_MSG_DISCONNECT:
	case SSH_CMSG_EXIT_CONFIRMATION:
		sysfatal("disconnect");

	case SSH_CMSG_STDIN_DATA:
	case SSH_CMSG_EOF:
	case SSH_CMSG_WINDOW_SIZE:
		/* don't care */
		free(m);
		break;

	case SSH_MSG_CHANNEL_DATA:
		chan = getlong(m);
		n = getlong(m);
		if(m->rp+n != m->ep)
			sysfatal("got bad channel data");
		if(chan<nclient && (c=client[chan])->state==Established){
			queuemsg(c, m);
			matchmsgs(c);
		}else
			free(m);
		break;

	case SSH_MSG_CHANNEL_INPUT_EOF:
		chan = getlong(m);
		free(m);
		if(chan<nclient){
			c = client[chan];
			chan = c->servernum;
			hangupclient(c);
			m = allocmsg(conn, SSH_MSG_CHANNEL_OUTPUT_CLOSED, 4);
			putlong(m, chan);
			sendmsg(m);
		}
		break;

	case SSH_MSG_CHANNEL_OUTPUT_CLOSED:
		chan = getlong(m);
		if(chan<nclient)
			hangupclient(client[chan]);
		free(m);
		break;

	case SSH_MSG_CHANNEL_OPEN_CONFIRMATION:
		chan = getlong(m);
		c = nil;
		if(chan>=nclient || (c=client[chan])->state != Dialing){
			if(c)
				fprint(2, "cstate %d\n", c->state);
			sysfatal("got unexpected open confirmation for %d", chan);
		}
		c->servernum = getlong(m);
		c->state = Established;
		dialedclient(c);
		free(m);
		break;

	case SSH_MSG_CHANNEL_OPEN_FAILURE:
		chan = getlong(m);
		c = nil;
		if(chan>=nclient || (c=client[chan])->state != Dialing)
			sysfatal("got unexpected open failure");
		if(m->rp+4 <= m->ep)
			c->servernum = getlong(m);
		c->state = Closed;
		dialedclient(c);
		free(m);
		break;
	}
}

void
fsnetproc(void*)
{
	ulong path;
	Alt a[4];
	Cs *cs;
	Fid *fid;
	Req *r;
	Msg *m;

	threadsetname("fsthread");

	a[0].op = CHANRCV;
	a[0].c = fsclunkchan;
	a[0].v = &fid;
	a[1].op = CHANRCV;
	a[1].c = fsreqchan;
	a[1].v = &r;
	a[2].op = CHANRCV;
	a[2].c = sshmsgchan;
	a[2].v = &m;
	a[3].op = CHANEND;

	for(;;){
		switch(alt(a)){
		case 0:
			path = fid->qid.path;
			switch(TYPE(path)){
			case Qcs:
				cs = fid->aux;
				if(cs){
					free(cs->resp);
					free(cs);
				}
				break;
			}
			if(fid->omode != -1 && TYPE(path) >= Qn)
				closeclient(client[NUM(path)]);
			sendp(fsclunkwaitchan, nil);
			break;
		case 1:
			switch(r->ifcall.type){
			case Tattach:
				fsattach(r);
				break;
			case Topen:
				fsopen(r);
				break;
			case Tread:
				fsread(r);
				break;
			case Twrite:
				fswrite(r);
				break;
			case Tstat:
				fsstat(r);
				break;
			case Tflush:
				fsflush(r);
				break;
			default:
				respond(r, "bug in fsthread");
				break;
			}
			sendp(fsreqwaitchan, 0);
			break;
		case 2:
			handlemsg(m);
			break;
		}
	}
}

static void
fssend(Req *r)
{
	sendp(fsreqchan, r);
	recvp(fsreqwaitchan);	/* avoids need to deal with spurious flushes */
}

static void
fsdestroyfid(Fid *fid)
{
	sendp(fsclunkchan, fid);
	recvp(fsclunkwaitchan);
}

void
takedown(Srv*)
{
	threadexitsall("done");
}

Srv fs = 
{
.attach=		fssend,
.destroyfid=	fsdestroyfid,
.walk1=		fswalk1,
.open=		fssend,
.read=		fssend,
.write=		fssend,
.stat=		fssend,
.flush=		fssend,
.end=		takedown,
};

void
threadmain(int argc, char **argv)
{
	int i, fd;
	char *host, *user, *p, *service;
	char *f[16];
	Msg *m;
	static Conn c;

	fmtinstall('B', mpfmt);
	fmtinstall('H', encodefmt);

	mtpt = "/net";
	service = nil;
	user = nil;
	ARGBEGIN{
	case 'B':	/* undocumented, debugging */
		doabort = 1;
		break;
	case 'D':	/* undocumented, debugging */
		debuglevel = strtol(EARGF(usage()), nil, 0);
		break;
	case '9':	/* undocumented, debugging */
		chatty9p++;
		break;

	case 'A':
		authlist = EARGF(usage());
		break;
	case 'c':
		cipherlist = EARGF(usage());
		break;
	case 'm':
		mtpt = EARGF(usage());
		break;
	case 's':
		service = EARGF(usage());
		break;
	default:
		usage();
	}ARGEND

	if(argc != 1)
		usage();

	host = argv[0];

	if((p = strchr(host, '@')) != nil){
		*p++ = '\0';
		user = host;
		host = p;
	}
	if(user == nil)
		user = getenv("user");
	if(user == nil)
		sysfatal("cannot find user name");

	privatefactotum();

	if((fd = dial(netmkaddr(host, "tcp", "ssh"), nil, nil, nil)) < 0)
		sysfatal("dialing %s: %r", host);

	c.interactive = isatty(0);
	c.fd[0] = c.fd[1] = fd;
	c.user = user;
	c.host = host;
	setaliases(&c, host);

	c.nokcipher = getfields(cipherlist, f, nelem(f), 1, ", ");
	c.okcipher = emalloc(sizeof(Cipher*)*c.nokcipher);
	for(i=0; i<c.nokcipher; i++)
		c.okcipher[i] = findcipher(f[i], allcipher, nelem(allcipher));

	c.nokauth = getfields(authlist, f, nelem(f), 1, ", ");
	c.okauth = emalloc(sizeof(Auth*)*c.nokauth);
	for(i=0; i<c.nokauth; i++)
		c.okauth[i] = findauth(f[i], allauth, nelem(allauth));

	sshclienthandshake(&c);

	requestpty(&c);		/* turns on TCP_NODELAY on other side */
	m = allocmsg(&c, SSH_CMSG_EXEC_SHELL, 0);
	sendmsg(m);

	time0 = time(0);
	sshmsgchan = chancreate(sizeof(Msg*), 16);
	fsreqchan = chancreate(sizeof(Req*), 0);
	fsreqwaitchan = chancreate(sizeof(void*), 0);
	fsclunkchan = chancreate(sizeof(Fid*), 0);
	fsclunkwaitchan = chancreate(sizeof(void*), 0);

	conn = &c;
	procrfork(sshreadproc, &c, 8192, RFNAMEG|RFNOTEG);
	procrfork(fsnetproc, nil, 8192, RFNAMEG|RFNOTEG);

	threadpostmountsrv(&fs, service, mtpt, MREPL);
	exits(0);
}