shithub: riscv

ref: 72a168b5fb1d3b40bf6c8051d06d7237f27a7d2b
dir: /sys/src/liboventi/client.c/

View raw version
#include <u.h>
#include <libc.h>
#include <oventi.h>
#include "session.h"

static char EProtocolBotch[] = "venti protocol botch";
static char ELumpSize[] = "illegal lump size";
static char ENotConnected[] = "not connected to venti server";

static Packet *vtRPC(VtSession *z, int op, Packet *p);

VtSession *
vtClientAlloc(void)
{
	VtSession *z = vtAlloc();	
	return z;
}

VtSession *
vtDial(char *host, int canfail)
{
	VtSession *z;
	int fd;
	char *na;
	char e[ERRMAX];

	if(host == nil) 
		host = getenv("venti");
	if(host == nil)
		host = "$venti";

	if (host == nil) {
		if (!canfail)
			werrstr("no venti host set");
		na = "";
		fd = -1;
	} else {
		na = netmkaddr(host, 0, "venti");
		fd = dial(na, 0, 0, 0);
	}
	if(fd < 0){
		rerrstr(e, sizeof e);
		if(!canfail){
			vtSetError("venti dialstring %s: %s", na, e);
			return nil;
		}
	}
	z = vtClientAlloc();
	if(fd < 0)
		strcpy(z->fderror, e);
	vtSetFd(z, fd);
	return z;
}

int
vtRedial(VtSession *z, char *host)
{
	int fd;
	char *na;

	if(host == nil) 
		host = getenv("venti");
	if(host == nil)
		host = "$venti";

	na = netmkaddr(host, 0, "venti");
	fd = dial(na, 0, 0, 0);
	if(fd < 0){
		vtOSError();
		return 0;
	}
	vtReset(z);
	vtSetFd(z, fd);
	return 1;
}

VtSession *
vtStdioServer(char *server)
{
	int pfd[2];
	VtSession *z;

	if(server == nil)
		return nil;

	if(access(server, AEXEC) < 0) {
		vtOSError();
		return nil;
	}

	if(pipe(pfd) < 0) {
		vtOSError();
		return nil;
	}

	switch(fork()) {
	case -1:
		close(pfd[0]);
		close(pfd[1]);
		vtOSError();
		return nil;
	case 0:
		close(pfd[0]);
		dup(pfd[1], 0);
		dup(pfd[1], 1);
		execl(server, "ventiserver", "-i", nil);
		exits("exec failed");
	}
	close(pfd[1]);

	z = vtClientAlloc();
	vtSetFd(z, pfd[0]);
	return z;
}

int
vtPing(VtSession *z)
{
	Packet *p = packetAlloc();

	p = vtRPC(z, VtQPing, p);
	if(p == nil)
		return 0;
	packetFree(p);
	return 1;
}

int
vtHello(VtSession *z)
{
	Packet *p;
	uchar buf[10];
	char *sid;
	int crypto, codec;

	sid = nil;

	p = packetAlloc();
	if(!vtAddString(p, vtGetVersion(z)))
		goto Err;
	if(!vtAddString(p, vtGetUid(z)))
		goto Err;
	buf[0] = vtGetCryptoStrength(z);
	buf[1] = 0;
	buf[2] = 0;
	packetAppend(p, buf, 3);
	p = vtRPC(z, VtQHello, p);
	if(p == nil)
		return 0;
	if(!vtGetString(p, &sid))
		goto Err;
	if(!packetConsume(p, buf, 2))
		goto Err;
	if(packetSize(p) != 0) {
		vtSetError(EProtocolBotch);
		goto Err;
	}
	crypto = buf[0];
	codec = buf[1];

	USED(crypto);
	USED(codec);

	packetFree(p);

	vtLock(z->lk);
	z->sid = sid;
	z->auth.state = VtAuthOK;
	vtSha1Free(z->inHash);
	z->inHash = nil;
	vtSha1Free(z->outHash);
	z->outHash = nil;
	vtUnlock(z->lk);

	return 1;
Err:
	packetFree(p);
	vtMemFree(sid);
	return 0;
}

int
vtSync(VtSession *z)
{
	Packet *p = packetAlloc();

	p = vtRPC(z, VtQSync, p);
	if(p == nil)
		return 0;
	if(packetSize(p) != 0){
		vtSetError(EProtocolBotch);
		goto Err;
	}
	packetFree(p);
	return 1;

Err:
	packetFree(p);
	return 0;
}

int
vtWrite(VtSession *z, uchar score[VtScoreSize], int type, uchar *buf, int n)
{
	Packet *p = packetAlloc();

	packetAppend(p, buf, n);
	return vtWritePacket(z, score, type, p);
}

int
vtWritePacket(VtSession *z, uchar score[VtScoreSize], int type, Packet *p)
{
	int n = packetSize(p);
	uchar *hdr;

	if(n > VtMaxLumpSize || n < 0) {
		vtSetError(ELumpSize);
		goto Err;
	}
	
	if(n == 0) {
		memmove(score, vtZeroScore, VtScoreSize);
		return 1;
	}

	hdr = packetHeader(p, 4);
	hdr[0] = type;
	hdr[1] = 0;	/* pad */
	hdr[2] = 0;	/* pad */
	hdr[3] = 0;	/* pad */
	p = vtRPC(z, VtQWrite, p);
	if(p == nil)
		return 0;
	if(!packetConsume(p, score, VtScoreSize))
		goto Err;
	if(packetSize(p) != 0) {
		vtSetError(EProtocolBotch);
		goto Err;
	}
	packetFree(p);
	return 1;
Err:
	packetFree(p);
	return 0;
}

int
vtRead(VtSession *z, uchar score[VtScoreSize], int type, uchar *buf, int n)
{
	Packet *p;

	p = vtReadPacket(z, score, type, n);
	if(p == nil)
		return -1;
	n = packetSize(p);
	packetCopy(p, buf, 0, n);
	packetFree(p);
	return n;
}

Packet *
vtReadPacket(VtSession *z, uchar score[VtScoreSize], int type, int n)
{
	Packet *p;
	uchar buf[10];

	if(n < 0 || n > VtMaxLumpSize) {
		vtSetError(ELumpSize);
		return nil;
	}

	p = packetAlloc();
	if(memcmp(score, vtZeroScore, VtScoreSize) == 0)
		return p;

	packetAppend(p, score, VtScoreSize);
	buf[0] = type;
	buf[1] = 0;	/* pad */
	buf[2] = n >> 8;
	buf[3] = n;
	packetAppend(p, buf, 4);
	return vtRPC(z, VtQRead, p);
}


static Packet *
vtRPC(VtSession *z, int op, Packet *p)
{
	uchar *hdr, buf[2];
	char *err;

	if(z == nil){
		vtSetError(ENotConnected);
		return nil;
	}

	/*
	 * single threaded for the momment
	 */
	vtLock(z->lk);
	if(z->cstate != VtStateConnected){
		vtSetError(ENotConnected);
		goto Err;
	}
	hdr = packetHeader(p, 2);
	hdr[0] = op;	/* op */
	hdr[1] = 0;	/* tid */
	vtDebug(z, "client send: ");
	vtDebugMesg(z, p, "\n");
	if(!vtSendPacket(z, p)) {
		p = nil;
		goto Err;
	}
	p = vtRecvPacket(z);
	if(p == nil)
		goto Err;
	vtDebug(z, "client recv: ");
	vtDebugMesg(z, p, "\n");
	if(!packetConsume(p, buf, 2))
		goto Err;
	if(buf[0] == VtRError) {
		if(!vtGetString(p, &err)) {
			vtSetError(EProtocolBotch);
			goto Err;
		}
		vtSetError(err);
		vtMemFree(err);
		packetFree(p);
		vtUnlock(z->lk);
		return nil;
	}
	if(buf[0] != op+1 || buf[1] != 0) {
		vtSetError(EProtocolBotch);
		goto Err;
	}
	vtUnlock(z->lk);
	return p;
Err:
	vtDebug(z, "vtRPC failed: %s\n", vtGetError());
	if(p != nil)
		packetFree(p);
	vtUnlock(z->lk);
	vtDisconnect(z, 1);
	return nil;
}