shithub: riscv

ref: 1787584ad839002584e0752e4df41b6f3b70064b
dir: /sys/src/liboventi/rpc.c/

View raw version
#include <u.h>
#include <libc.h>
#include <oventi.h>
#include "session.h"

struct {
	int version;
	char *s;
} vtVersions[] = {
	VtVersion02, "02",
	0, 0,
};

static char EBigString[] = "string too long";
static char EBigPacket[] = "packet too long";
static char ENullString[] = "missing string";
static char EBadVersion[] = "bad format in version string";

static Packet *vtRPC(VtSession *z, int op, Packet *p);


VtSession *
vtAlloc(void)
{
	VtSession *z;

	z = vtMemAllocZ(sizeof(VtSession));
	z->lk = vtLockAlloc();
//	z->inHash = vtSha1Alloc();
	z->inLock = vtLockAlloc();
	z->part = packetAlloc();
//	z->outHash = vtSha1Alloc();
	z->outLock = vtLockAlloc();
	z->fd = -1;
	z->uid = vtStrDup("anonymous");
	z->sid = vtStrDup("anonymous");
	return z;
}

void
vtReset(VtSession *z)
{
	vtLock(z->lk);
	z->cstate = VtStateAlloc;
	if(z->fd >= 0){
		vtFdClose(z->fd);
		z->fd = -1;
	}
	vtUnlock(z->lk);
}

int
vtConnected(VtSession *z)
{
	return z->cstate == VtStateConnected;
}

void
vtDisconnect(VtSession *z, int error)
{
	Packet *p;
	uchar *b;

vtDebug(z, "vtDisconnect\n");
	vtLock(z->lk);
	if(z->cstate == VtStateConnected && !error && z->vtbl == nil) {
		/* clean shutdown */
		p = packetAlloc();
		b = packetHeader(p, 2);
		b[0] = VtQGoodbye;
		b[1] = 0;
		vtSendPacket(z, p);
	}
	if(z->fd >= 0)
		vtFdClose(z->fd);
	z->fd = -1;
	z->cstate = VtStateClosed;
	vtUnlock(z->lk);
}

void
vtClose(VtSession *z)
{
	vtDisconnect(z, 0);
}

void
vtFree(VtSession *z)
{
	if(z == nil)
		return;
	vtLockFree(z->lk);
	vtSha1Free(z->inHash);
	vtLockFree(z->inLock);
	packetFree(z->part);
	vtSha1Free(z->outHash);
	vtLockFree(z->outLock);
	vtMemFree(z->uid);
	vtMemFree(z->sid);
	vtMemFree(z->vtbl);

	memset(z, 0, sizeof(VtSession));
	z->fd = -1;

	vtMemFree(z);
}

char *
vtGetUid(VtSession *s)
{
	return s->uid;
}

char *
vtGetSid(VtSession *z)
{
	return z->sid;
}

int
vtSetDebug(VtSession *z, int debug)
{
	int old;
	vtLock(z->lk);
	old = z->debug;
	z->debug = debug;
	vtUnlock(z->lk);
	return old;
}

int
vtSetFd(VtSession *z, int fd)
{
	vtLock(z->lk);
	if(z->cstate != VtStateAlloc) {
		vtSetError("bad state");
		vtUnlock(z->lk);
		return 0;
	}
	if(z->fd >= 0)
		vtFdClose(z->fd);
	z->fd = fd;
	vtUnlock(z->lk);
	return 1;
}

int
vtGetFd(VtSession *z)
{
	return z->fd;
}

int
vtSetCryptoStrength(VtSession *z, int c)
{
	if(z->cstate != VtStateAlloc) {
		vtSetError("bad state");
		return 0;
	}
	if(c != VtCryptoStrengthNone) {
		vtSetError("not supported yet");
		return 0;
	}
	return 1;
}

int
vtGetCryptoStrength(VtSession *s)
{
	return s->cryptoStrength;
}

int
vtSetCompression(VtSession *z, int fd)
{
	vtLock(z->lk);
	if(z->cstate != VtStateAlloc) {
		vtSetError("bad state");
		vtUnlock(z->lk);
		return 0;
	}
	z->fd = fd;
	vtUnlock(z->lk);
	return 1;
}

int
vtGetCompression(VtSession *s)
{
	return s->compression;
}

int
vtGetCrypto(VtSession *s)
{
	return s->crypto;
}

int
vtGetCodec(VtSession *s)
{
	return s->codec;
}

char *
vtGetVersion(VtSession *z)
{
	int v, i;
	
	v = z->version;
	if(v == 0)
		return "unknown";
	for(i=0; vtVersions[i].version; i++)
		if(vtVersions[i].version == v)
			return vtVersions[i].s;
	assert(0);
	return 0;
}

/* hold z->inLock */
static int
vtVersionRead(VtSession *z, char *prefix, int *ret)
{
	char c;
	char buf[VtMaxStringSize];
	char *q, *p, *pp;
	int i;

	q = prefix;
	p = buf;
	for(;;) {
		if(p >= buf + sizeof(buf)) {
			vtSetError(EBadVersion);
			return 0;
		}
		if(!vtFdReadFully(z->fd, (uchar*)&c, 1))
			return 0;
		if(z->inHash)
			vtSha1Update(z->inHash, (uchar*)&c, 1);
		if(c == '\n') {
			*p = 0;
			break;
		}
		if(c < ' ' || *q && c != *q) {
			vtSetError(EBadVersion);
			return 0;
		}
		*p++ = c;
		if(*q)
			q++;
	}
		
	vtDebug(z, "version string in: %s\n", buf);

	p = buf + strlen(prefix);
	for(;;) {
		for(pp=p; *pp && *pp != ':'  && *pp != '-'; pp++)
			;
		for(i=0; vtVersions[i].version; i++) {
			if(strlen(vtVersions[i].s) != pp-p)
				continue;
			if(memcmp(vtVersions[i].s, p, pp-p) == 0) {
				*ret = vtVersions[i].version;
				return 1;
			}
		}
		p = pp;
		if(*p != ':')
			return 0;
		p++;
	}	
}

Packet*
vtRecvPacket(VtSession *z)
{
	uchar buf[10], *b;
	int n;
	Packet *p;
	int size, len;

	if(z->cstate != VtStateConnected) {
		vtSetError("session not connected");
		return 0;
	}

	vtLock(z->inLock);
	p = z->part;
	/* get enough for head size */
	size = packetSize(p);
	while(size < 2) {
		b = packetTrailer(p, MaxFragSize);
		assert(b != nil);
		n = vtFdRead(z->fd, b, MaxFragSize);
		if(n <= 0)
			goto Err;
		size += n;
		packetTrim(p, 0, size);
	}

	if(!packetConsume(p, buf, 2))
		goto Err;
	len = (buf[0] << 8) | buf[1];
	size -= 2;

	while(size < len) {
		n = len - size;
		if(n > MaxFragSize)
			n = MaxFragSize;
		b = packetTrailer(p, n);
		if(!vtFdReadFully(z->fd, b, n))
			goto Err;
		size += n;
	}
	p = packetSplit(p, len);
	vtUnlock(z->inLock);
	return p;
Err:	
	vtUnlock(z->inLock);
	return nil;	
}

int
vtSendPacket(VtSession *z, Packet *p)
{
	IOchunk ioc;
	int n;
	uchar buf[2];
	
	/* add framing */
	n = packetSize(p);
	if(n >= (1<<16)) {
		vtSetError(EBigPacket);
		packetFree(p);
		return 0;
	}
	buf[0] = n>>8;
	buf[1] = n;
	packetPrefix(p, buf, 2);

	for(;;) {
		n = packetFragments(p, &ioc, 1, 0);
		if(n == 0)
			break;
		if(!vtFdWrite(z->fd, ioc.addr, ioc.len)) {
			packetFree(p);
			return 0;
		}
		packetConsume(p, nil, n);
	}
	packetFree(p);
	return 1;
}


int
vtGetString(Packet *p, char **ret)
{
	uchar buf[2];
	int n;
	char *s;

	if(!packetConsume(p, buf, 2))
		return 0;
	n = (buf[0]<<8) + buf[1];
	if(n > VtMaxStringSize) {
		vtSetError(EBigString);
		return 0;
	}
	s = vtMemAlloc(n+1);
	setmalloctag(s, getcallerpc(&p));
	if(!packetConsume(p, (uchar*)s, n)) {
		vtMemFree(s);
		return 0;
	}
	s[n] = 0;
	*ret = s;
	return 1;
}

int
vtAddString(Packet *p, char *s)
{
	uchar buf[2];
	int n;

	if(s == nil) {
		vtSetError(ENullString);
		return 0;
	}
	n = strlen(s);
	if(n > VtMaxStringSize) {
		vtSetError(EBigString);
		return 0;
	}
	buf[0] = n>>8;
	buf[1] = n;
	packetAppend(p, buf, 2);
	packetAppend(p, (uchar*)s, n);
	return 1;
}

int
vtConnect(VtSession *z, char *password)
{
	char buf[VtMaxStringSize], *p, *ep, *prefix;
	int i;

	USED(password);
	vtLock(z->lk);
	if(z->cstate != VtStateAlloc) {
		vtSetError("bad session state");
		vtUnlock(z->lk);
		return 0;
	}
	if(z->fd < 0){
		vtSetError("%s", z->fderror);
		vtUnlock(z->lk);
		return 0;
	}

	/* be a little anal */
	vtLock(z->inLock);
	vtLock(z->outLock);

	prefix = "venti-";
	p = buf;
	ep = buf + sizeof(buf);
	p = seprint(p, ep, "%s", prefix);
	p += strlen(p);
	for(i=0; vtVersions[i].version; i++) {
		if(i != 0)
			*p++ = ':';
		p = seprint(p, ep, "%s", vtVersions[i].s);
	}
	p = seprint(p, ep, "-libventi\n");
	assert(p-buf < sizeof(buf));
	if(z->outHash)
		vtSha1Update(z->outHash, (uchar*)buf, p-buf);
	if(!vtFdWrite(z->fd, (uchar*)buf, p-buf))
		goto Err;
	
	vtDebug(z, "version string out: %s", buf);

	if(!vtVersionRead(z, prefix, &z->version))
		goto Err;
		
	vtDebug(z, "version = %d: %s\n", z->version, vtGetVersion(z));

	vtUnlock(z->inLock);
	vtUnlock(z->outLock);
	z->cstate = VtStateConnected;
	vtUnlock(z->lk);

	if(z->vtbl)
		return 1;

	if(!vtHello(z))
		goto Err;
	return 1;	
Err:
	if(z->fd >= 0)
		vtFdClose(z->fd);
	z->fd = -1;
	vtUnlock(z->inLock);
	vtUnlock(z->outLock);
	z->cstate = VtStateClosed;
	vtUnlock(z->lk);
	return 0;	
}