shithub: riscv

Download patch

ref: 54746461641644e24a68cc81c6800c5297852698
parent: 8f087e019f551078b97da4dec4973fe9979c9551
author: cinap_lenrek <[email protected]>
date: Sun Jun 7 12:56:01 EDT 2020

devip: implement ipv6 support in ipmux packet filter

Added a ver= field to the filter to distinguish the ip version.
By default, a filter is parsed as ipv6, and after parsing
proto, src and dst fields are converted to ipv4. When no
ver= field is specified, a ip version filter is implicitely
added and both protocols are parsed.

This change also gets rid of the fast compare types as the
filed might not be aligned correctly in the packet.

This also fixes the ifc= filter, as we have to check any
local address.

--- a/sys/man/3/ip
+++ b/sys/man/3/ip
@@ -1126,6 +1126,10 @@
 .TF "\fLdata[\fIn\fL:\fIm\fL]=\fIexpr\fR "
 .PD
 .TP
+.BI ver= n
+the IP version must be
+.IR n .
+.TP
 .BI proto= n
 the IP protocol number must be
 .IR n .
@@ -1135,7 +1139,7 @@
 .I n
 through
 .I m
-following the IP packet must match
+following the IP header must match
 .IR expr .
 .TP
 .BI iph[ n : m ]= expr
--- a/sys/src/9/ip/ipmux.c
+++ b/sys/src/9/ip/ipmux.c
@@ -14,26 +14,9 @@
 typedef struct Ipmuxrock  Ipmuxrock;
 typedef struct Ipmux      Ipmux;
 
-typedef struct Myip4hdr Myip4hdr;
-struct Myip4hdr
-{
-	uchar	vihl;		/* Version and header length */
-	uchar	tos;		/* Type of service */
-	uchar	length[2];	/* packet length */
-	uchar	id[2];		/* ip->identification */
-	uchar	frag[2];	/* Fragment information */
-	uchar	ttl;		/* Time to live */
-	uchar	proto;		/* Protocol */
-	uchar	cksum[2];	/* Header checksum */
-	uchar	src[4];		/* IP source */
-	uchar	dst[4];		/* IP destination */
-
-	uchar	data[1];	/* start of data */
-};
-Myip4hdr *ipoff = 0;
-
 enum
 {
+	Tver,
 	Tproto,
 	Tdata,
 	Tiph,
@@ -40,28 +23,8 @@
 	Tdst,
 	Tsrc,
 	Tifc,
-
-	Cother = 0,
-	Cbyte,		/* single byte */
-	Cmbyte,		/* single byte with mask */
-	Cshort,		/* single short */
-	Cmshort,	/* single short with mask */
-	Clong,		/* single long */
-	Cmlong,		/* single long with mask */
-	Cifc,
-	Cmifc,
 };
 
-char *ftname[] = 
-{
-[Tproto]	"proto",
-[Tdata]		"data",
-[Tiph]	 	"iph",
-[Tdst]		"dst",
-[Tsrc]		"src",
-[Tifc]		"ifc",
-};
-
 /*
  *  a node in the decision tree
  */
@@ -70,16 +33,12 @@
 	Ipmux	*yes;
 	Ipmux	*no;
 	uchar	type;		/* type of field(Txxxx) */
-	uchar	ctype;		/* tupe of comparison(Cxxxx) */
 	uchar	len;		/* length in bytes of item to compare */
 	uchar	n;		/* number of items val points to */
-	short	off;		/* offset of comparison */
-	short	eoff;		/* end offset of comparison */
-	uchar	skiphdr;	/* should offset start after ipheader */
+	int	off;		/* offset of comparison */
 	uchar	*val;
 	uchar	*mask;
 	uchar	*e;		/* val+n*len*/
-
 	int	ref;		/* so we can garbage collect */
 	Conv	*conv;
 };
@@ -94,6 +53,7 @@
 
 static int	ipmuxsprint(Ipmux*, int, char*, int);
 static void	ipmuxkick(void *x);
+static void	ipmuxfree(Ipmux *f);
 
 static char*
 skipwhite(char *p)
@@ -126,27 +86,33 @@
 	Ipmux *f;
 
 	p = skipwhite(p);
-	if(strncmp(p, "dst", 3) == 0){
+	if(strncmp(p, "ver", 3) == 0){
+		type = Tver;
+		off = 0;
+		len = 1;
+		p += 3;
+	}
+	else if(strncmp(p, "dst", 3) == 0){
 		type = Tdst;
-		off = (int)(uintptr)(ipoff->dst);
-		len = IPv4addrlen;
+		off = offsetof(Ip6hdr, dst[0]);
+		len = IPaddrlen;
 		p += 3;
 	}
 	else if(strncmp(p, "src", 3) == 0){
 		type = Tsrc;
-		off = (int)(uintptr)(ipoff->src);
-		len = IPv4addrlen;
+		off = offsetof(Ip6hdr, src[0]);
+		len = IPaddrlen;
 		p += 3;
 	}
 	else if(strncmp(p, "ifc", 3) == 0){
 		type = Tifc;
-		off = -IPv4addrlen;
-		len = IPv4addrlen;
+		off = -IPaddrlen;
+		len = IPaddrlen;
 		p += 3;
 	}
 	else if(strncmp(p, "proto", 5) == 0){
 		type = Tproto;
-		off = (int)(uintptr)&(ipoff->proto);
+		off = offsetof(Ip6hdr, proto);
 		len = 1;
 		p += 5;
 	}
@@ -164,7 +130,7 @@
 			return nil;
 		p++;
 		off = strtoul(p, &p, 0);
-		if(off < 0 || off > (64-IP4HDR))
+		if(off < 0)
 			return nil;
 		p = skipwhite(p);
 		if(*p != ':')
@@ -193,11 +159,6 @@
 	f->mask = nil;
 	f->n = 1;
 	f->ref = 1;
-	if(type == Tdata)
-		f->skiphdr = 1;
-	else
-		f->skiphdr = 0;
-
 	return f;	
 }
 
@@ -233,7 +194,7 @@
 static Ipmux*
 parsemux(char *p)
 {
-	int n, nomask;
+	int n;
 	Ipmux *f;
 	char *val;
 	char *mask;
@@ -258,7 +219,7 @@
 		case Tdst:
 		case Tifc:
 			f->mask = smalloc(f->len);
-			v4parseip(f->mask, mask);
+			parseipmask(f->mask, mask, 0);
 			break;
 		case Tdata:
 		case Tiph:
@@ -268,15 +229,13 @@
 		default:
 			goto parseerror;
 		}
-		nomask = 0;
-	} else {
-		nomask = 1;
+	} else if(f->type == Tver){
 		f->mask = smalloc(f->len);
-		memset(f->mask, 0xff, f->len);
+		f->mask[0] = 0xF0;
 	}
 
 	/* parse vals */
-	f->n = getfields(val, vals, sizeof(vals)/sizeof(char*), 1, "|");
+	f->n = getfields(val, vals, nelem(vals), 1, "|");
 	if(f->n == 0)
 		goto parseerror;
 	f->val = smalloc(f->n*f->len);
@@ -283,10 +242,21 @@
 	v = f->val;
 	for(n = 0; n < f->n; n++){
 		switch(f->type){
+		case Tver:
+			if(f->n != 1)
+				goto parseerror;
+			if(strcmp(vals[n], "6") == 0)
+				*v = IP_VER6;
+			else if(strcmp(vals[n], "4") == 0)
+				*v = IP_VER4;
+			else
+				goto parseerror;
+			break;
 		case Tsrc:
 		case Tdst:
 		case Tifc:
-			v4parseip(v, vals[n]);
+			if(parseip(v, vals[n]) == -1)
+				goto parseerror;
 			break;
 		case Tproto:
 		case Tdata:
@@ -296,34 +266,11 @@
 		}
 		v += f->len;
 	}
-
-	f->eoff = f->off + f->len;
 	f->e = f->val + f->n*f->len;
-	f->ctype = Cother;
-	if(f->n == 1){
-		switch(f->len){
-		case 1:
-			f->ctype = nomask ? Cbyte : Cmbyte;
-			break;
-		case 2:
-			f->ctype = nomask ? Cshort : Cmshort;
-			break;
-		case 4:
-			if(f->type == Tifc)
-				f->ctype = nomask ? Cifc : Cmifc;
-			else
-				f->ctype = nomask ? Clong : Cmlong;
-			break;
-		}
-	}
 	return f;
 
 parseerror:
-	if(f->mask)
-		free(f->mask);
-	if(f->val)
-		free(f->val);
-	free(f);
+	ipmuxfree(f);
 	return nil;
 }
 
@@ -346,8 +293,7 @@
 		return n;
 
 	/* compare offsets, call earlier ones more specific */
-	n = (a->off+((int)a->skiphdr)*(int)(uintptr)ipoff->data) - 
-		(b->off+((int)b->skiphdr)*(int)(uintptr)ipoff->data);
+	n = a->off - b->off;
 	if(n != 0)
 		return n;
 
@@ -417,6 +363,10 @@
 	*nf = *f;
 	nf->no = ipmuxcopy(f->no);
 	nf->yes = ipmuxcopy(f->yes);
+	if(f->mask != nil){
+		nf->mask = smalloc(f->len);
+		memmove(nf->mask, f->mask, f->len);
+	}
 	nf->val = smalloc(f->n*f->len);
 	nf->e = nf->val + f->len*f->n;
 	memmove(nf->val, f->val, f->n*f->len);
@@ -426,8 +376,10 @@
 static void
 ipmuxfree(Ipmux *f)
 {
-	if(f->val != nil)
-		free(f->val);
+	if(f == nil)
+		return;
+	free(f->val);
+	free(f->mask);
 	free(f);
 }
 
@@ -436,10 +388,8 @@
 {
 	if(f == nil)
 		return;
-	if(f->no != nil)
-		ipmuxfree(f->no);
-	if(f->yes != nil)
-		ipmuxfree(f->yes);
+	ipmuxfree(f->no);
+	ipmuxfree(f->yes);
 	ipmuxfree(f);
 }
 
@@ -514,6 +464,8 @@
 		return ipmuxremove(&ft->no, f);
 	}
 
+	ipmuxremove(&ft->no, f->no);
+
 	/* we found a match */
 	if(--(ft->ref) == 0){
 		/*
@@ -535,8 +487,55 @@
 }
 
 /*
+ * convert to ipv4 filter
+ */
+static Ipmux*
+ipmuxconv4(Ipmux *f)
+{
+	int i, n;
+
+	if(f == nil)
+		return nil;
+
+	switch(f->type){
+	case Tproto:
+		f->off = offsetof(Ip4hdr, proto);
+		break;
+	case Tdst:
+		f->off = offsetof(Ip4hdr, dst[0]);
+		if(0){
+	case Tsrc:
+		f->off = offsetof(Ip4hdr, src[0]);
+		}
+		if(f->len != IPaddrlen)
+			break;
+		n = 0;
+		for(i = 0; i < f->n; i++){
+			if(isv4(f->val + i*IPaddrlen)){
+				memmove(f->val + n*IPv4addrlen, f->val + i*IPaddrlen + IPv4off, IPv4addrlen);
+				n++;
+			}
+		}
+		if(n == 0){
+			ipmuxtreefree(f);
+			return nil;
+		}
+		f->n = n;
+		f->len = IPv4addrlen;
+		if(f->mask != nil)
+			memmove(f->mask, f->mask+IPv4off, IPv4addrlen);
+	}
+	f->e = f->val + f->n*f->len;
+
+	f->yes = ipmuxconv4(f->yes);
+	f->no = ipmuxconv4(f->no);
+
+	return f;
+}
+
+/*
  *  connection request is a semi separated list of filters
- *  e.g. proto=17;data[0:4]=11aa22bb;ifc=135.104.9.2&255.255.255.0
+ *  e.g. ver=4;proto=17;data[0:4]=11aa22bb;ifc=135.104.9.2&255.255.255.0
  *
  *  there's no protection against overlapping specs.
  */
@@ -572,6 +571,18 @@
 		return Ebadarg;
 	mux->conv = c;
 
+	if(chain->type != Tver) {
+		char ver6[] = "ver=6";
+		mux = parsemux(ver6);
+		mux->yes = chain;
+		mux->no = ipmuxcopy(chain);
+		chain = mux;
+	}
+	if(*chain->val == IP_VER4)
+		chain->yes = ipmuxconv4(chain->yes);
+	else
+		chain->no = ipmuxconv4(chain->no);
+
 	/* save a copy of the chain so we can later remove it */
 	mux = ipmuxcopy(chain);
 	r = (Ipmuxrock*)(c->ptcl);
@@ -647,7 +658,7 @@
 
 	bp = qget(c->wq);
 	if(bp != nil) {
-		Myip4hdr *ih4 = (Myip4hdr*)(bp->rp);
+		Ip4hdr *ih4 = (Ip4hdr*)(bp->rp);
 
 		if((ih4->vihl & 0xF0) != IP_VER6)
 			ipoput4(c->p->f, bp, 0, ih4->ttl, ih4->tos, nil);
@@ -656,82 +667,74 @@
 	}
 }
 
+static int
+maskmemcmp(uchar *m, uchar *v, uchar *c, int n)
+{
+	int i;
+
+	if(m == nil)
+		return memcmp(v, c, n) != 0;
+
+	for(i = 0; i < n; i++)
+		if((v[i] & m[i]) != c[i])
+			return 1;
+	return 0;
+}
+
 static void
 ipmuxiput(Proto *p, Ipifc *ifc, Block *bp)
 {
-	int len, hl;
 	Fs *f = p->f;
-	uchar *m, *h, *v, *e, *ve, *hp;
 	Conv *c;
+	Iplifc *lifc;
 	Ipmux *mux;
-	Myip4hdr *ip;
+	uchar *v;
+	Ip4hdr *ip4;
 	Ip6hdr *ip6;
+	int off, hl;
 
-	ip = (Myip4hdr*)bp->rp;
-	hl = (ip->vihl&0x0F)<<2;
+	ip4 = (Ip4hdr*)bp->rp;
+	if((ip4->vihl & 0xF0) == IP_VER4) {
+		hl = (ip4->vihl&0x0F)<<2;
+		ip6 = nil;
+	} else {
+		hl = IP6HDR;
+		ip6 = (Ip6hdr*)ip4;
+	}
 
 	if(p->priv == nil)
 		goto nomatch;
 
-	h = bp->rp;
-	len = BLEN(bp);
+	c = nil;
+	lifc = nil;
 
-	/* run the v4 filter */
+	/* run the filter */
 	rlock(f);
-	c = nil;
 	mux = f->ipmux->priv;
 	while(mux != nil){
-		if(mux->eoff > len){
-			mux = mux->no;
-			continue;
-		}
-		hp = h + mux->off + ((int)mux->skiphdr)*hl;
-		switch(mux->ctype){
-		case Cbyte:
-			if(*mux->val == *hp)
-				goto yes;
+		switch(mux->type){
+		case Tifc:
+			if(mux->len != IPaddrlen)
+				goto no;
+			for(lifc = ifc->lifc; lifc != nil; lifc = lifc->next)
+				for(v = mux->val; v < mux->e; v += IPaddrlen)
+					if(maskmemcmp(mux->mask, lifc->local, v, IPaddrlen) == 0)
+						goto yes;
+			goto no;
+		case Tdata:
+			off = hl;
 			break;
-		case Cmbyte:
-			if((*hp & *mux->mask) == *mux->val)
-				goto yes;
-			break;
-		case Cshort:
-			if(*((ushort*)mux->val) == *(ushort*)hp)
-				goto yes;
-			break;
-		case Cmshort:
-			if((*(ushort*)hp & (*((ushort*)mux->mask))) == *((ushort*)mux->val))
-				goto yes;
-			break;
-		case Clong:
-			if(*((ulong*)mux->val) == *(ulong*)hp)
-				goto yes;
-			break;
-		case Cmlong:
-			if((*(ulong*)hp & (*((ulong*)mux->mask))) == *((ulong*)mux->val))
-				goto yes;
-			break;
-		case Cifc:
-			if(*((ulong*)mux->val) == *(ulong*)(ifc->lifc->local + IPv4off))
-				goto yes;
-			break;
-		case Cmifc:
-			if((*(ulong*)(ifc->lifc->local + IPv4off) & (*((ulong*)mux->mask))) == *((ulong*)mux->val))
-				goto yes;
-			break;
 		default:
-			v = mux->val;
-			for(e = mux->e; v < e; v = ve){
-				m = mux->mask;
-				hp = h + mux->off;
-				for(ve = v + mux->len; v < ve; v++){
-					if((*hp++ & *m++) != *v)
-						break;
-				}
-				if(v == ve)
-					goto yes;
-			}
+			off = 0;
+			break;
 		}
+		off += mux->off;
+		if(off < 0 || off + mux->len > BLEN(bp))
+			goto no;
+		for(v = mux->val; v < mux->e; v += mux->len)
+			if(maskmemcmp(mux->mask, bp->rp + off, v, mux->len) == 0)
+				goto yes;
+no:
 		mux = mux->no;
 		continue;
 yes:
@@ -744,7 +747,9 @@
 	if(c != nil){
 		/* tack on interface address */
 		bp = padblock(bp, IPaddrlen);
-		ipmove(bp->rp, ifc->lifc->local);
+		if(lifc == nil)
+			lifc = ifc->lifc;
+		ipmove(bp->rp, lifc != nil ? lifc->local : IPnoaddr);
 		qpass(c->rq, concatblock(bp));
 		return;
 	}
@@ -751,18 +756,15 @@
 
 nomatch:
 	/* doesn't match any filter, hand it to the specific protocol handler */
-	ip = (Myip4hdr*)bp->rp;
-	if((ip->vihl & 0xF0) == IP_VER4) {
-		p = f->t2p[ip->proto];
-	} else {
-		ip6 = (Ip6hdr*)bp->rp;
+	if(ip6 != nil)
 		p = f->t2p[ip6->proto];
-	}
-	if(p != nil && p->rcv != nil)
-		(*p->rcv)(p, ifc, bp);
 	else
-		freeblist(bp);
-	return;
+		p = f->t2p[ip4->proto];
+	if(p != nil && p->rcv != nil){
+		(*p->rcv)(p, ifc, bp);
+		return;
+	}
+	freeblist(bp);
 }
 
 static int
@@ -778,11 +780,14 @@
 		n += snprint(buf+n, len-n, "\n");
 		return n;
 	}
-	n += snprint(buf+n, len-n, "h[%d:%d]&", 
-               mux->off+((int)mux->skiphdr)*((int)(uintptr)ipoff->data), 
-               mux->off+(((int)mux->skiphdr)*((int)(uintptr)ipoff->data))+mux->len-1);
-	for(i = 0; i < mux->len; i++)
-		n += snprint(buf+n, len - n, "%2.2ux", mux->mask[i]);
+	n += snprint(buf+n, len-n, "%s[%d:%d]", 
+		mux->type == Tdata ? "data": "iph",
+		mux->off, mux->off+mux->len-1);
+	if(mux->mask != nil){
+		n += snprint(buf+n, len-n, "&");
+		for(i = 0; i < mux->len; i++)
+			n += snprint(buf+n, len - n, "%2.2ux", mux->mask[i]);
+	}
 	n += snprint(buf+n, len-n, "=");
 	v = mux->val;
 	for(j = 0; j < mux->n; j++){