shithub: riscv

ref: 39f18c9d88f52a22373790dec5721fa3521d3f00
dir: /sys/src/cmd/aux/wpa.c/

View raw version
#include <u.h>
#include <libc.h>
#include <ip.h>
#include <mp.h>
#include <libsec.h>
#include <auth.h>

enum {
	PMKlen = 256/8,
	PTKlen = 512/8,
	GTKlen = 256/8,

	MIClen = 16,

	Noncelen = 32,
	Eaddrlen = 6,
};

enum {
	Fptk	= 1<<3,
	Fins	= 1<<6,
	Fack	= 1<<7,
	Fmic	= 1<<8,
	Fsec	= 1<<9,
	Ferr	= 1<<10,
	Freq	= 1<<11,
	Fenc	= 1<<12,

	Keydescrlen = 1+2+2+8+32+16+8+8+16+2,
};

typedef struct Keydescr Keydescr;
struct Keydescr
{
	uchar	type[1];
	uchar	flags[2];
	uchar	keylen[2];
	uchar	repc[8];
	uchar	nonce[32];
	uchar	eapoliv[16];
	uchar	rsc[8];
	uchar	id[8];
	uchar	mic[16];
	uchar	datalen[2];
	uchar	data[];
};

typedef struct Cipher Cipher;
struct Cipher
{
	char	*name;
	int	keylen;
};

typedef struct Eapconn Eapconn;
typedef struct TLStunn TLStunn;

struct Eapconn
{
	int	fd;
	int	version;

	uchar	type;
	uchar	smac[Eaddrlen];
	uchar	amac[Eaddrlen];

	TLStunn	*tunn;

	void	(*write)(Eapconn*, uchar *data, int datalen);
};

struct TLStunn
{
	int	fd;

	int	clientpid;
	int	readerpid;

	uchar	id;
	uchar	tp;
};

Cipher	tkip = { "tkip", 32 };
Cipher	ccmp = { "ccmp", 16 };

Cipher	*peercipher;
Cipher	*groupcipher;

int	background;
int	prompt;
int	debug;
int	fd, cfd;
char	*dev;
int	ispsk;
char	devdir[40];
uchar	ptk[PTKlen];
char	essid[32+1];
uvlong	lastrepc;

uchar rsntkipoui[4] = {0x00, 0x0F, 0xAC, 0x02};
uchar rsnccmpoui[4] = {0x00, 0x0F, 0xAC, 0x04};
uchar rsnapskoui[4] = {0x00, 0x0F, 0xAC, 0x02};
uchar rsnawpaoui[4] = {0x00, 0x0F, 0xAC, 0x01};

uchar	rsnie[] = {
	0x30,			/* RSN */
	0x14,			/* length */
	0x01, 0x00,		/* version 1 */
	0x00, 0x0F, 0xAC, 0x04,	/* group cipher suite CCMP */
	0x01, 0x00,		/* pairwise cipher suite count 1 */
	0x00, 0x0F, 0xAC, 0x04,	/* pairwise cipher suite CCMP */
	0x01, 0x00,		/* authentication suite count 1 */
	0x00, 0x0F, 0xAC, 0x02,	/* authentication suite PSK */
	0x00, 0x00,		/* capabilities */
};

uchar wpa1oui[4]    = {0x00, 0x50, 0xF2, 0x01};
uchar wpatkipoui[4] = {0x00, 0x50, 0xF2, 0x02};
uchar wpaapskoui[4] = {0x00, 0x50, 0xF2, 0x02};
uchar wpaawpaoui[4] = {0x00, 0x50, 0xF2, 0x01};

uchar	wpaie[] = {
	0xdd,			/* vendor specific */
	0x16,			/* length */
	0x00, 0x50, 0xf2, 0x01,	/* WPAIE type 1 */
	0x01, 0x00,		/* version 1 */
	0x00, 0x50, 0xf2, 0x02,	/* group cipher suite TKIP */
	0x01, 0x00,		/* pairwise cipher suite count 1 */
	0x00, 0x50, 0xf2, 0x02,	/* pairwise cipher suite TKIP */
	0x01, 0x00,		/* authentication suite count 1 */
	0x00, 0x50, 0xf2, 0x02,	/* authentication suite PSK */
};

void*
emalloc(int len)
{
	void *v;

	if((v = mallocz(len, 1)) == nil)
		sysfatal("malloc: %r");
	return v;
}

int
hextob(char *s, char **sp, uchar *b, int n)
{
	int r;

	n <<= 1;
	for(r = 0; r < n && *s; s++){
		*b <<= 4;
		if(*s >= '0' && *s <= '9')
			*b |= (*s - '0');
		else if(*s >= 'a' && *s <= 'f')
			*b |= 10+(*s - 'a');
		else if(*s >= 'A' && *s <= 'F')
			*b |= 10+(*s - 'A');
		else break;
		if((++r & 1) == 0)
			b++;
	}
	if(sp != nil)
		*sp = s;
	return r >> 1;
}

char*
getifstats(char *key, char *val, int nval)
{
	char buf[8*1024], *f[2], *p, *e;
	int fd, n;

	snprint(buf, sizeof(buf), "%s/ifstats", devdir);
	if((fd = open(buf, OREAD)) < 0)
		return nil;
	n = readn(fd, buf, sizeof(buf)-1);
	close(fd);
	if(n <= 0)
		return nil;
	buf[n] = 0;
	for(p = buf; (e = strchr(p, '\n')) != nil; p = e){
		*e++ = 0;
		if(gettokens(p, f, 2, "\t\r\n ") != 2)
			continue;
		if(strcmp(f[0], key) != 0)
			continue;
		strncpy(val, f[1], nval);
		val[nval-1] = 0;
		return val;
	}
	return nil;
}

char*
getessid(void)
{
	return getifstats("essid:", essid, sizeof(essid));
}

int
getbssid(uchar mac[Eaddrlen])
{
	char buf[64];

	if(getifstats("bssid:", buf, sizeof(buf)) != nil)
		return parseether(mac, buf);
	return -1;
}

int
connected(int assoc)
{
	char status[1024];

	if(getifstats("status:", status, sizeof(status)) == nil)
		return 0;
	if(strcmp(status, "connecting") == 0)
		return 0;
	if(strcmp(status, "unauthenticated") == 0)
		return 0;
	if(assoc){
		if(strcmp(status, "blocked") != 0 && strcmp(status, "associated") != 0)
			return 0;
	}
	if(debug)
		fprint(2, "status: %s\n", status);
	return 1;
}

int
buildrsne(uchar rsne[258])
{
	char buf[1024];
	uchar brsne[258];
	int brsnelen;
	uchar *p, *w, *e;
	int i, n;

	if(getifstats("brsne:", buf, sizeof(buf)) == nil)
		return 0;	/* not an error, might be old kernel */

	brsnelen = hextob(buf, nil, brsne, sizeof(brsne));
	if(brsnelen <= 4){
trunc:		sysfatal("invalid or truncated RSNE; brsne: %s", buf);
		return 0;
	}

	w = rsne;
	p = brsne;
	e = p + brsnelen;
	if(p[0] == 0x30){
		p += 2;

		/* RSN */
		*w++ = 0x30;
		*w++ = 0;	/* length */
	} else if(p[0] == 0xDD){
		p += 2;
		if((e - p) < 4 || memcmp(p, wpa1oui, 4) != 0){
			sysfatal("unrecognized WPAIE type; brsne: %s", buf);
			return 0;
		}

		/* WPA */
		*w++ = 0xDD;
		*w++ = 0;	/* length */

		memmove(w, wpa1oui, 4);
		w += 4;
		p += 4;
	} else {
		sysfatal("unrecognized RSNE type; brsne: %s", buf);
		return 0;
	}

	if((e - p) < 6)
		goto trunc;

	*w++ = *p++;		/* version */
	*w++ = *p++;

	if(rsne[0] == 0x30){
		if(memcmp(p, rsnccmpoui, 4) == 0)
			groupcipher = &ccmp;
		else if(memcmp(p, rsntkipoui, 4) == 0)
			groupcipher = &tkip;
		else {
			sysfatal("unrecognized RSN group cipher; brsne: %s", buf);
			return 0;
		}
	} else {
		if(memcmp(p, wpatkipoui, 4) != 0){
			sysfatal("unrecognized WPA group cipher; brsne: %s", buf);
			return 0;
		}
		groupcipher = &tkip;
	}

	memmove(w, p, 4);	/* group cipher */
	w += 4;
	p += 4;

	if((e - p) < 6)
		goto trunc;

	*w++ = 0x01;		/* # of peer ciphers */
	*w++ = 0x00;
	n = *p++;
	n |= *p++ << 8;

	if(n <= 0)
		goto trunc;

	peercipher = &tkip;
	for(i=0; i<n; i++){
		if((e - p) < 4)
			goto trunc;

		if(rsne[0] == 0x30 && memcmp(p, rsnccmpoui, 4) == 0 && peercipher == &tkip)
			peercipher = &ccmp;
		p += 4;
	}
	if(peercipher == &ccmp)
		memmove(w, rsnccmpoui, 4);
	else if(rsne[0] == 0x30)
		memmove(w, rsntkipoui, 4);
	else
		memmove(w, wpatkipoui, 4);
	w += 4;

	if((e - p) < 6)
		goto trunc;

	*w++ = 0x01;		/* # of auth suites */
	*w++ = 0x00;
	n = *p++;
	n |= *p++ << 8;

	if(n <= 0)
		goto trunc;

	for(i=0; i<n; i++){
		if((e - p) < 4)
			goto trunc;

		if(rsne[0] == 0x30){
			/* look for PSK oui */
			if(memcmp(p, rsnapskoui, 4) == 0)
				break;
			/* look for WPA oui */
			if(memcmp(p, rsnawpaoui, 4) == 0){
				ispsk = 0;
				break;
			}
		} else {
			/* look for PSK oui */
			if(memcmp(p, wpaapskoui, 4) == 0)
				break;
			/* look for WPA oui */
			if(memcmp(p, wpaawpaoui, 4) == 0){
				ispsk = 0;
				break;
			}
		}
		p += 4;
	}
	if(i >= n){
		sysfatal("auth suite is not PSK or WPA; brsne: %s", buf);
		return 0;
	}

	memmove(w, p, 4);
	w += 4;

	if(rsne[0] == 0x30){
		/* RSN caps */
		*w++ = 0x00;
		*w++ = 0x00;
	}

	rsne[1] = (w - rsne) - 2;
	return w - rsne;
}

char*
factotumattr(char *attr, char *fmt, ...)
{
	char buf[1024];
	va_list list;
	AuthRpc *rpc;
	char *val;
	Attr *a;
	int afd;

	if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
		return nil;
	if((rpc = auth_allocrpc(afd)) == nil){
		close(afd);
		return nil;
	}
	va_start(list, fmt);
	vsnprint(buf, sizeof(buf), fmt, list);
	va_end(list);
	val = nil;
	if(auth_rpc(rpc, "start", buf, strlen(buf)) == 0){
		if((a = auth_attr(rpc)) != nil){
			if((val = _strfindattr(a, attr)) != nil)
				val = strdup(val);
			_freeattr(a);
		}
	}
	auth_freerpc(rpc);
	close(afd);

	return val;
}

void
freeup(UserPasswd *up)
{
	memset(up->user, 0, strlen(up->user));
	memset(up->passwd, 0, strlen(up->passwd));
	free(up);
}

char*
getidentity(void)
{
	static char *identity;
	char *s;

	s = nil;
	for(;;){
		if(getessid() == nil)
			break;
		if((s = factotumattr("user", "proto=pass service=wpa essid=%q", essid)) != nil)
			break;
		if((s = factotumattr("user", "proto=mschapv2 role=client service=wpa essid=%q", essid)) != nil)
			break;
		break;
	}
	if(s != nil){
		free(identity);
		identity = s;
	} else if(identity == nil)
		identity = strdup("anonymous");
	if(debug)
		fprint(2, "identity: %s\n", identity);
	return identity;
}

int
factotumctl(char *fmt, ...)
{
	va_list list;
	int fd, r, n;
	char *s;

	r = -1;
	if((fd = open("/mnt/factotum/ctl", OWRITE)) >= 0){
		va_start(list, fmt);
		s = vsmprint(fmt, list);
		va_end(list);
		if(s != nil){
			n = strlen(s);
			r = write(fd, s, n);
			memset(s, 0, n);
			free(s);
		}
		close(fd);
	}
	return r;
}

int
setpmk(uchar pmk[PMKlen])
{
	if(getessid() == nil)
		return -1;
	return factotumctl("key proto=wpapsk role=client essid=%q !password=%.*H\n", essid, PMKlen, pmk);
}

int
getptk(AuthGetkey *getkey,
	uchar smac[Eaddrlen], uchar amac[Eaddrlen], 
	uchar snonce[Noncelen],  uchar anonce[Noncelen], 
	uchar ptk[PTKlen])
{
	uchar buf[2*Eaddrlen + 2*Noncelen], *p;
	AuthRpc *rpc;
	int afd, ret;
	char *s;

	ret = -1;
	s = nil;
	rpc = nil;
	if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
		goto out;
	if((rpc = auth_allocrpc(afd)) == nil)
		goto out;
	if((s = getessid()) == nil)
		goto out;
	if((s = smprint("proto=wpapsk role=client essid=%q", s)) == nil)
		goto out;
	if((ret = auth_rpc(rpc, "start", s, strlen(s))) != ARok)
		goto out;
	p = buf;
	memmove(p, smac, Eaddrlen); p += Eaddrlen;
	memmove(p, amac, Eaddrlen); p += Eaddrlen;
	memmove(p, snonce, Noncelen); p += Noncelen;
	memmove(p, anonce, Noncelen); p += Noncelen;
	if((ret = auth_rpc(rpc, "write", buf, p - buf)) != ARok)
		goto out;
	if((ret = auth_rpc(rpc, "read", nil, 0)) != ARok)
		goto out;
	if(rpc->narg != PTKlen){
		ret = -1;
		goto out;
	}
	memmove(ptk, rpc->arg, PTKlen);
	ret = 0;
out:
	if(getkey != nil){
		switch(ret){
		case ARneedkey:
		case ARbadkey:
			(*getkey)(rpc->arg);
			break;
		}
	}
	free(s);
	if(afd >= 0) close(afd);
	if(rpc != nil) auth_freerpc(rpc);
	return ret;
}

int
Hfmt(Fmt *f)
{
	uchar *p, *e;

	p = va_arg(f->args, uchar*);
	e = p;
	if(f->prec >= 0)
		e += f->prec;
	for(; p != e; p++)
		if(fmtprint(f, "%.2x", *p) < 0)
			return -1;
	return 0;
}

void
dumpkeydescr(Keydescr *kd)
{
	static struct {
		int	flag;
		char	*name;
	} flags[] = {
		Fptk,	"ptk",
		Fins,	"ins",
		Fack,	"ack",
		Fmic,	"mic",
		Fsec,	"sec",
		Ferr,	"err",
		Freq,	"req",
		Fenc,	"enc",
	};
	int i, f;

	f = kd->flags[0]<<8 | kd->flags[1];
	fprint(2, "type=%.*H vers=%d flags=%.*H ( ",
		sizeof(kd->type), kd->type, kd->flags[1] & 7,
		sizeof(kd->flags), kd->flags);
	for(i=0; i<nelem(flags); i++)
		if(flags[i].flag & f)
			fprint(2, "%s ", flags[i].name);
	fprint(2, ") len=%.*H\nrepc=%.*H nonce=%.*H\neapoliv=%.*H rsc=%.*H id=%.*H mic=%.*H\n",
		sizeof(kd->keylen), kd->keylen,
		sizeof(kd->repc), kd->repc,
		sizeof(kd->nonce), kd->nonce,
		sizeof(kd->eapoliv), kd->eapoliv,
		sizeof(kd->rsc), kd->rsc,
		sizeof(kd->id), kd->id,
		sizeof(kd->mic), kd->mic);
	i = kd->datalen[0]<<8 | kd->datalen[1];
	fprint(2, "data[%.4x]=%.*H\n", i, i, kd->data);
}

int
rc4unwrap(uchar key[16], uchar iv[16], uchar *data, int len)
{
	uchar seed[32];
	RC4state rs;

	memmove(seed, iv, 16);
	memmove(seed+16, key, 16);
	setupRC4state(&rs, seed, sizeof(seed));
	rc4skip(&rs, 256);
	rc4(&rs, data, len);
	return len;
}

int
aesunwrap(uchar *key, int nkey, uchar *data, int len)
{
	static uchar IV[8] = { 0xa6, 0xa6, 0xa6, 0xa6, 0xa6, 0xa6, 0xa6, 0xa6, };
	uchar B[16], *R;
	AESstate s;
	uint t;
	int n;

	len -= 8;
	if(len < 16 || (len % 8) != 0)
		return -1;
	n = len/8;
	t = n*6;
	setupAESstate(&s, key, nkey, 0);
	memmove(B, data, 8);
	memmove(data, data+8, n*8);
	do {
		for(R = data + (n - 1)*8; R >= data; t--, R -= 8){
			memmove(B+8, R, 8);
			B[7] ^= (t >> 0);
			B[6] ^= (t >> 8);
			B[5] ^= (t >> 16);
			B[4] ^= (t >> 24);
			aes_decrypt(s.dkey, s.rounds, B, B);
			memmove(R, B+8, 8);
		}
	} while(t > 0);
	if(memcmp(B, IV, 8) != 0)
		return -1;
	return n*8;
}

int
calcmic(Keydescr *kd, uchar *msg, int msglen)
{
	int vers;

	vers = kd->flags[1] & 7;
	memset(kd->mic, 0, MIClen);
	if(vers == 1){
		uchar digest[MD5dlen];

		hmac_md5(msg, msglen, ptk, 16, digest, nil);
		memmove(kd->mic, digest, MIClen);
		return 0;
	}
	if(vers == 2){
		uchar digest[SHA1dlen];

		hmac_sha1(msg, msglen, ptk, 16, digest, nil);
		memmove(kd->mic, digest, MIClen);
		return 0;
	}
	return -1;
}

int
checkmic(Keydescr *kd, uchar *msg, int msglen)
{
	uchar tmp[MIClen];

	memmove(tmp, kd->mic, MIClen);
	if(calcmic(kd, msg, msglen) != 0)
		return -1;
	return memcmp(tmp, kd->mic, MIClen) != 0;
}

void
fdwrite(Eapconn *conn, uchar *data, int len)
{
	if(write(conn->fd, data, len) != len)
		sysfatal("write: %r");
}

void
etherwrite(Eapconn *conn, uchar *data, int len)
{
	uchar *buf, *p;
	int n;

	if(debug)
		fprint(2, "\nreply(v%d,t%d) %E -> %E: ", conn->version, conn->type, conn->smac, conn->amac);
	n = 2*Eaddrlen + 2 + len;
	if(n < 60) n = 60;	/* ETHERMINTU */
	p = buf = emalloc(n);
	/* ethernet header */
	memmove(p, conn->amac, Eaddrlen); p += Eaddrlen;
	memmove(p, conn->smac, Eaddrlen); p += Eaddrlen;
	*p++ = 0x88;
	*p++ = 0x8e;
	/* eapol data */
	memmove(p, data, len);
	fdwrite(conn, buf, n);
	free(buf);
}

void
eapwrite(Eapconn *conn, uchar *data, int len)
{
	uchar *buf, *p;

	p = buf = emalloc(len + 4);
	/* eapol header */
	*p++ = conn->version;
	*p++ = conn->type;
	*p++ = len >> 8;
	*p++ = len;
	/* eap data */
	memmove(p, data, len); p += len;
	etherwrite(conn, buf, p - buf);
	free(buf);
}

void
replykey(Eapconn *conn, int flags, Keydescr *kd, uchar *data, int datalen)
{
	uchar buf[4096], *p = buf;

	/* eapol hader */
	*p++ = conn->version;
	*p++ = conn->type;
	datalen += Keydescrlen;
	*p++ = datalen >> 8;
	*p++ = datalen;
	datalen -= Keydescrlen;
	/* key header */
	memmove(p, kd, Keydescrlen);
	kd = (Keydescr*)p;
	kd->flags[0] = flags >> 8;
	kd->flags[1] = flags;
	kd->datalen[0] = datalen >> 8;
	kd->datalen[1] = datalen;
	/* key data */
	p = kd->data;
	memmove(p, data, datalen);
	p += datalen;
	/* mic */
	memset(kd->mic, 0, MIClen);
	if(flags & Fmic)
		calcmic(kd, buf, p - buf);
	etherwrite(conn, buf, p - buf);
	if(debug)
		dumpkeydescr(kd);
}

void
eapresp(Eapconn *conn, int code, int id, uchar *data, int len)
{
	uchar *buf, *p;

	len += 4;
	p = buf = emalloc(len);
	/* eap header */
	*p++ = code;
	*p++ = id;
	*p++ = len >> 8;
	*p++ = len;
	memmove(p, data, len-4);
	(*conn->write)(conn, buf, len);
	free(buf);

	if(debug)
		fprint(2, "eapresp(code=%d, id=%d, data=%.*H)\n", code, id, len-4, data);
}

void
tlsreader(TLStunn *tunn, Eapconn *conn)
{
	enum {
		Tlshdrsz = 5,
		TLStunnhdrsz = 6,
	};
	uchar *rec, *w, *p;
	int fd, n, css;

	fd = tunn->fd;
	rec = nil;
	css = 0;
Reset:
	w = rec;
	w += TLStunnhdrsz;
	for(;;w += n){
		if((p = realloc(rec, (w - rec) + Tlshdrsz)) == nil)
			break;
		w = p + (w - rec), rec = p;
		if(readn(fd, w, Tlshdrsz) != Tlshdrsz)
			break;
		n = w[3]<<8 | w[4];
		if(n < 1)
			break;
		if((p = realloc(rec, (w - rec) + Tlshdrsz+n)) == nil)
			break;
		w = p + (w - rec), rec = p;
		if(readn(fd, w+Tlshdrsz, n) != n)
			break;
		n += Tlshdrsz;	
		
		/* batch records that need to be send together */
		if(!css){
			/* Client Certificate */
			if(w[0] == 22 && w[5] == 11)
				continue;
			/* Client Key Exchange */
			if(w[0] == 22 && w[5] == 16)
				continue;
			/* Change Cipher Spec */
			if(w[0] == 20){
				css = 1;
				continue;
			}
		}

		/* do not forward alert, close connection */
		if(w[0] == 21)
			break;

		/* check if we'r still the tunnel for this connection */
		if(conn->tunn != tunn)
			break;

		/* flush records in encapsulation */
		p = rec + TLStunnhdrsz;
		w += n;
		n = w - p;
		*(--p) = n;
		*(--p) = n >> 8;
		*(--p) = n >> 16;
		*(--p) = n >> 24;
		*(--p) = 0x80;	/* flags: Length included */
		*(--p) = tunn->tp;

		eapresp(conn, 2, tunn->id, p, w - p);
		goto Reset;
	}
	free(rec);
}

void ttlsclient(int);
void peapclient(int);

void
eapreset(Eapconn *conn)
{
	TLStunn *tunn;

	tunn = conn->tunn;
	if(tunn == nil)
		return;
	if(debug)
		fprint(2, "eapreset: kill client %d\n", tunn->clientpid);
	conn->tunn = nil;
	postnote(PNPROC, tunn->clientpid, "kill");
}

int
tlswrap(int fd, char *label)
{
	TLSconn *tls;

	tls = emalloc(sizeof(TLSconn));
	if(debug)
		tls->trace = print;
	if(label != nil){
		/* tls client computes the 1024 bit MSK for us */
		tls->sessionType = "ttls";
		tls->sessionConst = label;
		tls->sessionKeylen = 128;
		tls->sessionKey = emalloc(tls->sessionKeylen);
	}
	fd = tlsClient(fd, tls);
	if(fd < 0)
		sysfatal("tlsClient: %r");
	if(label != nil && tls->sessionKey != nil){
		/*
		 * PMK is derived from MSK by taking the first 256 bits.
		 * we store the PMK into factotum with setpmk() associated
		 * with the current essid.
		 */
		if(setpmk(tls->sessionKey) < 0)
			sysfatal("setpmk: %r");

		/* destroy session key */
		memset(tls->sessionKey, 0, tls->sessionKeylen);
	}
	free(tls->cert);	/* TODO: check cert */
	free(tls->sessionID);
	free(tls->sessionKey);
	free(tls);
	return fd;
}

void
eapreq(Eapconn *conn, int code, int id, uchar *data, int datalen)
{
	TLStunn *tunn;
	int tp, frag;
	char *user;

	if(debug)
		fprint(2, "eapreq(code=%d, id=%d, data=%.*H)\n", code, id, datalen, data);

	switch(code){
	case 1:	/* Request */
		break;
	case 4:	/* NAK */
	case 3:	/* Success */
		eapreset(conn);
		if(code == 4 || debug)
			fprint(2, "%s: eap code %s\n", argv0, code == 3 ? "Success" : "NAK");
		return;
	default:
	unhandled:
		if(debug)
			fprint(2, "unhandled: %.*H\n", datalen < 0 ? 0 : datalen, data);
		return;
	}
	if(datalen < 1)
		goto unhandled;

	tp = data[0];
	switch(tp){
	case 1:		/* Identity */
		user = getidentity();
		datalen = 1+strlen(user);
		memmove(data+1, user, datalen-1);
		eapresp(conn, 2, id, data, datalen);
		return;
	case 2:
		fprint(2, "%s: eap error: %.*s\n", argv0, datalen-1, (char*)data+1);
		return;
	case 33:	/* EAP Extensions (AVP) */
		if(debug)
			fprint(2, "eap extension: %.*H\n", datalen, data);
		eapresp(conn, 2, id, data, datalen);
		return;
	case 26:	/* MS-CHAP-V2 */
		data++;
		datalen--;
		if(datalen < 1)
			break;

		/* OpCode */	
		switch(data[0]){
		case 1:	/* Challenge */
			if(datalen > 4) {
				uchar cid, chal[16], resp[48];
				char user[256+1];
				int len;

				cid = data[1];
				len = data[2]<<8 | data[3];
				if(data[4] != sizeof(chal))
					break;
				if(len > datalen || (5 + data[4]) > len)
					break;
				memmove(chal, data+5, sizeof(chal));
				memset(user, 0, sizeof(user));
				memset(resp, 0, sizeof(resp));
				if(auth_respond(chal, sizeof(chal), user, sizeof(user), resp, sizeof(resp), nil,
					"proto=mschapv2 role=client service=wpa essid=%q", essid) < 0){
					fprint(2, "%s: eap mschapv2: auth_respond: %r\n", argv0);
					break;
				}
				len = 5 + sizeof(resp) + 1 + strlen(user);
				data[0] = 2;		/* OpCode - Response */
				data[1] = cid;		/* Identifier */
				data[2] = len >> 8;
				data[3] = len;
				data[4] = sizeof(resp)+1;	/* ValueSize */
				memmove(data+5, resp, sizeof(resp));
				data[5 + sizeof(resp)] = 0;	/* flags */
				strcpy((char*)&data[5 + sizeof(resp) + 1], user);

				*(--data) = tp, len++;
				eapresp(conn, 2, id, data, len);
				return;
			}
			break;

		case 3:	/* Success */
		case 4:	/* Failure */
			if(debug || data[0] == 4)
				fprint(2, "%s: eap mschapv2 %s: %.*s\n", argv0,
					data[0] == 3 ? "Success" : "Failure",
					datalen < 4 ? 0 : datalen-4, (char*)data+4);
			*(--data) = tp;
			eapresp(conn, 2, id, data, 2);
			return;
		}
		break;

	case 21:	/* EAP-TTLS */
	case 25:	/* PEAP */
		if(datalen < 2)
			break;
		datalen -= 2;
		data++;
		tunn = conn->tunn;
		if(*data & 0x20){	/* flags: start */
			int p[2], pid;

			if(tunn != nil){
				if(tunn->id == id && tunn->tp == tp)
					break;	/* is retransmit, ignore */
				eapreset(conn);
			}
			if(pipe(p) < 0)
				sysfatal("pipe: %r");
			if((pid = fork()) == -1)
				sysfatal("fork: %r");
			if(pid == 0){
				close(p[0]);
				switch(tp){
				case 21:
					ttlsclient(p[1]);
					break;
				case 25:
					peapclient(p[1]);
					break;
				}
				exits(nil);
			}
			close(p[1]);
			tunn = emalloc(sizeof(TLStunn));
			tunn->tp = tp;
			tunn->id = id;
			tunn->fd = p[0];
			tunn->clientpid = pid;
			conn->tunn = tunn;
			if((pid = rfork(RFPROC|RFMEM)) == -1)
				sysfatal("fork: %r");
			if(pid == 0){
				tunn->readerpid = getpid();
				tlsreader(tunn, conn);
				if(conn->tunn == tunn)
					conn->tunn = nil;
				close(tunn->fd);
				free(tunn);
				exits(nil);
			}
			return;
		}
		if(tunn == nil)
			break;
		if(id <= tunn->id || tunn->tp != tp)
			break;
		tunn->id = id;
		frag = *data & 0x40;	/* flags: more fragments */
		if(*data & 0x80){	/* flags: length included */
			datalen -= 4;
			data += 4;
		}
		data++;
		if(datalen > 0)
			write(tunn->fd, data, datalen);
		if(frag || (tp == 25 && data[0] == 20)){	/* ack change cipher spec */
			data -= 2;
			data[0] = tp;
			data[1] = 0;
			eapresp(conn, 2, id, data, 2);
		}
		return;
	}
	goto unhandled;
}

int
avp(uchar *p, int n, int code, void *val, int len, int pad)
{
	pad = 8 + ((len + pad) & ~pad);	/* header + data + data pad */
	assert(((pad + 3) & ~3) <= n);
	p[0] = code >> 24;
	p[1] = code >> 16;
	p[2] = code >> 8;
	p[3] = code;
	p[4] = 2;
	p[5] = pad >> 16;
	p[6] = pad >> 8;
	p[7] = pad;
	memmove(p+8, val, len);
	len += 8;
	pad = (pad + 3) & ~3;	/* packet padding */
	memset(p+len, 0, pad - len);
	return pad;
}

enum {
	/* Avp Code */
	AvpUserName = 1,
	AvpUserPass = 2,
	AvpChapPass = 3,
	AvpChapChal = 60,
};

void
ttlsclient(int fd)
{
	uchar buf[4096];
	UserPasswd *up;
	int n;

	fd = tlswrap(fd, "ttls keying material");
	if((up = auth_getuserpasswd(nil, "proto=pass service=wpa essid=%q", essid)) == nil)
		sysfatal("auth_getuserpasswd: %r");
	n = avp(buf, sizeof(buf), AvpUserName, up->user, strlen(up->user), 0);
	n += avp(buf+n, sizeof(buf)-n, AvpUserPass, up->passwd, strlen(up->passwd), 15);
	freeup(up);
	write(fd, buf, n);
	memset(buf, 0, n);
}

void
peapwrite(Eapconn *conn, uchar *data, int len)
{
	assert(len >= 4);
	fdwrite(conn, data + 4, len - 4);
}

void
peapclient(int fd)
{
	static Eapconn conn;
	uchar buf[4096], *p;
	int n, id, code;

	conn.fd = fd = tlswrap(fd, "client EAP encryption");
	while((n = read(fd, p = buf, sizeof(buf))) > 0){
		if(n > 4 && (p[2] << 8 | p[3]) == n && p[4] == 33){
			code = p[0];
			id = p[1];
			p += 4, n -= 4;
			conn.write = fdwrite;
		} else {
			code = 1;
			id = 0;
			conn.write = peapwrite;
		}
		eapreq(&conn, code, id, p, n);
	}
}

void
usage(void)
{
	fprint(2, "%s: [-dp12] [-s essid] dev\n", argv0);
	exits("usage");
}

void
main(int argc, char *argv[])
{
	uchar mac[Eaddrlen], buf[4096], snonce[Noncelen], anonce[Noncelen];
	static uchar brsne[258];
	static Eapconn conn;
	char addr[128];
	uchar *rsne;
	int rsnelen;
	int n, try;

	quotefmtinstall();
	fmtinstall('H', Hfmt);
	fmtinstall('E', eipfmt);

	rsne = nil;
	rsnelen = -1;
	peercipher = nil;
	groupcipher = nil;

	ARGBEGIN {
	case 'd':
		debug = 1;
		break;
	case 'p':
		prompt = 1;
		break;
	case 's':
		strncpy(essid, EARGF(usage()), 32);
		break;
	case '1':
		rsne = wpaie;
		rsnelen = sizeof(wpaie);
		peercipher = &tkip;
		groupcipher = &tkip;
		break;
	case '2':
		rsne = rsnie;
		rsnelen = sizeof(rsnie);
		peercipher = &ccmp;
		groupcipher = &ccmp;
		break;
	default:
		usage();
	} ARGEND;

	if(*argv != nil)
		dev = *argv++;

	if(*argv != nil || dev == nil)
		usage();

	if(myetheraddr(mac, dev) < 0)
		sysfatal("can't get mac address: %r");

	snprint(addr, sizeof(addr), "%s!0x888e", dev);
	if((fd = dial(addr, nil, devdir, &cfd)) < 0)
		sysfatal("dial: %r");

	if(essid[0] != 0){
		if(fprint(cfd, "essid %q", essid) < 0)
			sysfatal("write essid: %r");
	} else {
		getessid();
		if(essid[0] == 0)
			sysfatal("no essid set");
	}

Connect:
 	/* bss scan might not be complete yet, so check for 10 seconds.	*/
	for(try = 10; (background || try >= 0) && !connected(0); try--)
		sleep(1000);

	ispsk = 1;
	if(rsnelen <= 0 || rsne == brsne){
		rsne = brsne;
		rsnelen = buildrsne(rsne);
	}

	if(rsnelen <= 0){
		/* default is WPA */
		rsne = wpaie;
		rsnelen = sizeof(wpaie);
		peercipher = &tkip;
		groupcipher = &tkip;
	}

	if(debug)
		fprint(2, "rsne: %.*H\n", rsnelen, rsne);

	/*
	 * we use write() instead of fprint so the message gets written
	 * at once and not chunked up on fprint buffer.
	 */
	n = sprint((char*)buf, "auth %.*H", rsnelen, rsne);
	if(write(cfd, buf, n) != n)
		sysfatal("write auth: %r");

	conn.fd = fd;
	conn.write = eapwrite;
	conn.type = 1;	/* Start */
	conn.version = 1;
	memmove(conn.smac, mac, Eaddrlen);
	getbssid(conn.amac);

	if(prompt){
		prompt = 0;
		if(ispsk){
			/* dummy to for factotum keyprompt */
			genrandom(anonce, sizeof(anonce));
			genrandom(snonce, sizeof(snonce));
			getptk(auth_getkey, conn.smac, conn.amac, snonce, anonce, ptk);
		} else {
			UserPasswd *up;

			if((up = auth_getuserpasswd(auth_getkey, "proto=pass service=wpa essid=%q", essid)) != nil){
				factotumctl("key proto=mschapv2 role=client service=wpa essid=%q user=%q !password=%q\n",
					essid, up->user, up->passwd);
				freeup(up);
			}
		}
	}

	if(!background){
		background = 1;
		if(!debug){
			switch(rfork(RFNOTEG|RFREND|RFPROC|RFNOWAIT)){
			default:
				exits(nil);
			case -1:
				sysfatal("fork: %r");
				return;
			case 0:
				break;
			}
		}
	}

	/* wait for getting associated before sending start message */
	for(try = 10; (background || try >= 0) && !connected(1); try--)
		sleep(500);
	
	if(getbssid(conn.amac) == 0)
		eapwrite(&conn, nil, 0);

	lastrepc = 0ULL;
	for(;;){
		uchar *p, *e, *m;
		int proto, flags, vers, datalen;
		uvlong repc, rsc, tsc;
		Keydescr *kd;

		if((n = read(fd, buf, sizeof(buf))) < 0)
			sysfatal("read: %r");

		if(n == 0){
			if(debug)
				fprint(2, "got deassociation\n");
			eapreset(&conn);
			goto Connect;
		}

		p = buf;
		e = buf+n;
		if(n < 2*Eaddrlen + 2)
			continue;

		memmove(conn.smac, p, Eaddrlen); p += Eaddrlen;
		memmove(conn.amac, p, Eaddrlen); p += Eaddrlen;
		proto = p[0]<<8 | p[1]; p += 2;

		if(proto != 0x888e || memcmp(conn.smac, mac, Eaddrlen) != 0)
			continue;

		m = p;
		n = e - p;
		if(n < 4)
			continue;

		conn.version = p[0];
		if(conn.version != 0x01 && conn.version != 0x02)
			continue;
		conn.type = p[1];
		n = p[2]<<8 | p[3];
		p += 4;
		if(p+n > e)
			continue;
		e = p + n;

		if(debug)
			fprint(2, "\nrecv(v%d,t%d) %E <- %E: ", conn.version, conn.type, conn.smac, conn.amac);

		if(conn.type == 0x00 && !ispsk){
			uchar code, id;

			if(n < 4)
				continue;
			code = p[0];
			id = p[1];
			n = p[3] | p[2]<<8;
			if(n < 4 || p + n > e)
				continue;
			p += 4, n -= 4;
			eapreq(&conn, code, id, p, n);
			continue;
		}

		if(conn.type != 0x03)
			continue;

		if(n < Keydescrlen){
			if(debug)
				fprint(2, "bad kd size\n");
			continue;
		}
		kd = (Keydescr*)p;
		if(debug)
			dumpkeydescr(kd);

		if(kd->type[0] != 0xFE && kd->type[0] != 0x02)
			continue;

		vers = kd->flags[1] & 7;
		flags = kd->flags[0]<<8 | kd->flags[1];
		datalen = kd->datalen[0]<<8 | kd->datalen[1];
		if(kd->data + datalen > e)
			continue;

		if((flags & Fmic) == 0){
			if((flags & (Fptk|Fack)) != (Fptk|Fack))
				continue;

			memmove(anonce, kd->nonce, sizeof(anonce));
			genrandom(snonce, sizeof(snonce));
			if(getptk(nil, conn.smac, conn.amac, snonce, anonce, ptk) != 0){
				if(debug)
					fprint(2, "getptk: %r\n");
				continue;
			}

			/* ack key exchange with mic */
			memset(kd->rsc, 0, sizeof(kd->rsc));
			memset(kd->eapoliv, 0, sizeof(kd->eapoliv));
			memmove(kd->nonce, snonce, sizeof(kd->nonce));
			replykey(&conn, (flags & ~(Fack|Fins)) | Fmic, kd, rsne, rsnelen);
		} else {
			uchar gtk[GTKlen];
			int gtklen, gtkkid;

			if(checkmic(kd, m, e - m) != 0){
				if(debug)
					fprint(2, "bad mic\n");
				continue;
			}

			repc =	(uvlong)kd->repc[7] |
				(uvlong)kd->repc[6]<<8 |
				(uvlong)kd->repc[5]<<16 |
				(uvlong)kd->repc[4]<<24 |
				(uvlong)kd->repc[3]<<32 |
				(uvlong)kd->repc[2]<<40 |
				(uvlong)kd->repc[1]<<48 |
				(uvlong)kd->repc[0]<<56;
			if(repc <= lastrepc){
				if(debug)
					fprint(2, "bad repc: %llux <= %llux\n", repc, lastrepc);
				continue;
			}
			lastrepc = repc;

			rsc =	(uvlong)kd->rsc[0] |
				(uvlong)kd->rsc[1]<<8 |
				(uvlong)kd->rsc[2]<<16 |
				(uvlong)kd->rsc[3]<<24 |
				(uvlong)kd->rsc[4]<<32 |
				(uvlong)kd->rsc[5]<<40;

			if(datalen > 0 && (flags & Fenc) != 0){
				if(vers == 1)
					datalen = rc4unwrap(ptk+16, kd->eapoliv, kd->data, datalen);
				else
					datalen = aesunwrap(ptk+16, 16, kd->data, datalen);
				if(datalen <= 0){
					if(debug)
						fprint(2, "bad keywrap\n");
					continue;
				}
				if(debug)
					fprint(2, "unwraped keydata[%.4x]=%.*H\n", datalen, datalen, kd->data);
			}

			gtklen = 0;
			gtkkid = -1;

			if(kd->type[0] != 0xFE || (flags & (Fptk|Fack)) == (Fptk|Fack)){
				uchar *p, *x, *e;

				p = kd->data;
				e = p + datalen;
				for(; p+2 <= e; p = x){
					if((x = p+2+p[1]) > e)
						break;
					if(debug)
						fprint(2, "ie=%.2x data[%.2x]=%.*H\n", p[0], p[1], p[1], p+2);
					if(p[0] == 0x30){ /* RSN */
					}
					if(p[0] == 0xDD){ /* WPA */
						static uchar oui[] = { 0x00, 0x0f, 0xac, 0x01, };

						if(p+2+sizeof(oui) > x || memcmp(p+2, oui, sizeof(oui)) != 0)
							continue;
						if((flags & Fenc) == 0)
							continue;	/* ignore gorup key if unencrypted */
						gtklen = x - (p + 8);
						if(gtklen <= 0)
							continue;
						if(gtklen > sizeof(gtk))
							gtklen = sizeof(gtk);
						memmove(gtk, p + 8, gtklen);
						gtkkid = p[6] & 3;
					}
				}
			}

			if((flags & (Fptk|Fack)) == (Fptk|Fack)){
				if(vers != 1)	/* in WPA2, RSC is for group key only */
					tsc = 0LL;
				else {
					tsc = rsc;
					rsc = 0LL;
				}
				/* install pairwise receive key */
				if(fprint(cfd, "rxkey %.*H %s:%.*H@%llux", Eaddrlen, conn.amac,
					peercipher->name, peercipher->keylen, ptk+32, tsc) < 0)
					sysfatal("write rxkey: %r");

				tsc = 0LL;
				memset(kd->rsc, 0, sizeof(kd->rsc));
				memset(kd->eapoliv, 0, sizeof(kd->eapoliv));
				memset(kd->nonce, 0, sizeof(kd->nonce));
				replykey(&conn, flags & ~(Fack|Fenc|Fins), kd, nil, 0);
				sleep(100);

				/* install pairwise transmit key */ 
				if(fprint(cfd, "txkey %.*H %s:%.*H@%llux", Eaddrlen, conn.amac,
					peercipher->name, peercipher->keylen, ptk+32, tsc) < 0)
					sysfatal("write txkey: %r");
			} else
			if((flags & (Fptk|Fsec|Fack)) == (Fsec|Fack)){
				if(kd->type[0] == 0xFE){
					/* WPA always RC4 encrypts the GTK, even tho the flag isnt set */
					if((flags & Fenc) == 0)
						datalen = rc4unwrap(ptk+16, kd->eapoliv, kd->data, datalen);
					gtklen = datalen;
					if(gtklen > sizeof(gtk))
						gtklen = sizeof(gtk);
					memmove(gtk, kd->data, gtklen);
					gtkkid = (flags >> 4) & 3;
				}

				memset(kd->rsc, 0, sizeof(kd->rsc));
				memset(kd->eapoliv, 0, sizeof(kd->eapoliv));
				memset(kd->nonce, 0, sizeof(kd->nonce));
				replykey(&conn, flags & ~(Fenc|Fack), kd, nil, 0);
			} else
				continue;

			if(gtklen >= groupcipher->keylen && gtkkid != -1){
				/* install group key */
				if(fprint(cfd, "rxkey%d %.*H %s:%.*H@%llux",
					gtkkid, Eaddrlen, conn.amac, 
					groupcipher->name, groupcipher->keylen, gtk, rsc) < 0)
					sysfatal("write rxkey%d: %r", gtkkid);
			}
		}
	}
}