shithub: riscv

Download patch

ref: a944c37d68b3742ae5156cec3bd23f64ad5c7a1b
parent: 99825e22ed403c8636751238743d01a1b143f8e2
author: cinap_lenrek <[email protected]>
date: Fri Apr 21 15:23:56 EDT 2017

ssh: actually handle flow control and channel id's

--- a/sys/src/cmd/ssh.c
+++ b/sys/src/cmd/ssh.c
@@ -45,19 +45,33 @@
 	MSG_CHANNEL_FAILURE,
 };
 
+
+enum {
+	Overhead = 256,		// enougth for MSG_CHANNEL_DATA header
+	MaxPacket = 1<<15,
+	WinPackets = 8,		// (1<<15) * 8 = 256K
+};
+
 typedef struct
 {
-	int		pid;
 	u32int		seq;
 	u32int		kex;
+	u32int		chan;
+
+	int		win;
+	int		pkt;
+	int		eof;
+
 	Chachastate	cs1;
 	Chachastate	cs2;
-	char		*v;
-	char		eof;
 
 	uchar		*r;
 	uchar		*w;
-	uchar		b[1<<15];
+	uchar		b[Overhead + MaxPacket];
+
+	char		*v;
+	int		pid;
+	Rendez;
 } Oneway;
 
 int nsid;
@@ -902,7 +916,6 @@
 	switch(recv.r[0]){
 	case MSG_IGNORE:
 	case MSG_GLOBAL_REQUEST:
-	case MSG_CHANNEL_WINDOW_ADJUST:
 		return;
 	case MSG_DISCONNECT:
 		if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
@@ -922,24 +935,38 @@
 	case MSG_CHANNEL_DATA:
 		if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
 			break;
-		if(c != 0)
+		if(c != recv.chan)
 			break;
 		if(write(1, s, n) != n)
 			sysfatal("write out: %r");
 	Winadjust:
-		sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, c, n);
+		recv.win -= n;
+		if(recv.win < recv.pkt){
+			n = WinPackets*recv.pkt;
+			recv.win += n;
+			sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, send.chan, n);
+		}
 		return;
 	case MSG_CHANNEL_EXTENDED_DATA:
 		if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
 			break;
-		if(c != 0)
+		if(c != recv.chan)
 			break;
 		if(b == 1) write(2, s, n);
 		goto Winadjust;
+	case MSG_CHANNEL_WINDOW_ADJUST:
+		if(unpack(recv.r, recv.w-recv.r, "_uu", &c, &n) < 0)
+			break;
+		if(c != recv.chan)
+			break;
+		send.win += n;
+		if(send.win >= send.pkt)
+			rwakeup(&send);
+		return;
 	case MSG_CHANNEL_REQUEST:
 		if(unpack(recv.r, recv.w-recv.r, "_usb.", &c, &s, &n, &b, &p) < 0)
 			break;
-		if(c != 0)
+		if(c != recv.chan)
 			break;
 		if(n == 11 && memcmp(s, "exit-signal", n) == 0){
 			if(unpack(p, recv.w-p, "s", &s, &n) < 0)
@@ -1044,7 +1071,6 @@
 void
 main(int argc, char *argv[])
 {
-	static char buf[8*1024];
 	static QLock sl;
 	int b, n, c;
 	char *s;
@@ -1106,6 +1132,8 @@
 		sysfatal("bad server version: %s", recv.v);
 	recv.v = strdup(recv.v);
 
+	send.l = recv.l = &sl;
+
 	kex(0);
 
 	if(user == nil)
@@ -1124,12 +1152,16 @@
 	if(pubkeyauth() < 0 && passauth() < 0 && kbintauth() < 0)
 		sysfatal("auth: %r");
 
+	recv.pkt = MaxPacket;
+	recv.win = WinPackets*recv.pkt;
+	recv.chan = 0;
+
 	/* open hailing frequencies */
 	sendpkt("bsuuu", MSG_CHANNEL_OPEN,
 		"session", 7,
-		0,
-		8*sizeof(buf),
-		sizeof(buf));
+		recv.chan,
+		recv.win,
+		recv.pkt);
 
 Next1:	switch(recvpkt()){
 	default:
@@ -1143,6 +1175,11 @@
 		break;
 	}
 
+	if(unpack(recv.r, recv.w-recv.r, "_uuuu", &recv.chan, &send.chan, &send.win, &send.pkt) < 0)
+		sysfatal("bad channel open confirmation");
+	if(send.pkt <= 0 || send.pkt > MaxPacket)
+		send.pkt = MaxPacket;
+
 	notify(catch);
 	atexit(shutdown);
 
@@ -1170,7 +1207,7 @@
 	if(raw) {
 		rawon();
 		sendpkt("busbsuuuus", MSG_CHANNEL_REQUEST,
-			0,
+			send.chan,
 			"pty-req", 7,
 			0,
 			tty.term, strlen(tty.term),
@@ -1182,19 +1219,20 @@
 	}
 	if(cmd == nil){
 		sendpkt("busb", MSG_CHANNEL_REQUEST,
-			0,
+			send.chan,
 			"shell", 5,
 			0);
 	} else {
 		sendpkt("busbs", MSG_CHANNEL_REQUEST,
-			0,
+			send.chan,
 			"exec", 4,
 			0,
 			cmd, strlen(cmd));
 	}
 	for(;;){
+		static uchar buf[MaxPacket];
 		qunlock(&sl);
-		n = read(0, buf, sizeof(buf));
+		n = read(0, buf, send.pkt);
 		qlock(&sl);
 		if(send.eof)
 			break;
@@ -1201,7 +1239,7 @@
 		if(n < 0 && wasintr()){
 			if(!raw) break;
 			sendpkt("busbs", MSG_CHANNEL_REQUEST,
-				0,
+				send.chan,
 				"signal", 6,
 				0,
 				"INT", 3);
@@ -1210,12 +1248,15 @@
 		}
 		if(n <= 0)
 			break;
+		send.win -= n;
+		while(send.win < 0)
+			rsleep(&send);
 		sendpkt("bus", MSG_CHANNEL_DATA,
-			0,
+			send.chan,
 			buf, n);
 	}
 	if(send.eof++ == 0)
-		sendpkt("bu", raw ? MSG_CHANNEL_CLOSE : MSG_CHANNEL_EOF, 0);
+		sendpkt("bu", raw ? MSG_CHANNEL_CLOSE : MSG_CHANNEL_EOF, send.chan);
 	qunlock(&sl);
 
 	exits(nil);