shithub: riscv

ref: 8319b750ea7780bc8452fdda078ecd17f0b3e81d
dir: /sys/src/cmd/ssh.c/

View raw version
#include <u.h>
#include <libc.h>
#include <mp.h>
#include <libsec.h>
#include <auth.h>
#include <authsrv.h>

enum {
	MSG_DISCONNECT = 1,
	MSG_IGNORE,
	MSG_UNIMPLEMENTED,
	MSG_DEBUG,
	MSG_SERVICE_REQUEST,
	MSG_SERVICE_ACCEPT,

	MSG_KEXINIT = 20,
	MSG_NEWKEYS,

	MSG_ECDH_INIT = 30,
	MSG_ECDH_REPLY,

	MSG_USERAUTH_REQUEST = 50,
	MSG_USERAUTH_FAILURE,
	MSG_USERAUTH_SUCCESS,
	MSG_USERAUTH_BANNER,

	MSG_USERAUTH_PK_OK = 60,
	MSG_USERAUTH_INFO_REQUEST = 60,
	MSG_USERAUTH_INFO_RESPONSE = 61,

	MSG_GLOBAL_REQUEST = 80,
	MSG_REQUEST_SUCCESS,
	MSG_REQUEST_FAILURE,

	MSG_CHANNEL_OPEN = 90,
	MSG_CHANNEL_OPEN_CONFIRMATION,
	MSG_CHANNEL_OPEN_FAILURE,
	MSG_CHANNEL_WINDOW_ADJUST,
	MSG_CHANNEL_DATA,
	MSG_CHANNEL_EXTENDED_DATA,
	MSG_CHANNEL_EOF,
	MSG_CHANNEL_CLOSE,
	MSG_CHANNEL_REQUEST,
	MSG_CHANNEL_SUCCESS,
	MSG_CHANNEL_FAILURE,
};


enum {
	Overhead = 256,		// enougth for MSG_CHANNEL_DATA header
	MaxPacket = 1<<15,
	WinPackets = 8,		// (1<<15) * 8 = 256K
};

int MaxPwTries = 3; // retry this often for keyboard-interactive

typedef struct
{
	u32int		seq;
	u32int		kex;
	u32int		chan;

	int		win;
	int		pkt;
	int		eof;

	Chachastate	cs1;
	Chachastate	cs2;

	uchar		*r;
	uchar		*w;
	uchar		b[Overhead + MaxPacket];

	char		*v;
	int		pid;
	Rendez;
} Oneway;

int nsid;
uchar sid[256];
char thumb[2*SHA2_256dlen+1], *thumbfile;

int fd, intr, raw, port, mux, debug;
char *user, *service, *status, *host, *remote, *cmd;

Oneway recv, send;
void dispatch(void);

void
shutdown(void)
{
	recv.eof = send.eof = 1;
	if(send.pid > 0)
		postnote(PNPROC, send.pid, "shutdown");
}

void
catch(void*, char *msg)
{
	if(strcmp(msg, "interrupt") == 0){
		intr = 1;
		noted(NCONT);
	}
	noted(NDFLT);
}

int
wasintr(void)
{
	char err[ERRMAX];
	int r;

	memset(err, 0, sizeof(err));
	errstr(err, sizeof(err));
	r = strcmp(err, "interrupted") == 0;
	errstr(err, sizeof(err));
	return r;
}

#define PUT4(p, u) (p)[0] = (u)>>24, (p)[1] = (u)>>16, (p)[2] = (u)>>8, (p)[3] = (u)
#define GET4(p)	(u32int)(p)[3] | (u32int)(p)[2]<<8 | (u32int)(p)[1]<<16 | (u32int)(p)[0]<<24

int
vpack(uchar *p, int n, char *fmt, va_list a)
{
	uchar *p0 = p, *e = p+n;
	u32int u;
	mpint *m;
	void *s;
	int c;

	for(;;){
		switch(c = *fmt++){
		case '\0':
			return p - p0;
		case '_':
			if(++p > e) goto err;
			break;
		case '.':
			*va_arg(a, void**) = p;
			break;
		case 'b':
			if(p >= e) goto err;
			*p++ = va_arg(a, int);
			break;
		case 'm':
			m = va_arg(a, mpint*);
			u = (mpsignif(m)+8)/8;
			if(p+4 > e) goto err;
			PUT4(p, u), p += 4;
			if(u > e-p) goto err;
			mptober(m, p, u), p += u;
			break;
		case '[':
		case 's':
			s = va_arg(a, void*);
			u = va_arg(a, int);
			if(c == 's'){
				if(p+4 > e) goto err;
				PUT4(p, u), p += 4;
			}
			if(u > e-p) goto err;
			memmove(p, s, u);
			p += u;
			break;
		case 'u':
			u = va_arg(a, int);
			if(p+4 > e) goto err;
			PUT4(p, u), p += 4;
			break;
		}
	}
err:
	return -1;
}

int
vunpack(uchar *p, int n, char *fmt, va_list a)
{
	uchar *p0 = p, *e = p+n;
	u32int u;
	mpint *m;
	void *s;

	for(;;){
		switch(*fmt++){
		case '\0':
			return p - p0;
		case '_':
			if(++p > e) goto err;
			break;
		case '.':
			*va_arg(a, void**) = p;
			break;
		case 'b':
			if(p >= e) goto err;
			*va_arg(a, int*) = *p++;
			break;
		case 'm':
			if(p+4 > e) goto err;
			u = GET4(p), p += 4;
			if(u > e-p) goto err;
			m = va_arg(a, mpint*);
			betomp(p, u, m), p += u;
			break;
		case 's':
			if(p+4 > e) goto err;
			u = GET4(p), p += 4;
			if(u > e-p) goto err;
			*va_arg(a, void**) = p;
			*va_arg(a, int*) = u;
			p += u;
			break;
		case '[':
			s = va_arg(a, void*);
			u = va_arg(a, int);
			if(u > e-p) goto err;
			memmove(s, p, u);
			p += u;
			break;
		case 'u':
			if(p+4 > e) goto err;
			u = GET4(p);
			*va_arg(a, int*) = u;
			p += 4;
			break;
		}
	}
err:
	return -1;
}

int
pack(uchar *p, int n, char *fmt, ...)
{
	va_list a;
	va_start(a, fmt);
	n = vpack(p, n, fmt, a);
	va_end(a);
	return n;
}
int
unpack(uchar *p, int n, char *fmt, ...)
{
	va_list a;
	va_start(a, fmt);
	n = vunpack(p, n, fmt, a);
	va_end(a);
	return n;
}

void
setupcs(Oneway *c, uchar otk[32])
{
	uchar iv[8];

	memset(otk, 0, 32);
	pack(iv, sizeof(iv), "uu", 0, c->seq);
	chacha_setiv(&c->cs1, iv);
	chacha_setiv(&c->cs2, iv);
	chacha_setblock(&c->cs1, 0);
	chacha_setblock(&c->cs2, 0);
	chacha_encrypt(otk, 32, &c->cs2);
}

void
sendpkt(char *fmt, ...)
{
	static uchar buf[sizeof(send.b)];
	int n, pad;
	va_list a;

	va_start(a, fmt);
	n = vpack(send.b, sizeof(send.b), fmt, a);
	va_end(a);
	if(n < 0) {
toobig:		sysfatal("sendpkt: message too big");
		return;
	}
	send.r = send.b;
	send.w = send.b+n;

if(debug > 1)
	fprint(2, "sendpkt: (%d) %.*H\n", send.r[0], (int)(send.w-send.r), send.r);

	if(nsid){
		/* undocumented */
		pad = ChachaBsize - ((5+n) % ChachaBsize) + 4;
	} else {
		for(pad=4; (5+n+pad) % 8; pad++)
			;
	}
	prng(send.w, pad);
	n = pack(buf, sizeof(buf)-16, "ub[[", 1+n+pad, pad, send.b, n, send.w, pad);
	if(n < 0) goto toobig;
	if(nsid){
		uchar otk[32];

		setupcs(&send, otk);
		chacha_encrypt(buf, 4, &send.cs1);
		chacha_encrypt(buf+4, n-4, &send.cs2);
		poly1305(buf, n, otk, sizeof(otk), buf+n, nil);
		n += 16;
	}

	if(write(fd, buf, n) != n)
		sysfatal("write: %r");

	send.seq++;
}

int
readall(int fd, uchar *data, int len)
{
	int n, tot;

	for(tot = 0; tot < len; tot += n){
		n = read(fd, data+tot, len-tot);
		if(n <= 0){
			if(n < 0 && wasintr()){
				n = 0;
				continue;
			} else if(n == 0)
				werrstr("eof");
			break;
		}
	}
	return tot;
}

int
recvpkt(void)
{
	uchar otk[32], tag[16];
	DigestState *ds = nil;
	int n;

	if(readall(fd, recv.b, 4) != 4)
		sysfatal("read1: %r");
	if(nsid){
		setupcs(&recv, otk);
		ds = poly1305(recv.b, 4, otk, sizeof(otk), nil, nil);
		chacha_encrypt(recv.b, 4, &recv.cs1);
		unpack(recv.b, 4, "u", &n);
		n += 16;
	} else {
		unpack(recv.b, 4, "u", &n);
	}
	if(n < 8 || n > sizeof(recv.b)){
badlen:		sysfatal("bad length %d", n);
	}
	if(readall(fd, recv.b, n) != n)
		sysfatal("read2: %r");
	if(nsid){
		n -= 16;
		if(n < 0) goto badlen;
		poly1305(recv.b, n, otk, sizeof(otk), tag, ds);
		if(tsmemcmp(tag, recv.b+n, 16) != 0)
			sysfatal("bad tag");
		chacha_encrypt(recv.b, n, &recv.cs2);
	}
	n -= recv.b[0]+1;
	if(n < 1) goto badlen;

	recv.r = recv.b + 1;
	recv.w = recv.r + n;
	recv.seq++;

if(debug > 1)
	fprint(2, "recvpkt: (%d) %.*H\n", recv.r[0], (int)(recv.w-recv.r), recv.r);

	return recv.r[0];
}

static char sshrsa[] = "ssh-rsa";

int
rsapub2ssh(RSApub *rsa, uchar *data, int len)
{
	return pack(data, len, "smm", sshrsa, sizeof(sshrsa)-1, rsa->ek, rsa->n);
}

RSApub*
ssh2rsapub(uchar *data, int len)
{
	RSApub *pub;
	char *s;
	int n;

	pub = rsapuballoc();
	pub->n = mpnew(0);
	pub->ek = mpnew(0);
	if(unpack(data, len, "smm", &s, &n, pub->ek, pub->n) < 0
	|| n != sizeof(sshrsa)-1 || memcmp(s, sshrsa, n) != 0){
		rsapubfree(pub);
		return nil;
	}
	return pub;
}

static char rsasha256[] = "rsa-sha2-256";

int
rsasig2ssh(RSApub *pub, mpint *S, uchar *data, int len)
{
	int l = (mpsignif(pub->n)+7)/8;
	if(4+12+4+l > len)
		return -1;
	mptober(S, data+4+12+4, l);
	return pack(data, len, "ss", rsasha256, sizeof(rsasha256)-1, data+4+12+4, l);
}

mpint*
ssh2rsasig(uchar *data, int len)
{
	mpint *m;
	char *s;
	int n;

	m = mpnew(0);
	if(unpack(data, len, "sm", &s, &n, m) < 0
	|| n != sizeof(rsasha256)-1 || memcmp(s, rsasha256, n) != 0){
		mpfree(m);
		return nil;
	}
	return m;
}

mpint*
pkcs1digest(uchar *data, int len, RSApub *pub)
{
	uchar digest[SHA2_256dlen], buf[256];

	sha2_256(data, len, digest, nil);
	return pkcs1padbuf(buf, asn1encodedigest(sha2_256, digest, buf, sizeof(buf)), pub->n, 1);
}

int
pkcs1verify(uchar *data, int len, RSApub *pub, mpint *S)
{
	mpint *V;
	int ret;

	V = pkcs1digest(data, len, pub);
	ret = V != nil;
	if(ret){
		rsaencrypt(pub, S, S);
		ret = mpcmp(V, S) == 0;
		mpfree(V);
	}
	return ret;
}

DigestState*
hashstr(void *data, ulong len, DigestState *ds)
{
	uchar l[4];
	pack(l, 4, "u", len);
	return sha2_256((uchar*)data, len, nil, sha2_256(l, 4, nil, ds));
}

void
kdf(uchar *k, int nk, uchar *h, char x, uchar *out, int len)
{
	uchar digest[SHA2_256dlen], *out0;
	DigestState *ds;
	int n;

	ds = hashstr(k, nk, nil);
	ds = sha2_256(h, sizeof(digest), nil, ds);
	ds = sha2_256((uchar*)&x, 1, nil, ds);
	sha2_256(sid, nsid, digest, ds);
	for(out0=out;;){
		n = len;
		if(n > sizeof(digest))
			n = sizeof(digest);
		memmove(out, digest, n);
		len -= n;
		if(len == 0)
			break;
		out += n;
		ds = hashstr(k, nk, nil);
		ds = sha2_256(h, sizeof(digest), nil, ds);
		sha2_256(out0, out-out0, digest, ds);
	}
}

void
kex(int gotkexinit)
{
	static char kexalgs[] = "curve25519-sha256,[email protected]";
	static char cipheralgs[] = "[email protected]";
	static char zipalgs[] = "none";
	static char macalgs[] = "";
	static char langs[] = "";

	uchar cookie[16], x[32], yc[32], z[32], k[32+1], h[SHA2_256dlen], *ys, *ks, *sig;
	uchar k12[2*ChachaKeylen];
	int i, nk, nys, nks, nsig;
	DigestState *ds;
	mpint *S, *K;
	RSApub *pub;

	ds = hashstr(send.v, strlen(send.v), nil);	
	ds = hashstr(recv.v, strlen(recv.v), ds);

	genrandom(cookie, sizeof(cookie));
	sendpkt("b[ssssssssssbu", MSG_KEXINIT,
		cookie, sizeof(cookie),
		kexalgs, sizeof(kexalgs)-1,
		rsasha256, sizeof(rsasha256)-1,
		cipheralgs, sizeof(cipheralgs)-1,
		cipheralgs, sizeof(cipheralgs)-1,
		macalgs, sizeof(macalgs)-1,
		macalgs, sizeof(macalgs)-1,
		zipalgs, sizeof(zipalgs)-1,
		zipalgs, sizeof(zipalgs)-1,
		langs, sizeof(langs)-1,
		langs, sizeof(langs)-1,
		0,
		0);
	ds = hashstr(send.r, send.w-send.r, ds);

	if(!gotkexinit){
	Next0:	switch(recvpkt()){
		default:
			dispatch();
			goto Next0;
		case MSG_KEXINIT:
			break;
		}
	}
	ds = hashstr(recv.r, recv.w-recv.r, ds);

	if(debug){
		char *tab[] = {
			"kexalgs", "hostalgs",
			"cipher1", "cipher2",
			"mac1", "mac2",
			"zip1", "zip2",
			"lang1", "lang2",
			nil,
		}, **t, *s;
		uchar *p = recv.r+17;
		int n;
		for(t=tab; *t != nil; t++){
			if(unpack(p, recv.w-p, "s.", &s, &n, &p) < 0)
				break;
			fprint(2, "%s: %.*s\n", *t, utfnlen(s, n), s);
		}
	}

	curve25519_dh_new(x, yc);
	yc[31] &= ~0x80;

	sendpkt("bs", MSG_ECDH_INIT, yc, sizeof(yc));
Next1:	switch(recvpkt()){
	default:
		dispatch();
		goto Next1;
	case MSG_KEXINIT:
		sysfatal("inception");
	case MSG_ECDH_REPLY:
		if(unpack(recv.r, recv.w-recv.r, "_sss", &ks, &nks, &ys, &nys, &sig, &nsig) < 0)
			sysfatal("bad ECDH_REPLY");
		break;
	}

	if(nys != 32)
		sysfatal("bad server ECDH ephermal public key length");

	ds = hashstr(ks, nks, ds);
	ds = hashstr(yc, 32, ds);
	ds = hashstr(ys, 32, ds);

	if(thumb[0] == 0){
		Thumbprint *ok;

		sha2_256(ks, nks, h, nil);
		i = enc64(thumb, sizeof(thumb), h, sizeof(h));
		while(i > 0 && thumb[i-1] == '=')
			i--;
		thumb[i] = '\0';

		if(debug)
			fprint(2, "host fingerprint: %s\n", thumb);

		ok = initThumbprints(thumbfile, nil, "ssh");
		if(ok == nil || !okThumbprint(h, sizeof(h), ok)){
			if(ok != nil) werrstr("unknown host");
			fprint(2, "%s: %r\n", argv0);
			fprint(2, "verify hostkey: %s %.*[\n", sshrsa, nks, ks);
			fprint(2, "add thumbprint after verification:\n");
			fprint(2, "\techo 'ssh sha256=%s server=%s' >> %q\n", thumb, host, thumbfile);
			sysfatal("checking hostkey failed: %r");
		}
		freeThumbprints(ok);
	}

	if((pub = ssh2rsapub(ks, nks)) == nil)
		sysfatal("bad server public key");
	if((S = ssh2rsasig(sig, nsig)) == nil)
		sysfatal("bad server signature");

	if(!curve25519_dh_finish(x, ys, z))
		sysfatal("unlucky shared key");

	K = betomp(z, 32, nil);
	nk = (mpsignif(K)+8)/8;
	mptober(K, k, nk);
	mpfree(K);

	ds = hashstr(k, nk, ds);
	sha2_256(nil, 0, h, ds);
	if(!pkcs1verify(h, sizeof(h), pub, S))
		sysfatal("server verification failed");
	mpfree(S);
	rsapubfree(pub);

	sendpkt("b", MSG_NEWKEYS);
Next2:	switch(recvpkt()){
	default:
		dispatch();
		goto Next2;
	case MSG_KEXINIT:
		sysfatal("inception");
	case MSG_NEWKEYS:
		break;
	}

	/* next key exchange */
	recv.kex = recv.seq + 100000;
	send.kex = send.seq + 100000;

	if(nsid == 0)
		memmove(sid, h, nsid = sizeof(h));

	kdf(k, nk, h, 'C', k12, sizeof(k12));
	setupChachastate(&send.cs1, k12+1*ChachaKeylen, ChachaKeylen, nil, 64/8, 20);
	setupChachastate(&send.cs2, k12+0*ChachaKeylen, ChachaKeylen, nil, 64/8, 20);

	kdf(k, nk, h, 'D', k12, sizeof(k12));
	setupChachastate(&recv.cs1, k12+1*ChachaKeylen, ChachaKeylen, nil, 64/8, 20);
	setupChachastate(&recv.cs2, k12+0*ChachaKeylen, ChachaKeylen, nil, 64/8, 20);
}

static char *authnext;

int
authok(char *meth)
{
	int ok = authnext == nil || strstr(authnext, meth) != nil;
if(debug)
	fprint(2, "userauth %s %s\n", meth, ok ? "ok" : "skipped");
	return ok;
}

int
authfailure(char *meth)
{
	char *s;
	int n, partial;

	if(unpack(recv.r, recv.w-recv.r, "_sb", &s, &n, &partial) < 0)
		sysfatal("bad auth failure response");
	free(authnext);
	authnext = smprint("%.*s", utfnlen(s, n), s);
if(debug)
	fprint(2, "userauth %s failed: partial=%d, next=%s\n", meth, partial, authnext);
	return partial != 0 || !authok(meth);
}

int
noneauth(void)
{
	static char authmeth[] = "none";

	if(!authok(authmeth))
		return -1;

	sendpkt("bsss", MSG_USERAUTH_REQUEST,
		user, strlen(user),
		service, strlen(service),
		authmeth, sizeof(authmeth)-1);

Next0:	switch(recvpkt()){
	default:
		dispatch();
		goto Next0;
	case MSG_USERAUTH_FAILURE:
		werrstr("authentication needed");
		authfailure(authmeth);
		return -1;
	case MSG_USERAUTH_SUCCESS:
		return 0;
	}
}

int
pubkeyauth(void)
{
	static char authmeth[] = "publickey";

	uchar pk[4096], sig[4096];
	int npk, nsig;

	int afd, n;
	char *s;
	mpint *S;
	AuthRpc *rpc;
	RSApub *pub;

	if(!authok(authmeth))
		return -1;

	if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
		return -1;
	if((rpc = auth_allocrpc(afd)) == nil){
		close(afd);
		return -1;
	}

	s = "proto=rsa service=ssh role=client";
	if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
		auth_freerpc(rpc);
		close(afd);
		return -1;
	}

	pub = rsapuballoc();
	pub->n = mpnew(0);
	pub->ek = mpnew(0);

	while(auth_rpc(rpc, "read", nil, 0) == ARok){
		s = rpc->arg;
		if(strtomp(s, &s, 16, pub->n) == nil)
			break;
		if(*s++ != ' ')
			continue;
		if(strtomp(s, nil, 16, pub->ek) == nil)
			continue;
		npk = rsapub2ssh(pub, pk, sizeof(pk));

		sendpkt("bsssbss", MSG_USERAUTH_REQUEST,
			user, strlen(user),
			service, strlen(service),
			authmeth, sizeof(authmeth)-1,
			0,
			rsasha256, sizeof(rsasha256)-1,
			pk, npk);
Next1:		switch(recvpkt()){
		default:
			dispatch();
			goto Next1;
		case MSG_USERAUTH_FAILURE:
			if(authfailure(authmeth))
				goto Failed;
			continue;
		case MSG_USERAUTH_SUCCESS:
		case MSG_USERAUTH_PK_OK:
			break;
		}

		/* sign sid and the userauth request */
		n = pack(send.b, sizeof(send.b), "sbsssbss",
			sid, nsid,
			MSG_USERAUTH_REQUEST,
			user, strlen(user),
			service, strlen(service),
			authmeth, sizeof(authmeth)-1,
			1,
			rsasha256, sizeof(rsasha256)-1,
			pk, npk);
		S = pkcs1digest(send.b, n, pub);
		n = snprint((char*)send.b, sizeof(send.b), "%B", S);
		mpfree(S);

		if(auth_rpc(rpc, "write", (char*)send.b, n) != ARok)
			break;
		if(auth_rpc(rpc, "read", nil, 0) != ARok)
			break;

		S = strtomp(rpc->arg, nil, 16, nil);
		nsig = rsasig2ssh(pub, S, sig, sizeof(sig));
		mpfree(S);

		/* send final userauth request with the signature */
		sendpkt("bsssbsss", MSG_USERAUTH_REQUEST,
			user, strlen(user),
			service, strlen(service),
			authmeth, sizeof(authmeth)-1,
			1,
			rsasha256, sizeof(rsasha256)-1,
			pk, npk,
			sig, nsig);
Next2:		switch(recvpkt()){
		default:
			dispatch();
			goto Next2;
		case MSG_USERAUTH_FAILURE:
			if(authfailure(authmeth))
				goto Failed;
			continue;
		case MSG_USERAUTH_SUCCESS:
			break;
		}
		rsapubfree(pub);
		auth_freerpc(rpc);
		close(afd);
		return 0;
	}
Failed:
	rsapubfree(pub);
	auth_freerpc(rpc);
	close(afd);
	return -1;	
}

int
passauth(void)
{
	static char authmeth[] = "password";
	UserPasswd *up;

	if(!authok(authmeth))
		return -1;

	up = auth_getuserpasswd(auth_getkey, "proto=pass service=ssh user=%q server=%q thumb=%q",
		user, host, thumb);
	if(up == nil)
		return -1;

	sendpkt("bsssbs", MSG_USERAUTH_REQUEST,
		user, strlen(user),
		service, strlen(service),
		authmeth, sizeof(authmeth)-1,
		0,
		up->passwd, strlen(up->passwd));

	memset(up->passwd, 0, strlen(up->passwd));
	free(up);

Next0:	switch(recvpkt()){
	default:
		dispatch();
		goto Next0;
	case MSG_USERAUTH_FAILURE:
		werrstr("wrong password");
		authfailure(authmeth);
		return -1;
	case MSG_USERAUTH_SUCCESS:
		return 0;
	}
}

int
kbintauth(void)
{
	static char authmeth[] = "keyboard-interactive";
	int tries;

	char *name, *inst, *s, *a;
	int fd, i, n, m;
	int nquest, echo;
	uchar *ans, *answ;
	tries = 0;

	if(!authok(authmeth))
		return -1;

Loop:
	if(++tries > MaxPwTries)
		return -1;
		
	sendpkt("bsssss", MSG_USERAUTH_REQUEST,
		user, strlen(user),
		service, strlen(service),
		authmeth, sizeof(authmeth)-1,
		"", 0,
		"", 0);

Next0:	switch(recvpkt()){
	default:
		dispatch();
		goto Next0;
	case MSG_USERAUTH_FAILURE:
		werrstr("keyboard-interactive failed");
		if(authfailure(authmeth))
			return -1;
		goto Loop;
	case MSG_USERAUTH_SUCCESS:
		return 0;
	case MSG_USERAUTH_INFO_REQUEST:
		break;
	}
Retry:
	if((fd = open("/dev/cons", OWRITE)) < 0)
		return -1;

	if(unpack(recv.r, recv.w-recv.r, "_ss.", &name, &n, &inst, &m, &recv.r) < 0)
		sysfatal("bad info request: name, inst");

	while(n > 0 && strchr("\r\n\t ", name[n-1]) != nil)
		n--;
	while(m > 0 && strchr("\r\n\t ", inst[m-1]) != nil)
		m--;

	if(n > 0)
		fprint(fd, "%.*s\n", utfnlen(name, n), name);
	if(m > 0)
		fprint(fd, "%.*s\n", utfnlen(inst, m), inst);

	/* lang, nprompt */
	if(unpack(recv.r, recv.w-recv.r, "su.", &s, &n, &nquest, &recv.r) < 0)
		sysfatal("bad info request: lang, #quest");

	ans = answ = nil;
	for(i = 0; i < nquest; i++){
		if(unpack(recv.r, recv.w-recv.r, "sb.", &s, &n, &echo, &recv.r) < 0)
			sysfatal("bad info request: question [%d]", i);

		while(n > 0 && strchr("\r\n\t :", s[n-1]) != nil)
			n--;
		s[n] = '\0';

		if((a = readcons(s, nil, !echo)) == nil)
			sysfatal("readcons: %r");

		n = answ - ans;
		m = strlen(a)+4;
		if((s = realloc(ans, n + m)) == nil)
			sysfatal("realloc: %r");
		ans = (uchar*)s;
		answ = ans+n;
		answ += pack(answ, m, "s", a, m-4);
	}

	sendpkt("bu[", MSG_USERAUTH_INFO_RESPONSE, i, ans, answ - ans);
	free(ans);
	close(fd);

Next1:	switch(recvpkt()){
	default:
		dispatch();
		goto Next1;
	case MSG_USERAUTH_INFO_REQUEST:
		goto Retry;
	case MSG_USERAUTH_FAILURE:
		werrstr("keyboard-interactive failed");
		if(authfailure(authmeth))
			return -1;
		goto Loop;
	case MSG_USERAUTH_SUCCESS:
		return 0;
	}
}

void
dispatch(void)
{
	char *s;
	uchar *p;
	int n, b, c;

	switch(recv.r[0]){
	case MSG_IGNORE:
		return;
	case MSG_GLOBAL_REQUEST:
		if(unpack(recv.r, recv.w-recv.r, "_sb", &s, &n, &b) < 0)
			break;
		if(debug)
			fprint(2, "%s: global request: %.*s\n",
				argv0, utfnlen(s, n), s);
		if(b != 0)
			sendpkt("b", MSG_REQUEST_FAILURE);
		return;
	case MSG_DISCONNECT:
		if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
			break;
		sysfatal("disconnect: (%d) %.*s", c, utfnlen(s, n), s);
		return;
	case MSG_DEBUG:
		if(unpack(recv.r, recv.w-recv.r, "__sb", &s, &n, &c) < 0)
			break;
		if(c != 0 || debug)
			fprint(2, "%s: %.*s\n", argv0, utfnlen(s, n), s);
		return;
	case MSG_USERAUTH_BANNER:
		if(unpack(recv.r, recv.w-recv.r, "_s", &s, &n) < 0)
			break;
		if(raw) write(2, s, n);
		return;
	case MSG_KEXINIT:
		kex(1);
		return;
	}

	if(mux){
		n = recv.w - recv.r;
		if(write(1, recv.r, n) != n)
			sysfatal("write out: %r");
		return;
	}

	switch(recv.r[0]){
	case MSG_CHANNEL_DATA:
		if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
			break;
		if(c != recv.chan)
			break;
		if(write(1, s, n) != n)
			sysfatal("write out: %r");
	Winadjust:
		recv.win -= n;
		if(recv.win < recv.pkt){
			n = WinPackets*recv.pkt;
			recv.win += n;
			sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, send.chan, n);
		}
		return;
	case MSG_CHANNEL_EXTENDED_DATA:
		if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
			break;
		if(c != recv.chan)
			break;
		if(b == 1) write(2, s, n);
		goto Winadjust;
	case MSG_CHANNEL_WINDOW_ADJUST:
		if(unpack(recv.r, recv.w-recv.r, "_uu", &c, &n) < 0)
			break;
		if(c != recv.chan)
			break;
		send.win += n;
		if(send.win >= send.pkt)
			rwakeup(&send);
		return;
	case MSG_CHANNEL_REQUEST:
		if(unpack(recv.r, recv.w-recv.r, "_usb.", &c, &s, &n, &b, &p) < 0)
			break;
		if(c != recv.chan)
			break;
		if(n == 11 && memcmp(s, "exit-signal", n) == 0){
			if(unpack(p, recv.w-p, "s", &s, &n) < 0)
				break;
			if(n != 0 && status == nil)
				status = smprint("%.*s", utfnlen(s, n), s);
			c = MSG_CHANNEL_SUCCESS;
		} else if(n == 11 && memcmp(s, "exit-status", n) == 0){
			if(unpack(p, recv.w-p, "u", &n) < 0)
				break;
			if(n != 0 && status == nil)
				status = smprint("%d", n);
			c = MSG_CHANNEL_SUCCESS;
		} else {
			if(debug)
				fprint(2, "%s: channel request: %.*s\n",
					argv0, utfnlen(s, n), s);
			c = MSG_CHANNEL_FAILURE;
		}
		if(b != 0)
			sendpkt("bu", c, recv.chan);
		return;
	case MSG_CHANNEL_EOF:
		recv.eof = 1;
		if(!raw) write(1, "", 0);
		return;
	case MSG_CHANNEL_CLOSE:
		shutdown();
		return;
	}
	sysfatal("got: %.*H", (int)(recv.w - recv.r), recv.r);
}

char*
readline(void)
{
	uchar *p;

	for(p = send.b; p < &send.b[sizeof(send.b)-1]; p++){
		*p = '\0';
		if(read(fd, p, 1) != 1 || *p == '\n')
			break;
	}
	while(p >= send.b && (*p == '\n' || *p == '\r'))
		*p-- = '\0';
	return (char*)send.b;
}

static struct {
	char	*term;
	int	xpixels;
	int	ypixels;
	int	lines;
	int	cols;
	int	gen;
} tty;

int
getdim(void)
{
	char *s;
	int g;

	if(s = getenv("WINCH")){
		g = atoi(s);
		if(tty.gen == g)
			return 0;
		tty.gen = g;
		free(s);
	}
	if(s = getenv("XPIXELS")){
		tty.xpixels = atoi(s);
		free(s);
	}
	if(s = getenv("YPIXELS")){
		tty.ypixels = atoi(s);
		free(s);
	}
	if(s = getenv("LINES")){
		tty.lines = atoi(s);
		free(s);
	}
	if(s = getenv("COLS")){
		tty.cols = atoi(s);
		free(s);
	}
	return 1;
}

void
rawon(void)
{
	int ctl;

	close(0);
	if(open("/dev/cons", OREAD) != 0)
		sysfatal("open: %r");
	close(1);
	if(open("/dev/cons", OWRITE) != 1)
		sysfatal("open: %r");
	dup(1, 2);
	if((ctl = open("/dev/consctl", OWRITE)) >= 0){
		write(ctl, "rawon", 5);
		write(ctl, "winchon", 7);	/* vt(1): interrupt note on window change */
	}
	getdim();
}

#pragma	   varargck    type  "k"   char*

kfmt(Fmt *f)
{
	char *s, *p;
	int n;

	s = va_arg(f->args, char*);
	n = fmtstrcpy(f, "'");
	while((p = strchr(s, '\'')) != nil){
		*p = '\0';
		n += fmtstrcpy(f, s);
		*p = '\'';
		n += fmtstrcpy(f, "'\\''");
		s = p+1;
	}
	n += fmtstrcpy(f, s);
	n += fmtstrcpy(f, "'");
	return n;
}

void
usage(void)
{
	fprint(2, "usage: %s [-dR] [-t thumbfile] [-T tries] [-u user] [-h] [user@]host [-W remote!port] [cmd args...]\n", argv0);
	exits("usage");
}

void
main(int argc, char *argv[])
{
	static QLock sl;
	int b, n, c;
	char *s;

	quotefmtinstall();
	fmtinstall('B', mpfmt);
	fmtinstall('H', encodefmt);
	fmtinstall('[', encodefmt);
	fmtinstall('k', kfmt);

	tty.gen = -1;
	tty.term = getenv("TERM");
	if(tty.term == nil)
		tty.term = "";
	raw = *tty.term != 0;

	ARGBEGIN {
	case 'd':
		debug++;
		break;
	case 'W':
		remote = EARGF(usage());
		s = strrchr(remote, '!');
		if(s == nil)
			s = strrchr(remote, ':');
		if(s == nil)
			usage();
		*s++ = 0;
		port = atoi(s);
		raw = 0;
		break;
	case 'R':
		raw = 0;
		break;
	case 'r':
		raw = 2; /* bloody */
		break;
	case 'u':
		user = EARGF(usage());
		break;
	case 'h':
		host = EARGF(usage());
		break;
	case 't':
		thumbfile = EARGF(usage());
		break;
	case 'T':
		MaxPwTries = strtol(EARGF(usage()), &s, 0);
		if(*s != 0) usage();
		break;
	case 'X':
		mux = 1;
		raw = 0;
		break;
	default:
		usage();
	} ARGEND;

	if(host == nil){
		if(argc == 0)
			usage();
		host = *argv++;
	}

	if(user == nil){
		s = strchr(host, '@');
		if(s != nil){
			*s++ = '\0';
			user = host;
			host = s;
		}
	}

	for(cmd = nil; *argv != nil; argv++){
		if(cmd == nil){
			cmd = strdup(*argv);
			if(raw == 1)
				raw = 0;
		}else{
			s = smprint("%s %k", cmd, *argv);
			free(cmd);
			cmd = s;
		}
	}

	if(remote != nil && cmd != nil)
		usage();

	if((fd = dial(netmkaddr(host, nil, "ssh"), nil, nil, nil)) < 0)
		sysfatal("dial: %r");

	send.v = "SSH-2.0-(9)";
	fprint(fd, "%s\r\n", send.v);
	recv.v = readline();
	if(debug)
		fprint(2, "server version: %s\n", recv.v);
	if(strncmp("SSH-2.0-", recv.v, 8) != 0)
		sysfatal("bad server version: %s", recv.v);
	recv.v = strdup(recv.v);

	send.l = recv.l = &sl;

	if(user == nil)
		user = getuser();
	if(thumbfile == nil)
		thumbfile = smprint("%s/lib/sshthumbs", getenv("home"));

	kex(0);

	sendpkt("bs", MSG_SERVICE_REQUEST, "ssh-userauth", 12);
Next0:	switch(recvpkt()){
	default:
		dispatch();
		goto Next0;
	case MSG_SERVICE_ACCEPT:
		break;
	}

	service = "ssh-connection";
	if(noneauth() < 0 && pubkeyauth() < 0 && passauth() < 0 && kbintauth() < 0)
		sysfatal("auth: %r");

	recv.pkt = send.pkt = MaxPacket;
	recv.win = send.win =  WinPackets*recv.pkt;
	recv.chan = send.win = 0;

	if(mux)
		goto Mux;

	/* open hailing frequencies */
	if(remote != nil){
		NetConnInfo *nci = getnetconninfo(nil, fd);
		if(nci == nil)
			sysfatal("can't get netconninfo: %r");
		sendpkt("bsuuususu", MSG_CHANNEL_OPEN,
			"direct-tcpip", 12,
			recv.chan,
			recv.win,
			recv.pkt,
			remote, strlen(remote),
			port,
			nci->laddr, strlen(nci->laddr),
			atoi(nci->lserv));
		free(nci);
	} else {
		sendpkt("bsuuu", MSG_CHANNEL_OPEN,
			"session", 7,
			recv.chan,
			recv.win,
			recv.pkt);
	}
Next1:	switch(recvpkt()){
	default:
		dispatch();
		goto Next1;
	case MSG_CHANNEL_OPEN_FAILURE:
		if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
			n = strlen(s = "???");
		sysfatal("channel open failure: (%d) %.*s", b, utfnlen(s, n), s);
	case MSG_CHANNEL_OPEN_CONFIRMATION:
		break;
	}

	if(unpack(recv.r, recv.w-recv.r, "_uuuu", &recv.chan, &send.chan, &send.win, &send.pkt) < 0)
		sysfatal("bad channel open confirmation");
	if(send.pkt <= 0 || send.pkt > MaxPacket)
		send.pkt = MaxPacket;

	if(remote != nil)
		goto Mux;

	if(raw) {
		rawon();
		sendpkt("busbsuuuus", MSG_CHANNEL_REQUEST,
			send.chan,
			"pty-req", 7,
			0,
			tty.term, strlen(tty.term),
			tty.cols,
			tty.lines,
			tty.xpixels,
			tty.ypixels,
			"", 0);
	}
	if(cmd == nil){
		sendpkt("busb", MSG_CHANNEL_REQUEST,
			send.chan,
			"shell", 5,
			0);
	} else if(*cmd == '#') {
		sendpkt("busbs", MSG_CHANNEL_REQUEST,
			send.chan,
			"subsystem", 9,
			0,
			cmd+1, strlen(cmd)-1);
	} else {
		sendpkt("busbs", MSG_CHANNEL_REQUEST,
			send.chan,
			"exec", 4,
			0,
			cmd, strlen(cmd));
	}

Mux:
	notify(catch);
	atexit(shutdown);

	recv.pid = getpid();
	n = rfork(RFPROC|RFMEM);
	if(n < 0)
		sysfatal("fork: %r");

	/* parent reads and dispatches packets */
	if(n > 0) {
		send.pid = n;
		while(recv.eof == 0){
			recvpkt();
			qlock(&sl);					
			dispatch();
			if((int)(send.kex - send.seq) <= 0 || (int)(recv.kex - recv.seq) <= 0)
				kex(0);
			qunlock(&sl);
		}
		exits(status);
	}

	/* child reads input and sends packets */
	qlock(&sl);
	for(;;){
		static uchar buf[MaxPacket];
		qunlock(&sl);
		n = read(0, buf, send.pkt);
		qlock(&sl);
		if(send.eof)
			break;
		if(n < 0 && wasintr())
			intr = 1;
		if(intr){
			if(!raw) break;
			if(getdim()){
				sendpkt("busbuuuu", MSG_CHANNEL_REQUEST,
					send.chan,
					"window-change", 13,
					0,
					tty.cols,
					tty.lines,
					tty.xpixels,
					tty.ypixels);
			}else{
				sendpkt("busbs", MSG_CHANNEL_REQUEST,
					send.chan,
					"signal", 6,
					0,
					"INT", 3);
			}
			intr = 0;
			continue;
		}
		if(n <= 0)
			break;
		if(mux){
			sendpkt("[", buf, n);
			continue;
		}
		send.win -= n;
		while(send.win < 0)
			rsleep(&send);
		sendpkt("bus", MSG_CHANNEL_DATA,
			send.chan,
			buf, n);
	}
	if(send.eof++ == 0 && !mux)
		sendpkt("bu", raw ? MSG_CHANNEL_CLOSE : MSG_CHANNEL_EOF, send.chan);
	else if(recv.pid > 0 && mux)
		postnote(PNPROC, recv.pid, "shutdown");
	qunlock(&sl);

	exits(nil);
}