shithub: riscv

ref: ebe88f34cb1823dd2d8f0f7428ce69c7a983397a
dir: /sys/src/cmd/auth/factotum/wpapsk.c/

View raw version
/*
 * WPA-PSK
 *
 * Client protocol:
 *	write challenge: smac[6] + amac[6] + snonce[32] + anonce[32]
 *	read response: ptk[64]
 *
 * Server protocol:
 *	unimplemented
 */
#include "dat.h"

enum {
	PMKlen = 256/8,
	PTKlen = 512/8,

	Eaddrlen = 6,
	Noncelen = 32,
};

enum
{
	CNeedChal,
	CHaveResp,
	Maxphase,
};

static char *phasenames[Maxphase] = {
[CNeedChal]	"CNeedChal",
[CHaveResp]	"CHaveResp",
};

struct State
{
	uchar	resp[PTKlen];
};

static void
pbkdf2(uchar *p, ulong plen, uchar *s, ulong slen, ulong rounds, uchar *d, ulong dlen)
{
	uchar block[SHA1dlen], tmp[SHA1dlen], tmp2[SHA1dlen];
	ulong i, j, k, n;
	DigestState *ds;

	for(i = 1; dlen > 0; i++, d += n, dlen -= n){
		tmp[3] = i;
		tmp[2] = i >> 8;
		tmp[1] = i >> 16;
		tmp[0] = i >> 24;
		ds = hmac_sha1(s, slen, p, plen, nil, nil);
		hmac_sha1(tmp, 4, p, plen, block, ds);
		memmove(tmp, block, sizeof(tmp));
		for(j = 1; j < rounds; j++){
			hmac_sha1(tmp, sizeof(tmp), p, plen, tmp2, nil);
			memmove(tmp, tmp2, sizeof(tmp));
			for(k=0; k<sizeof(tmp); k++)
				block[k] ^= tmp[k];
		}
		n = dlen > sizeof(block) ? sizeof(block) : dlen;
		memmove(d, block, n); 
	}
}

static int
hextob(char *s, char **sp, uchar *b, int n)
{
	int r;

	n <<= 1;
	for(r = 0; r < n && *s; s++){
		*b <<= 4;
		if(*s >= '0' && *s <= '9')
			*b |= (*s - '0');
		else if(*s >= 'a' && *s <= 'f')
			*b |= 10+(*s - 'a');
		else if(*s >= 'A' && *s <= 'F')
			*b |= 10+(*s - 'A');
		else break;
		if((++r & 1) == 0)
			b++;
	}
	if(sp != nil)
		*sp = s;
	return r >> 1;
}

static void
pass2pmk(char *pass, char *ssid, uchar pmk[PMKlen])
{
	if(hextob(pass, nil, pmk, PMKlen) == PMKlen)
		return;
	pbkdf2((uchar*)pass, strlen(pass), (uchar*)ssid, strlen(ssid), 4096, pmk, PMKlen);
}

static void
prfn(uchar *k, int klen, char *a, uchar *b, int blen, uchar *d, int dlen)
{
	uchar r[SHA1dlen], i;
	DigestState *ds;
	int n;

	i = 0;
	while(dlen > 0){
		ds = hmac_sha1((uchar*)a, strlen(a)+1, k, klen, nil, nil);
		hmac_sha1(b, blen, k, klen, nil, ds);
		hmac_sha1(&i, 1, k, klen, r, ds);
		i++;
		n = dlen;
		if(n > sizeof(r))
			n = sizeof(r);
		memmove(d, r, n); d += n;
		dlen -= n;
	}
}

static void
calcptk(uchar pmk[PMKlen], uchar smac[Eaddrlen], uchar amac[Eaddrlen], 
	uchar snonce[Noncelen],  uchar anonce[Noncelen], 
	uchar ptk[PTKlen])
{
	uchar b[2*Eaddrlen + 2*Noncelen];

	if(memcmp(smac, amac, Eaddrlen) > 0){
		memmove(b + Eaddrlen*0, amac, Eaddrlen);
		memmove(b + Eaddrlen*1, smac, Eaddrlen);
	} else {
		memmove(b + Eaddrlen*0, smac, Eaddrlen);
		memmove(b + Eaddrlen*1, amac, Eaddrlen);
	}
	if(memcmp(snonce, anonce, Eaddrlen) > 0){
		memmove(b + Eaddrlen*2 + Noncelen*0, anonce, Noncelen);
		memmove(b + Eaddrlen*2 + Noncelen*1, snonce, Noncelen);
	} else {
		memmove(b + Eaddrlen*2 + Noncelen*0, snonce, Noncelen);
		memmove(b + Eaddrlen*2 + Noncelen*1, anonce, Noncelen);
	}
	prfn(pmk, PMKlen, "Pairwise key expansion", b, sizeof(b), ptk, PTKlen);
}

static int
wpapskinit(Proto *p, Fsstate *fss)
{
	int iscli;
	State *s;

	if((iscli = isclient(_strfindattr(fss->attr, "role"))) < 0)
		return failure(fss, nil);
	if(!iscli)
		return failure(fss, "%s server not supported", p->name);

	s = emalloc(sizeof *s);
	fss->phasename = phasenames;
	fss->maxphase = Maxphase;
	fss->phase = CNeedChal;
	fss->ps = s;
	return RpcOk;
}

static int
wpapskwrite(Fsstate *fss, void *va, uint n)
{
	uchar pmk[PMKlen], *smac, *amac, *snonce, *anonce;
	char *pass, *essid;
	State *s;
	int ret;
	Key *k;
	Keyinfo ki;
	Attr *attr;

	s = fss->ps;

	if(fss->phase != CNeedChal)
		return phaseerror(fss, "write");
	if(n != (2*Eaddrlen + 2*Noncelen))
		return phaseerror(fss, "bad write size");

	attr = _delattr(_copyattr(fss->attr), "role");
	mkkeyinfo(&ki, fss, attr);
	ret = findkey(&k, &ki, "%s", fss->proto->keyprompt);
	_freeattr(attr);
	if(ret != RpcOk)
		return ret;

	pass = _strfindattr(k->privattr, "!password");
	if(pass == nil)
		return failure(fss, "key has no password");
	essid = _strfindattr(k->attr, "essid");
	if(essid == nil)
		return failure(fss, "key has no essid");
	setattrs(fss->attr, k->attr);
	closekey(k);

	pass2pmk(pass, essid, pmk);

	smac = va;
	amac = smac + Eaddrlen;
	snonce = amac + Eaddrlen;
	anonce = snonce + Noncelen;
	calcptk(pmk, smac, amac, snonce, anonce, s->resp);

	fss->phase = CHaveResp;
	return RpcOk;
}

static int
wpapskread(Fsstate *fss, void *va, uint *n)
{
	State *s;

	s = fss->ps;
	if(fss->phase != CHaveResp)
		return phaseerror(fss, "read");
	if(*n > sizeof(s->resp))
		*n = sizeof(s->resp);
	memmove(va, s->resp, *n);
	fss->phase = Established;
	fss->haveai = 0;
	return RpcOk;
}

static void
wpapskclose(Fsstate *fss)
{
	State *s;
	s = fss->ps;
	free(s);
}

Proto wpapsk = {
.name=		"wpapsk",
.init=		wpapskinit,
.write=		wpapskwrite,
.read=		wpapskread,
.close=		wpapskclose,
.addkey=	replacekey,
.keyprompt=	"!password? essid?"
};