shithub: riscv

ref: 2b323158bd8f2a1e1f86796e0bb0139e0c77dace
dir: /sys/src/cmd/import.c/

View raw version
#include <u.h>
#include <libc.h>
#include <auth.h>
#include <libsec.h>

enum {
	Encnone,
	Encssl,
	Enctls,
};

static char *encprotos[] = {
	[Encnone] =	"clear",
	[Encssl] =	"ssl",
	[Enctls] = 	"tls",
			nil,
};

char		*keyspec = "";
char		*filterp;
char		*ealgs = "rc4_256 sha1";
int		encproto = Encnone;
char		*aan = "/bin/aan";
char		*anstring  = "tcp!*!0";
AuthInfo 	*ai;
int		debug;
int		doauth = 1;
int		timedout;
int		skiptree;

int	connect(char*, char*);
int	passive(void);
void	catcher(void*, char*);
void	sysfatal(char*, ...);
void	usage(void);
int	filter(int, char *, char *);

static void	mksecret(char *, uchar *);

/*
 * based on libthread's threadsetname, but drags in less library code.
 * actually just sets the arguments displayed.
 */
void
procsetname(char *fmt, ...)
{
	int fd;
	char *cmdname;
	char buf[128];
	va_list arg;

	va_start(arg, fmt);
	cmdname = vsmprint(fmt, arg);
	va_end(arg);
	if (cmdname == nil)
		return;
	snprint(buf, sizeof buf, "#p/%d/args", getpid());
	if((fd = open(buf, OWRITE)) >= 0){
		write(fd, cmdname, strlen(cmdname)+1);
		close(fd);
	}
	free(cmdname);
}

void
post(char *name, char *envname, int srvfd)
{
	int fd;
	char buf[32];

	fd = create(name, OWRITE, 0600);
	if(fd < 0)
		return;
	snprint(buf, sizeof(buf), "%d", srvfd);
	if(write(fd, buf, strlen(buf)) != strlen(buf))
		sysfatal("srv write: %r");
	close(fd);
	putenv(envname, name);
}

static int
lookup(char *s, char *l[])
{
	int i;

	for (i = 0; l[i] != 0; i++)
		if (strcmp(l[i], s) == 0)
			return i;
	return -1;
}

void
main(int argc, char **argv)
{
	char *mntpt, *srvpost, srvfile[64];
	int backwards = 0, fd, mntflags;

	quotefmtinstall();
	srvpost = nil;
	mntflags = MREPL;
	ARGBEGIN{
	case 'A':
		doauth = 0;
		break;
	case 'a':
		mntflags = MAFTER;
		break;
	case 'b':
		mntflags = MBEFORE;
		break;
	case 'c':
		mntflags |= MCREATE;
		break;
	case 'C':
		mntflags |= MCACHE;
		break;
	case 'd':
		debug++;
		break;
	case 'f':
		/* ignored but allowed for compatibility */
		break;
	case 'E':
		if ((encproto = lookup(EARGF(usage()), encprotos)) < 0)
			usage();
		break;
	case 'e':
		ealgs = EARGF(usage());
		if(*ealgs == 0 || strcmp(ealgs, "clear") == 0)
			ealgs = nil;
		break;
	case 'k':
		keyspec = EARGF(usage());
		break;
	case 'p':
		filterp = aan;
		break;
	case 'n':
		anstring = EARGF(usage());
		break;
	case 's':
		srvpost = EARGF(usage());
		break;
	case 'B':
		backwards = 1;
		break;
	case 'z':
		skiptree = 1;
		break;
	default:
		usage();
	}ARGEND;

	mntpt = 0;		/* to shut up compiler */
	if(backwards){
		switch(argc) {
		default:
			mntpt = argv[0];
			break;
		case 0:
			usage();
		}
	} else {
		switch(argc) {
		case 2:
			mntpt = argv[1];
			break;
		case 3:
			mntpt = argv[2];
			break;
		default:
			usage();
		}
	}

	if (encproto == Enctls)
		sysfatal("%s: tls has not yet been implemented", argv[0]);

	notify(catcher);
	alarm(60*1000);

	if (backwards)
		fd = passive();
	else
		fd = connect(argv[0], argv[1]);

	fprint(fd, "impo %s %s\n", filterp? "aan": "nofilter", encprotos[encproto]);

	if (encproto != Encnone && ealgs && ai) {
		uchar key[16], digest[SHA1dlen];
		char fromclientsecret[21];
		char fromserversecret[21];
		int i;

		if(ai->nsecret < 8)
			sysfatal("secret too small to ssl");
		memmove(key+4, ai->secret, 8);

		/* exchange random numbers */
		srand(truerand());
		for(i = 0; i < 4; i++)
			key[i] = rand();
		if(write(fd, key, 4) != 4)
			sysfatal("can't write key part: %r");
		if(readn(fd, key+12, 4) != 4)
			sysfatal("can't read key part: %r");

		/* scramble into two secrets */
		sha1(key, sizeof(key), digest, nil);
		mksecret(fromclientsecret, digest);
		mksecret(fromserversecret, digest+10);

		if (filterp)
			fd = filter(fd, filterp, backwards ? nil : argv[0]);

		/* set up encryption */
		procsetname("pushssl");
		fd = pushssl(fd, ealgs, fromclientsecret, fromserversecret, nil);
		if(fd < 0)
			sysfatal("can't establish ssl connection: %r");
	}
	else if (filterp)
		fd = filter(fd, filterp, backwards ? nil : argv[0]);

	if(ai)
		auth_freeAI(ai);

	if(srvpost){
		snprint(srvfile, sizeof(srvfile), "/srv/%s", srvpost);
		remove(srvfile);
		post(srvfile, srvpost, fd);
	}
	procsetname("mount on %s", mntpt);
	if(mount(fd, -1, mntpt, mntflags, "") < 0)
		sysfatal("can't mount %s: %r", argv[1]);
	alarm(0);

	if(backwards && argc > 1){
		exec(argv[1], &argv[1]);
		sysfatal("exec: %r");
	}
	exits(0);
}

void
catcher(void*, char *msg)
{
	timedout = 1;
	if(strcmp(msg, "alarm") == 0)
		noted(NCONT);
	noted(NDFLT);
}

int
connect(char *system, char *tree)
{
	char buf[ERRMAX], dir[128], *na;
	int fd, n;

	na = netmkaddr(system, 0, "exportfs");
	procsetname("dial %s", na);
	if((fd = dial(na, 0, dir, 0)) < 0)
		sysfatal("can't dial %s: %r", system);

	if(doauth){
		procsetname("auth_proxy auth_getkey proto=p9any role=client %s", keyspec);
		ai = auth_proxy(fd, auth_getkey, "proto=p9any role=client %s", keyspec);
		if(ai == nil)
			sysfatal("%r: %s", system);
	}

	if(!skiptree){
		procsetname("writing tree name %s", tree);
		n = write(fd, tree, strlen(tree));
		if(n < 0)
			sysfatal("can't write tree: %r");

		strcpy(buf, "can't read tree");

		procsetname("awaiting OK for %s", tree);
		n = read(fd, buf, sizeof buf - 1);
		if(n!=2 || buf[0]!='O' || buf[1]!='K'){
			if (timedout)
				sysfatal("timed out connecting to %s", na);
			buf[sizeof buf - 1] = '\0';
			sysfatal("bad remote tree: %s", buf);
		}
	}
	return fd;
}

int
passive(void)
{
	int fd;

	/*
	 * Ignore doauth==0 on purpose.  Is it useful here?
	 */

	procsetname("auth_proxy auth_getkey proto=p9any role=server");
	ai = auth_proxy(0, auth_getkey, "proto=p9any role=server");
	if(ai == nil)
		sysfatal("auth_proxy: %r");
	if(auth_chuid(ai, nil) < 0)
		sysfatal("auth_chuid: %r");
	putenv("service", "import");

	fd = dup(0, -1);
	close(0);
	open("/dev/null", ORDWR);
	close(1);
	open("/dev/null", ORDWR);

	return fd;
}

void
usage(void)
{
	fprint(2, "usage: import [-abcC] [-A] [-E clear|ssl|tls] "
"[-e 'crypt auth'|clear] [-k keypattern] [-p] [-n address ] [-z] host remotefs [mountpoint]\n");
	exits("usage");
}

int
filter(int fd, char *cmd, char *host)
{
	char addr[128], buf[256], *s, *file, *argv[16];
	int lfd, p[2], len, argc;

	if(host == nil){
		/* Get a free port and post it to the client. */
		if (announce(anstring, addr) < 0)
			sysfatal("filter: Cannot announce %s: %r", anstring);

		snprint(buf, sizeof(buf), "%s/local", addr);
		if ((lfd = open(buf, OREAD)) < 0)
			sysfatal("filter: Cannot open %s: %r", buf);
		if ((len = read(lfd, buf, sizeof buf - 1)) < 0)
			sysfatal("filter: Cannot read %s: %r", buf);
		close(lfd);
		buf[len] = '\0';
		if ((s = strchr(buf, '\n')) != nil)
			len = s - buf;
		if (write(fd, buf, len) != len) 
			sysfatal("filter: cannot write port; %r");
	} else {
		/* Read address string from connection */
		if ((len = read(fd, buf, sizeof buf - 1)) < 0)
			sysfatal("filter: cannot write port; %r");
		buf[len] = '\0';

		if ((s = strrchr(buf, '!')) == nil)
			sysfatal("filter: illegally formatted port %s", buf);
		strecpy(addr, addr+sizeof(addr), netmkaddr(host, "tcp", s+1));
		strecpy(strrchr(addr, '!'), addr+sizeof(addr), s);
	}

	if(debug)
		fprint(2, "filter: %s\n", addr);

	snprint(buf, sizeof(buf), "%s", cmd);
	argc = tokenize(buf, argv, nelem(argv)-3);
	if (argc == 0)
		sysfatal("filter: empty command");

	if(host != nil)
		argv[argc++] = "-c";
	argv[argc++] = addr;
	argv[argc] = nil;

	file = argv[0];
	if((s = strrchr(argv[0], '/')) != nil)
		argv[0] = s+1;

	if(pipe(p) < 0)
		sysfatal("pipe: %r");

	switch(rfork(RFNOWAIT|RFPROC|RFMEM|RFFDG|RFREND)) {
	case -1:
		sysfatal("filter: rfork; %r\n");
	case 0:
		close(fd);
		if (dup(p[0], 1) < 0)
			sysfatal("filter: Cannot dup to 1; %r");
		if (dup(p[0], 0) < 0)
			sysfatal("filter: Cannot dup to 0; %r");
		close(p[0]);
		close(p[1]);
		exec(file, argv);
		sysfatal("filter: exec; %r");
	default:
		dup(p[1], fd);
		close(p[0]);
		close(p[1]);
	}
	return fd;
}

static void
mksecret(char *t, uchar *f)
{
	sprint(t, "%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux",
		f[0], f[1], f[2], f[3], f[4], f[5], f[6], f[7], f[8], f[9]);
}