shithub: musw

Download patch

ref: 62e75d8830eb56ab03bd4689d51ffd6d4150f461
parent: 775309861b51dd1f340d82074b7e9234f3e0675e
author: rodri <[email protected]>
date: Sat Feb 11 18:21:00 EST 2023

implemented per-packet HMAC to avoid MITM tampering.

--- a/dat.h
+++ b/dat.h
@@ -46,7 +46,7 @@
 
 enum {
 	ProtocolID	= 0x5753554d,	/* MUSW */
-	Framehdrsize	= 4+1+4+4+2,
+	Framehdrsize	= 4+1+4+4+2+MD5dlen,
 	MTU		= 1024
 };
 
@@ -147,6 +147,7 @@
 	u32int seq;
 	u32int ack;
 	u16int len;
+	uchar sig[MD5dlen];
 	uchar data[];
 };
 
@@ -177,3 +178,5 @@
 	Universe *u;
 	Party *prev, *next;
 };
+
+#pragma varargck type "Φ" Frame*
--- /dev/null
+++ b/fmt.c
@@ -1,0 +1,25 @@
+#include <u.h>
+#include <libc.h>
+#include <ip.h>
+#include <mp.h>
+#include <libsec.h>
+#include <thread.h>
+#include <draw.h>
+#include <geometry.h>
+#include "dat.h"
+#include "fns.h"
+
+int
+Φfmt(Fmt *f)
+{
+	int n, i;
+	Frame *frame;
+
+	frame = va_arg(f->args, Frame*);
+
+	n = fmtprint(f, "id %x type %ud seq %ud ack %ud len %ud sig ",
+		frame->id, frame->type, frame->seq, frame->ack, frame->len);
+	for(i = 0; i < MD5dlen; i++)
+		n += fmtprint(f, "%2.2x", frame->sig[i]);
+	return n;
+}
--- a/fns.h
+++ b/fns.h
@@ -53,4 +53,11 @@
 NetConn *newnetconn(NCState, Udphdr*);
 void delnetconn(NetConn*);
 Frame *newframe(Frame*, u8int, u32int, u32int, u16int, uchar*);
+void signframe(Frame*, ulong);
+int verifyframe(Frame*, ulong);
 void delframe(Frame*);
+
+/*
+ * fmt
+ */
+int Φfmt(Fmt*);
--- a/mkfile
+++ b/mkfile
@@ -15,6 +15,7 @@
 	universe.$O\
 	sprite.$O\
 	net.$O\
+	fmt.$O\
 
 HFILES=\
 	dat.h\
--- a/musw.c
+++ b/musw.c
@@ -1,6 +1,8 @@
 #include <u.h>
 #include <libc.h>
 #include <ip.h>
+#include <mp.h>
+#include <libsec.h>
 #include <bio.h>
 #include <thread.h>
 #include <draw.h>
@@ -164,6 +166,7 @@
 
 	frame = newframe(nil, NCinput, 0, 0, sizeof(kdown), nil);
 	pack(frame->data, frame->len, "k", kdown);
+	signframe(frame, netconn.dh.priv);
 	sendp(egress, frame);
 }
 
@@ -241,9 +244,8 @@
 		if(debug){
 			rport = frame->udp.rport[0]<<8 | frame->udp.rport[1];
 			lport = frame->udp.lport[0]<<8 | frame->udp.lport[1];
-			fprint(2, "%I!%ud ← %I!%ud | rcvd type %ud seq %ud ack %ud len %ud\n",
-				frame->udp.laddr, lport, frame->udp.raddr, rport,
-				frame->type, frame->seq, frame->ack, frame->len);
+			fprint(2, "%I!%ud → %I!%ud | rcvd %Φ\n",
+				frame->udp.laddr, lport, frame->udp.raddr, rport, frame);
 		}
 	}
 	closeioproc(io);
@@ -267,7 +269,7 @@
 				unpack(frame->data, frame->len, "kk", &netconn.dh.p, &netconn.dh.g);
 
 				newf = newframe(frame, NCdhx, 0, 0, sizeof(ulong), nil);
-	
+
 				netconn.dh.sec = truerand();
 				pack(newf->data, newf->len, "k", dhgenkey(netconn.dh.g, netconn.dh.sec, netconn.dh.p));
 				sendp(egress, newf);
@@ -274,7 +276,7 @@
 
 				if(debug)
 					fprint(2, "\tsent pubkey %ld\n", dhgenkey(netconn.dh.g, netconn.dh.sec, netconn.dh.p));
-	
+
 				break;
 			case NSdhx:
 				unpack(frame->data, frame->len, "k", &netconn.dh.pub);
@@ -281,7 +283,7 @@
 				netconn.state = NCSConnected;
 
 				if(debug)
-					fprint(2, "\trecvd pubkey %ld\n", netconn.dh.pub);
+					fprint(2, "\trcvd pubkey %ld\n", netconn.dh.pub);
 
 				netconn.dh.priv = dhgenkey(netconn.dh.pub, netconn.dh.sec, netconn.dh.p);
 				break;
@@ -288,6 +290,12 @@
 			}
 			break;
 		case NCSConnected:
+			if(verifyframe(frame, netconn.dh.priv) != 0){
+				if(debug)
+					fprint(2, "\tbad signature\n");
+				goto discard;
+			}
+
 			switch(frame->type){
 			case NSsimstate:
 				unpack(frame->data, frame->len, "PdPdP",
@@ -297,6 +305,7 @@
 				break;
 			case NSnudge:
 				newf = newframe(frame, NCnudge, 0, 0, 0, nil);
+				signframe(newf, netconn.dh.priv);
 
 				sendp(egress, newf);
 
@@ -308,7 +317,7 @@
 			break;
 		}
 discard:
-		free(frame);
+		delframe(frame);
 	}
 }
 
@@ -332,9 +341,8 @@
 		if(debug){
 			rport = frame->udp.rport[0]<<8 | frame->udp.rport[1];
 			lport = frame->udp.lport[0]<<8 | frame->udp.lport[1];
-			fprint(2, "%I!%ud → %I!%ud | sent type %ud seq %ud ack %ud len %ud\n",
-				frame->udp.laddr, lport, frame->udp.raddr, rport,
-				frame->type, frame->seq, frame->ack, frame->len);
+			fprint(2, "%I!%ud → %I!%ud | sent %Φ\n",
+				frame->udp.laddr, lport, frame->udp.raddr, rport, frame);
 		}
 
 		free(frame);
@@ -438,6 +446,7 @@
 
 	GEOMfmtinstall();
 	fmtinstall('I', eipfmt);
+	fmtinstall(L'Φ', Φfmt);
 	ARGBEGIN{
 	case 'd':
 		debug++;
--- a/muswd.c
+++ b/muswd.c
@@ -1,6 +1,8 @@
 #include <u.h>
 #include <libc.h>
 #include <ip.h>
+#include <mp.h>
+#include <libsec.h>
 #include <thread.h>
 #include <draw.h>
 #include <geometry.h>
@@ -46,7 +48,7 @@
 
 	ncpe = conns+nconns;
 
-	for(ncp = conns; ncp < conns+nconns; ncp++)
+	for(ncp = conns; ncp < ncpe; ncp++)
 		if(*ncp == nc){
 			memmove(ncp, ncp+1, sizeof(NetConn*)*(ncpe-ncp-1));
 			nconns--;
@@ -77,9 +79,8 @@
 		if(debug){
 			rport = frame->udp.rport[0]<<8 | frame->udp.rport[1];
 			lport = frame->udp.lport[0]<<8 | frame->udp.lport[1];
-			fprint(2, "%I!%ud ← %I!%ud | rcvd type %ud seq %ud ack %ud len %ud\n",
-				frame->udp.laddr, lport, frame->udp.raddr, rport,
-				frame->type, frame->seq, frame->ack, frame->len);
+			fprint(2, "%I!%ud → %I!%ud | rcvd %Φ\n",
+				frame->udp.laddr, lport, frame->udp.raddr, rport, frame);
 		}
 	}
 	closeioproc(io);
@@ -103,13 +104,13 @@
 			if(frame->type == NChi){
 				nc = newnetconn(NCSConnecting, &frame->udp);
 				putconn(nc);
-	
+
 				newf = newframe(frame, NShi, 0, 0, 2*sizeof(ulong), nil);
-	
+
 				dhgenpg(&nc->dh.p, &nc->dh.g);
 				pack(newf->data, newf->len, "kk", nc->dh.p, nc->dh.g);
 				sendp(egress, newf);
-	
+
 				if(debug)
 					fprint(2, "\tsent p %ld g %ld\n", nc->dh.p, nc->dh.g);
 			}else
@@ -124,10 +125,10 @@
 				nc->state = NCSConnected;
 
 				if(debug)
-					fprint(2, "\trecvd pubkey %ld\n", nc->dh.pub);
+					fprint(2, "\trcvd pubkey %ld\n", nc->dh.pub);
 
 				newf = newframe(frame, NSdhx, 0, 0, sizeof(ulong), nil);
-	
+
 				nc->dh.sec = truerand();
 				nc->dh.priv = dhgenkey(nc->dh.pub, nc->dh.sec, nc->dh.p);
 				pack(newf->data, newf->len, "k", dhgenkey(nc->dh.g, nc->dh.sec, nc->dh.p));
@@ -140,6 +141,12 @@
 			}
 			break;
 		case NCSConnected:
+			if(verifyframe(frame, nc->dh.priv) != 0){
+				if(debug)
+					fprint(2, "\tbad signature\n");
+				goto discard;
+			}
+
 			switch(frame->type){
 			case NCinput:
 				unpack(frame->data, frame->len, "k", &kdown);
@@ -150,13 +157,13 @@
 				break;
 			case NCbuhbye:
 				popconn(nc);
-				free(nc);
+				delnetconn(nc);
 				break;
 			}
 			break;
 		}
 discard:
-		free(frame);
+		delframe(frame);
 	}
 }
 
@@ -180,9 +187,8 @@
 		if(debug){
 			rport = frame->udp.rport[0]<<8 | frame->udp.rport[1];
 			lport = frame->udp.lport[0]<<8 | frame->udp.lport[1];
-			fprint(2, "%I!%ud → %I!%ud | sent type %ud seq %ud ack %ud len %ud\n",
-				frame->udp.laddr, lport, frame->udp.raddr, rport,
-				frame->type, frame->seq, frame->ack, frame->len);
+			fprint(2, "%I!%ud → %I!%ud | sent %Φ\n",
+				frame->udp.laddr, lport, frame->udp.raddr, rport, frame);
 		}
 
 		free(frame);
@@ -334,6 +340,7 @@
 
 	GEOMfmtinstall();
 	fmtinstall('I', eipfmt);
+	fmtinstall(L'Φ', Φfmt);
 	addr = "udp!*!112";
 	ARGBEGIN{
 	case 'a':
--- a/net.c
+++ b/net.c
@@ -1,6 +1,8 @@
 #include <u.h>
 #include <libc.h>
 #include <ip.h>
+#include <mp.h>
+#include <libsec.h>
 #include <thread.h>
 #include <draw.h>
 #include <geometry.h>
@@ -69,6 +71,7 @@
 	Frame *f;
 
 	f = emalloc(sizeof(Frame)+len);
+	memset(f, 0, sizeof(Frame));
 	f->id = ProtocolID;
 	f->type = type;
 	if(pf != nil){
@@ -85,6 +88,40 @@
 		memmove(f->data, data, f->len);
 
 	return f;
+}
+
+void
+signframe(Frame *f, ulong key)
+{
+	uchar k[sizeof(ulong)];
+	uchar h[MD5dlen];
+	uchar msg[MTU];
+	int n;
+
+	k[0] = key; k[1] = key>>8; k[2] = key>>16; k[3] = key>>24;
+
+	memset(f->sig, 0, MD5dlen);
+	n = pack(msg, sizeof msg, "f", f);
+	hmac_md5(msg, n, k, sizeof k, h, nil);
+	memmove(f->sig, h, MD5dlen);
+}
+
+int
+verifyframe(Frame *f, ulong key)
+{
+	uchar k[sizeof(ulong)];
+	uchar h0[MD5dlen], h1[MD5dlen];
+	uchar msg[MTU];
+	int n;
+
+	k[0] = key; k[1] = key>>8; k[2] = key>>16; k[3] = key>>24;
+
+	memmove(h0, f->sig, MD5dlen);
+	memset(f->sig, 0, MD5dlen);
+	n = pack(msg, sizeof msg, "f", f);
+	hmac_md5(msg, n, k, sizeof k, h1, nil);
+	memmove(f->sig, h0, MD5dlen);
+	return memcmp(h0, h1, MD5dlen);
 }
 
 void
--- a/pack.c
+++ b/pack.c
@@ -1,6 +1,8 @@
 #include <u.h>
 #include <libc.h>
 #include <ip.h>
+#include <mp.h>
+#include <libsec.h>
 #include <draw.h>
 #include <geometry.h>
 #include "dat.h"
@@ -91,6 +93,7 @@
 			put4(p, F->seq), p += 4;
 			put4(p, F->ack), p += 4;
 			put2(p, F->len), p += 2;
+			memmove(p, F->sig, MD5dlen), p += MD5dlen;
 
 			if(p+F->len > e)
 				goto err;
@@ -161,6 +164,7 @@
 			F->seq = get4(p), p += 4;
 			F->ack = get4(p), p += 4;
 			F->len = get2(p), p += 2;
+			memmove(F->sig, p, MD5dlen), p += MD5dlen;
 
 			if(p+F->len > e)
 				goto err;
--- a/party.c
+++ b/party.c
@@ -1,6 +1,8 @@
 #include <u.h>
 #include <libc.h>
 #include <ip.h>
+#include <mp.h>
+#include <libsec.h>
 #include <draw.h>
 #include <geometry.h>
 #include "dat.h"
--- a/physics.c
+++ b/physics.c
@@ -1,6 +1,8 @@
 #include <u.h>
 #include <libc.h>
 #include <ip.h>
+#include <mp.h>
+#include <libsec.h>
 #include <draw.h>
 #include <geometry.h>
 #include "dat.h"
--- a/sprite.c
+++ b/sprite.c
@@ -1,6 +1,8 @@
 #include <u.h>
 #include <libc.h>
 #include <ip.h>
+#include <mp.h>
+#include <libsec.h>
 #include <draw.h>
 #include <geometry.h>
 #include "dat.h"
--- a/universe.c
+++ b/universe.c
@@ -1,6 +1,8 @@
 #include <u.h>
 #include <libc.h>
 #include <ip.h>
+#include <mp.h>
+#include <libsec.h>
 #include <draw.h>
 #include <geometry.h>
 #include "dat.h"