ref: b86a12149ade500326a238753c31b6e0178d3b5b
dir: /sys/src/cmd/ssh/sshnet.c/
/* * SSH network file system. * Presents remote TCP stack as /net-style file system. */ #include "ssh.h" #include <bio.h> #include <ndb.h> #include <thread.h> #include <fcall.h> #include <9p.h> int rawhack = 1; Conn *conn; char *remoteip = "<remote>"; char *mtpt; Cipher *allcipher[] = { &cipherrc4, &cipherblowfish, &cipher3des, &cipherdes, &ciphernone, &ciphertwiddle, }; Auth *allauth[] = { &authpassword, &authrsa, &authtis, }; char *cipherlist = "rc4 3des"; char *authlist = "rsa password tis"; Cipher* findcipher(char *name, Cipher **list, int nlist) { int i; for(i=0; i<nlist; i++) if(strcmp(name, list[i]->name) == 0) return list[i]; error("unknown cipher %s", name); return nil; } Auth* findauth(char *name, Auth **list, int nlist) { int i; for(i=0; i<nlist; i++) if(strcmp(name, list[i]->name) == 0) return list[i]; error("unknown auth %s", name); return nil; } void usage(void) { fprint(2, "usage: sshnet [-A authlist] [-c cipherlist] [-m mtpt] [user@]hostname\n"); exits("usage"); } int isatty(int fd) { char buf[64]; buf[0] = '\0'; fd2path(fd, buf, sizeof buf); if(strlen(buf)>=9 && strcmp(buf+strlen(buf)-9, "/dev/cons")==0) return 1; return 0; } enum { Qroot, Qcs, Qtcp, Qclone, Qn, Qctl, Qdata, Qlocal, Qremote, Qstatus, }; #define PATH(type, n) ((type)|((n)<<8)) #define TYPE(path) ((int)(path) & 0xFF) #define NUM(path) ((uint)(path)>>8) Channel *sshmsgchan; /* chan(Msg*) */ Channel *fsreqchan; /* chan(Req*) */ Channel *fsreqwaitchan; /* chan(nil) */ Channel *fsclunkchan; /* chan(Fid*) */ Channel *fsclunkwaitchan; /* chan(nil) */ ulong time0; enum { Closed, Dialing, Established, Teardown, }; char *statestr[] = { "Closed", "Dialing", "Established", "Teardown", }; typedef struct Client Client; struct Client { int ref; int state; int num; int servernum; char *connect; Req *rq; Req **erq; Msg *mq; Msg **emq; }; int nclient; Client **client; int newclient(void) { int i; Client *c; for(i=0; i<nclient; i++) if(client[i]->ref==0 && client[i]->state == Closed) return i; if(nclient%16 == 0) client = erealloc9p(client, (nclient+16)*sizeof(client[0])); c = emalloc9p(sizeof(Client)); memset(c, 0, sizeof(*c)); c->num = nclient; client[nclient++] = c; return c->num; } void queuereq(Client *c, Req *r) { if(c->rq==nil) c->erq = &c->rq; *c->erq = r; r->aux = nil; c->erq = (Req**)&r->aux; } void queuemsg(Client *c, Msg *m) { if(c->mq==nil) c->emq = &c->mq; *c->emq = m; m->link = nil; c->emq = (Msg**)&m->link; } void matchmsgs(Client *c) { Req *r; Msg *m; int n, rm; while(c->rq && c->mq){ r = c->rq; c->rq = r->aux; rm = 0; m = c->mq; n = r->ifcall.count; if(n >= m->ep - m->rp){ n = m->ep - m->rp; c->mq = m->link; rm = 1; } memmove(r->ofcall.data, m->rp, n); if(rm) free(m); else m->rp += n; r->ofcall.count = n; respond(r, nil); } } Req* findreq(Client *c, Req *r) { Req **l; for(l=&c->rq; *l; l=(Req**)&(*l)->aux){ if(*l == r){ *l = r->aux; if(*l == nil) c->erq = l; return r; } } return nil; } void dialedclient(Client *c) { Req *r; if(r=c->rq){ if(r->aux != nil) sysfatal("more than one outstanding dial request (BUG)"); if(c->state == Established) respond(r, nil); else respond(r, "connect failed"); } c->rq = nil; } void teardownclient(Client *c) { Msg *m; c->state = Teardown; m = allocmsg(conn, SSH_MSG_CHANNEL_INPUT_EOF, 4); putlong(m, c->servernum); sendmsg(m); } void hangupclient(Client *c) { Req *r, *next; Msg *m, *mnext; c->state = Closed; for(m=c->mq; m; m=mnext){ mnext = m->link; free(m); } c->mq = nil; for(r=c->rq; r; r=next){ next = r->aux; respond(r, "hangup on network connection"); } c->rq = nil; } void closeclient(Client *c) { Msg *m, *next; if(--c->ref) return; if(c->rq != nil) sysfatal("ref count reached zero with requests pending (BUG)"); for(m=c->mq; m; m=next){ next = m->link; free(m); } c->mq = nil; if(c->state != Closed) teardownclient(c); } void sshreadproc(void *a) { Conn *c; Msg *m; c = a; for(;;){ m = recvmsg(c, -1); if(m == nil) sysfatal("eof on ssh connection"); sendp(sshmsgchan, m); } } typedef struct Tab Tab; struct Tab { char *name; ulong mode; }; Tab tab[] = { "/", DMDIR|0555, "cs", 0666, "tcp", DMDIR|0555, "clone", 0666, nil, DMDIR|0555, "ctl", 0666, "data", 0666, "local", 0444, "remote", 0444, "status", 0444, }; static void fillstat(Dir *d, uvlong path) { Tab *t; memset(d, 0, sizeof(*d)); d->uid = estrdup9p("ssh"); d->gid = estrdup9p("ssh"); d->qid.path = path; d->atime = d->mtime = time0; t = &tab[TYPE(path)]; if(t->name) d->name = estrdup9p(t->name); else{ d->name = smprint("%ud", NUM(path)); if(d->name == nil) sysfatal("out of memory"); } d->qid.type = t->mode>>24; d->mode = t->mode; } static void fsattach(Req *r) { if(r->ifcall.aname && r->ifcall.aname[0]){ respond(r, "invalid attach specifier"); return; } r->fid->qid.path = PATH(Qroot, 0); r->fid->qid.type = QTDIR; r->fid->qid.vers = 0; r->ofcall.qid = r->fid->qid; respond(r, nil); } static void fsstat(Req *r) { fillstat(&r->d, r->fid->qid.path); respond(r, nil); } static int rootgen(int i, Dir *d, void*) { i += Qroot+1; if(i <= Qtcp){ fillstat(d, i); return 0; } return -1; } static int tcpgen(int i, Dir *d, void*) { i += Qtcp+1; if(i < Qn){ fillstat(d, i); return 0; } i -= Qn; if(i < nclient){ fillstat(d, PATH(Qn, i)); return 0; } return -1; } static int clientgen(int i, Dir *d, void *aux) { Client *c; c = aux; i += Qn+1; if(i <= Qstatus){ fillstat(d, PATH(i, c->num)); return 0; } return -1; } static char* fswalk1(Fid *fid, char *name, Qid *qid) { int i, n; char buf[32]; ulong path; path = fid->qid.path; if(!(fid->qid.type&QTDIR)) return "walk in non-directory"; if(strcmp(name, "..") == 0){ switch(TYPE(path)){ case Qn: qid->path = PATH(Qtcp, NUM(path)); qid->type = tab[Qtcp].mode>>24; return nil; case Qtcp: qid->path = PATH(Qroot, 0); qid->type = tab[Qroot].mode>>24; return nil; case Qroot: return nil; default: return "bug in fswalk1"; } } i = TYPE(path)+1; for(; i<nelem(tab); i++){ if(i==Qn){ n = atoi(name); snprint(buf, sizeof buf, "%d", n); if(n < nclient && strcmp(buf, name) == 0){ qid->path = PATH(i, n); qid->type = tab[i].mode>>24; return nil; } break; } if(strcmp(name, tab[i].name) == 0){ qid->path = PATH(i, NUM(path)); qid->type = tab[i].mode>>24; return nil; } if(tab[i].mode&DMDIR) break; } return "directory entry not found"; } typedef struct Cs Cs; struct Cs { char *resp; int isnew; }; static int ndbfindport(char *p) { char *s, *port; int n; static Ndb *db; if(*p == '\0') return -1; n = strtol(p, &s, 0); if(*s == '\0') return n; if(db == nil){ db = ndbopen("/lib/ndb/common"); if(db == nil) return -1; } port = ndbgetvalue(db, nil, "tcp", p, "port", nil); if(port == nil) return -1; n = atoi(port); free(port); return n; } static void csread(Req *r) { Cs *cs; cs = r->fid->aux; if(cs->resp==nil){ respond(r, "cs read without write"); return; } if(r->ifcall.offset==0){ if(!cs->isnew){ r->ofcall.count = 0; respond(r, nil); return; } cs->isnew = 0; } readstr(r, cs->resp); respond(r, nil); } static void cswrite(Req *r) { int port, nf; char err[ERRMAX], *f[4], *s, *ns; Cs *cs; cs = r->fid->aux; s = emalloc(r->ifcall.count+1); memmove(s, r->ifcall.data, r->ifcall.count); s[r->ifcall.count] = '\0'; nf = getfields(s, f, nelem(f), 0, "!"); if(nf != 3){ free(s); respond(r, "can't translate"); return; } if(strcmp(f[0], "tcp") != 0 && strcmp(f[0], "net") != 0){ free(s); respond(r, "unknown protocol"); return; } port = ndbfindport(f[2]); if(port <= 0){ free(s); respond(r, "no translation found"); return; } ns = smprint("%s/tcp/clone %s!%d", mtpt, f[1], port); if(ns == nil){ free(s); rerrstr(err, sizeof err); respond(r, err); return; } free(s); free(cs->resp); cs->resp = ns; cs->isnew = 1; r->ofcall.count = r->ifcall.count; respond(r, nil); } static void ctlread(Req *r, Client *c) { char buf[32]; sprint(buf, "%d", c->num); readstr(r, buf); respond(r, nil); } static void ctlwrite(Req *r, Client *c) { char *f[3], *s; int nf; Msg *m; s = emalloc(r->ifcall.count+1); memmove(s, r->ifcall.data, r->ifcall.count); s[r->ifcall.count] = '\0'; nf = tokenize(s, f, 3); if(nf == 0){ free(s); respond(r, nil); return; } if(strcmp(f[0], "hangup") == 0){ if(c->state != Established) goto Badarg; if(nf != 1) goto Badarg; queuereq(c, r); teardownclient(c); }else if(strcmp(f[0], "connect") == 0){ if(c->state != Closed) goto Badarg; if(nf != 2) goto Badarg; c->connect = estrdup9p(f[1]); nf = getfields(f[1], f, nelem(f), 0, "!"); if(nf != 2){ free(c->connect); c->connect = nil; goto Badarg; } c->state = Dialing; m = allocmsg(conn, SSH_MSG_PORT_OPEN, 4+4+strlen(f[0])+4+4+strlen("localhost")); putlong(m, c->num); putstring(m, f[0]); putlong(m, ndbfindport(f[1])); putstring(m, "localhost"); queuereq(c, r); sendmsg(m); }else{ Badarg: respond(r, "bad or inappropriate tcp control message"); } free(s); } static void dataread(Req *r, Client *c) { if(c->state != Established){ respond(r, "not connected"); return; } queuereq(c, r); matchmsgs(c); } static void datawrite(Req *r, Client *c) { Msg *m; if(c->state != Established){ respond(r, "not connected"); return; } if(r->ifcall.count){ m = allocmsg(conn, SSH_MSG_CHANNEL_DATA, 4+4+r->ifcall.count); putlong(m, c->servernum); putlong(m, r->ifcall.count); putbytes(m, r->ifcall.data, r->ifcall.count); sendmsg(m); } r->ofcall.count = r->ifcall.count; respond(r, nil); } static void localread(Req *r) { char buf[128]; snprint(buf, sizeof buf, "%s!%d\n", remoteip, 0); readstr(r, buf); respond(r, nil); } static void remoteread(Req *r, Client *c) { char *s; char buf[128]; s = c->connect; if(s == nil) s = "::!0"; snprint(buf, sizeof buf, "%s\n", s); readstr(r, buf); respond(r, nil); } static void statusread(Req *r, Client *c) { char buf[64]; char *s; snprint(buf, sizeof buf, "%s!%d", remoteip, 0); s = statestr[c->state]; readstr(r, s); respond(r, nil); } static void fsread(Req *r) { char e[ERRMAX]; ulong path; path = r->fid->qid.path; switch(TYPE(path)){ default: snprint(e, sizeof e, "bug in fsread path=%lux", path); respond(r, e); break; case Qroot: dirread9p(r, rootgen, nil); respond(r, nil); break; case Qcs: csread(r); break; case Qtcp: dirread9p(r, tcpgen, nil); respond(r, nil); break; case Qn: dirread9p(r, clientgen, client[NUM(path)]); respond(r, nil); break; case Qctl: ctlread(r, client[NUM(path)]); break; case Qdata: dataread(r, client[NUM(path)]); break; case Qlocal: localread(r); break; case Qremote: remoteread(r, client[NUM(path)]); break; case Qstatus: statusread(r, client[NUM(path)]); break; } } static void fswrite(Req *r) { ulong path; char e[ERRMAX]; path = r->fid->qid.path; switch(TYPE(path)){ default: snprint(e, sizeof e, "bug in fswrite path=%lux", path); respond(r, e); break; case Qcs: cswrite(r); break; case Qctl: ctlwrite(r, client[NUM(path)]); break; case Qdata: datawrite(r, client[NUM(path)]); break; } } static void fsopen(Req *r) { static int need[4] = { 4, 2, 6, 1 }; ulong path; int n; Tab *t; Cs *cs; /* * lib9p already handles the blatantly obvious. * we just have to enforce the permissions we have set. */ path = r->fid->qid.path; t = &tab[TYPE(path)]; n = need[r->ifcall.mode&3]; if((n&t->mode) != n){ respond(r, "permission denied"); return; } switch(TYPE(path)){ case Qcs: cs = emalloc(sizeof(Cs)); r->fid->aux = cs; respond(r, nil); break; case Qclone: n = newclient(); path = PATH(Qctl, n); r->fid->qid.path = path; r->ofcall.qid.path = path; if(chatty9p) fprint(2, "open clone => path=%lux\n", path); t = &tab[Qctl]; /* fall through */ default: if(t-tab >= Qn) client[NUM(path)]->ref++; respond(r, nil); break; } } static void fsflush(Req *r) { int i; for(i=0; i<nclient; i++) if(findreq(client[i], r->oldreq)) respond(r->oldreq, "interrupted"); respond(r, nil); } static void handlemsg(Msg *m) { int chan, n; Client *c; switch(m->type){ case SSH_MSG_DISCONNECT: case SSH_CMSG_EXIT_CONFIRMATION: sysfatal("disconnect"); case SSH_CMSG_STDIN_DATA: case SSH_CMSG_EOF: case SSH_CMSG_WINDOW_SIZE: /* don't care */ free(m); break; case SSH_MSG_CHANNEL_DATA: chan = getlong(m); n = getlong(m); if(m->rp+n != m->ep) sysfatal("got bad channel data"); if(chan<nclient && (c=client[chan])->state==Established){ queuemsg(c, m); matchmsgs(c); }else free(m); break; case SSH_MSG_CHANNEL_INPUT_EOF: chan = getlong(m); free(m); if(chan<nclient){ c = client[chan]; chan = c->servernum; hangupclient(c); m = allocmsg(conn, SSH_MSG_CHANNEL_OUTPUT_CLOSED, 4); putlong(m, chan); sendmsg(m); } break; case SSH_MSG_CHANNEL_OUTPUT_CLOSED: chan = getlong(m); if(chan<nclient) hangupclient(client[chan]); free(m); break; case SSH_MSG_CHANNEL_OPEN_CONFIRMATION: chan = getlong(m); c = nil; if(chan>=nclient || (c=client[chan])->state != Dialing){ if(c) fprint(2, "cstate %d\n", c->state); sysfatal("got unexpected open confirmation for %d", chan); } c->servernum = getlong(m); c->state = Established; dialedclient(c); free(m); break; case SSH_MSG_CHANNEL_OPEN_FAILURE: chan = getlong(m); c = nil; if(chan>=nclient || (c=client[chan])->state != Dialing) sysfatal("got unexpected open failure"); if(m->rp+4 <= m->ep) c->servernum = getlong(m); c->state = Closed; dialedclient(c); free(m); break; } } void fsnetproc(void*) { ulong path; Alt a[4]; Cs *cs; Fid *fid; Req *r; Msg *m; threadsetname("fsthread"); a[0].op = CHANRCV; a[0].c = fsclunkchan; a[0].v = &fid; a[1].op = CHANRCV; a[1].c = fsreqchan; a[1].v = &r; a[2].op = CHANRCV; a[2].c = sshmsgchan; a[2].v = &m; a[3].op = CHANEND; for(;;){ switch(alt(a)){ case 0: path = fid->qid.path; switch(TYPE(path)){ case Qcs: cs = fid->aux; if(cs){ free(cs->resp); free(cs); } break; } if(fid->omode != -1 && TYPE(path) >= Qn) closeclient(client[NUM(path)]); sendp(fsclunkwaitchan, nil); break; case 1: switch(r->ifcall.type){ case Tattach: fsattach(r); break; case Topen: fsopen(r); break; case Tread: fsread(r); break; case Twrite: fswrite(r); break; case Tstat: fsstat(r); break; case Tflush: fsflush(r); break; default: respond(r, "bug in fsthread"); break; } sendp(fsreqwaitchan, 0); break; case 2: handlemsg(m); break; } } } static void fssend(Req *r) { sendp(fsreqchan, r); recvp(fsreqwaitchan); /* avoids need to deal with spurious flushes */ } static void fsdestroyfid(Fid *fid) { sendp(fsclunkchan, fid); recvp(fsclunkwaitchan); } void takedown(Srv*) { threadexitsall("done"); } Srv fs = { .attach= fssend, .destroyfid= fsdestroyfid, .walk1= fswalk1, .open= fssend, .read= fssend, .write= fssend, .stat= fssend, .flush= fssend, .end= takedown, }; void threadmain(int argc, char **argv) { int i, fd; char *host, *user, *p, *service; char *f[16]; Msg *m; static Conn c; fmtinstall('B', mpfmt); fmtinstall('H', encodefmt); mtpt = "/net"; service = nil; user = nil; ARGBEGIN{ case 'B': /* undocumented, debugging */ doabort = 1; break; case 'D': /* undocumented, debugging */ debuglevel = strtol(EARGF(usage()), nil, 0); break; case '9': /* undocumented, debugging */ chatty9p++; break; case 'A': authlist = EARGF(usage()); break; case 'c': cipherlist = EARGF(usage()); break; case 'm': mtpt = EARGF(usage()); break; case 's': service = EARGF(usage()); break; default: usage(); }ARGEND if(argc != 1) usage(); host = argv[0]; if((p = strchr(host, '@')) != nil){ *p++ = '\0'; user = host; host = p; } if(user == nil) user = getenv("user"); if(user == nil) sysfatal("cannot find user name"); privatefactotum(); if((fd = dial(netmkaddr(host, "tcp", "ssh"), nil, nil, nil)) < 0) sysfatal("dialing %s: %r", host); c.interactive = isatty(0); c.fd[0] = c.fd[1] = fd; c.user = user; c.host = host; setaliases(&c, host); c.nokcipher = getfields(cipherlist, f, nelem(f), 1, ", "); c.okcipher = emalloc(sizeof(Cipher*)*c.nokcipher); for(i=0; i<c.nokcipher; i++) c.okcipher[i] = findcipher(f[i], allcipher, nelem(allcipher)); c.nokauth = getfields(authlist, f, nelem(f), 1, ", "); c.okauth = emalloc(sizeof(Auth*)*c.nokauth); for(i=0; i<c.nokauth; i++) c.okauth[i] = findauth(f[i], allauth, nelem(allauth)); sshclienthandshake(&c); requestpty(&c); /* turns on TCP_NODELAY on other side */ m = allocmsg(&c, SSH_CMSG_EXEC_SHELL, 0); sendmsg(m); time0 = time(0); sshmsgchan = chancreate(sizeof(Msg*), 16); fsreqchan = chancreate(sizeof(Req*), 0); fsreqwaitchan = chancreate(sizeof(void*), 0); fsclunkchan = chancreate(sizeof(Fid*), 0); fsclunkwaitchan = chancreate(sizeof(void*), 0); conn = &c; procrfork(sshreadproc, &c, 8192, RFNAMEG|RFNOTEG); procrfork(fsnetproc, nil, 8192, RFNAMEG|RFNOTEG); threadpostmountsrv(&fs, service, mtpt, MREPL); exits(0); }