ref: 4dbb99f478d68eaeea4c043a839454cb807c2c7d
dir: /sys/src/cmd/ndb/dntcpserver.c/
#include <u.h> #include <libc.h> #include <bio.h> #include <ndb.h> #include <ip.h> #include <mp.h> #include <libsec.h> #include "dns.h" enum { Maxprocs = 64, }; static int readmsg(int, uchar*, int); static int reply(int, uchar *, DNSmsg*, Request*, uchar*); static int dnzone(int, uchar *, DNSmsg*, DNSmsg*, Request*, uchar*); static int tcpannounce(char *mntpt, char *addr, char caller[128], char *cert); void dntcpserver(char *mntpt, char *addr, char *cert) { volatile int fd, len, rcode, rv; volatile long ms; volatile char caller[128]; volatile uchar pkt[Maxpkt], callip[IPaddrlen]; volatile DNSmsg reqmsg, repmsg; volatile Request req; volatile RR *edns; char *volatile err; /* * fork sharing text, data, and bss with parent. * stay in the same note group. */ switch(rfork(RFPROC|RFMEM|RFNOWAIT)){ case -1: break; case 0: break; default: return; } procsetname("%s: tcp server %s", mntpt, addr); if((fd = tcpannounce(mntpt, addr, caller, cert)) < 0){ warning("can't announce %s on %s: %r", addr, mntpt); _exits(0); } parseip(callip, caller); procsetname("%s: tcp server %s serving %s", mntpt, addr, caller); memset(&req, 0, sizeof req); req.isslave = 1; req.from = caller; req.aborttime = timems() + Maxreqtm; /* loop on requests */ for(;; putactivity(&req)){ memset(&reqmsg, 0, sizeof reqmsg); edns = nil; ms = (long)(req.aborttime - nowms); if(ms < Minreqtm){ hangup: close(fd); _exits(0); } alarm(ms); if(readn(fd, pkt, 2) != 2){ alarm(0); goto hangup; } len = pkt[0]<<8 | pkt[1]; if(len <= 0 || len > Maxtcp || readn(fd, pkt+2, len) != len){ alarm(0); goto hangup; } alarm(0); getactivity(&req); if((long)(req.aborttime - timems()) < Minreqtm) break; stats.qrecvdtcp++; rcode = Rok; err = convM2DNS(pkt+2, len, &reqmsg, &rcode); if(err){ dnslog("%d: server: input err, len %d: %s from %s", req.id, len, err, caller); free(err); break; } if(rcode == Rok) if(reqmsg.qdcount < 1){ dnslog("%d: server: no questions from %s", req.id, caller); break; } else if(reqmsg.flags & Fresp){ dnslog("%d: server: reply not request from %s", req.id, caller); break; } else if((reqmsg.flags & Omask) != Oquery){ dnslog("%d: server: op %d from %s", req.id, reqmsg.flags & Omask, caller); break; } if(reqmsg.qd == nil){ dnslog("%d: server: no question RR from %s", req.id, caller); break; } logrequest(req.id, 0, "rcvd", callip, caller, reqmsg.qd->owner->name, reqmsg.qd->type); if(rcode == Rok) if((reqmsg.edns = getednsopt(&reqmsg, &rcode)) != nil) edns = mkednsopt(); /* loop through each question */ while(reqmsg.qd){ memset(&repmsg, 0, sizeof(repmsg)); repmsg.edns = edns; if(rcode == Rok && reqmsg.qd->type == Taxfr) rv = dnzone(fd, pkt, &reqmsg, &repmsg, &req, callip); else { dnserver(&reqmsg, &repmsg, &req, callip, rcode); rv = reply(fd, pkt, &repmsg, &req, callip); freeanswers(&repmsg); } if(rv < 0) goto out; } rrfreelist(edns); rrfreelist(reqmsg.edns); freeanswers(&reqmsg); } out: close(fd); rrfreelist(edns); rrfreelist(reqmsg.edns); freeanswers(&reqmsg); putactivity(&req); _exits(0); } static int reply(int fd, uchar *pkt, DNSmsg *rep, Request *req, uchar *callip) { int len, rv; long ms; /* taking too long */ ms = (long)(req->aborttime - nowms); if(ms < 1) return -1; logreply(req->id, "send", callip, rep); len = convDNS2M(rep, pkt+2, Maxtcp); pkt[0] = len>>8; pkt[1] = len; len += 2; alarm(ms); rv = write(fd, pkt, len); alarm(0); if(rv != len){ dnslog("%d: error sending reply to %I: %r", req->id, callip); rv = -1; } return rv; } static Server* findserver(uchar *callip, Server *servers, Request *req) { uchar ip[IPaddrlen]; RR *list, *rp; for(; servers != nil; servers = servers->next){ if(strcmp(servers->name, "*") == 0) return servers; if(strcmp(ipattr(servers->name), "ip") == 0){ if(parseip(ip, servers->name) == -1) continue; if(ipcmp(callip, ip) == 0) return servers; continue; } list = dnresolve(servers->name, Cin, isv4(callip)? Ta: Taaaa, req, nil, 0, Recurse, 0, nil); rrfreelist(rrremneg(&list)); for(rp = list; rp != nil; rp = rp->next){ if(parseip(ip, rp->ip->name) == -1) continue; if(ipcmp(callip, ip) == 0) break; } rrfreelist(list); if(rp != nil) return servers; } return nil; } static int dnzone(int fd, uchar *pkt, DNSmsg *reqp, DNSmsg *repp, Request *req, uchar *callip) { DN *dp; RR *rp; int rv; repp->id = reqp->id; repp->qd = reqp->qd; reqp->qd = reqp->qd->next; repp->qd->next = nil; repp->flags = Fauth | Fresp | Oquery; setercode(repp, Rok); dp = repp->qd->owner; /* send the soa */ repp->an = rrlookup(dp, Tsoa, NOneg); if(repp->an != nil && !myip(callip) && findserver(callip, repp->an->soa->slaves, req) == nil){ dnslog("%d: dnzone: %I axfr %s - not a dnsslave", req->id, callip, dp->name); rrfreelist(repp->an); repp->an = nil; } rv = reply(fd, pkt, repp, req, callip); if(rv < 0 || repp->an == nil) goto out; rrfreelist(repp->an); repp->an = nil; repp->an = rrgetzone(dp->name); while(repp->an != nil) { rp = repp->an->next; repp->an->next = nil; rv = reply(fd, pkt, repp, req, callip); rrfreelist(repp->an); repp->an = rp; if(rv < 0) goto out; } /* resend the soa */ repp->an = rrlookup(dp, Tsoa, NOneg); rv = reply(fd, pkt, repp, req, callip); out: rrfreelist(repp->an); repp->an = nil; rrfree(repp->qd); repp->qd = nil; return rv; } static int tcpannounce(char *mntpt, char *addr, char caller[128], char *cert) { char adir[NETPATHLEN], ldir[NETPATHLEN], buf[128]; int acfd, lcfd, dfd, wfd, rfd, procs; PEMChain *chain = nil; if(cert != nil){ chain = readcertchain(cert); if(chain == nil) return -1; } /* announce tcp dns port */ snprint(buf, sizeof(buf), "%s/tcp!%s!%s", mntpt, addr, cert == nil ? "53" : "853"); acfd = announce(buf, adir); if(acfd < 0) return -1; /* open wait file to maintain child process count */ snprint(buf, sizeof(buf), "/proc/%d/wait", getpid()); wfd = open(buf, OREAD|OCEXEC); if(wfd < 0){ close(acfd); return -1; } procs = 0; for(;;) { if(procs >= Maxprocs || (procs % 8) == 0){ while(procs > 0){ if(procs < Maxprocs){ Dir *d = dirfstat(wfd); if(d == nil || d->length == 0){ free(d); break; } free(d); } if(read(wfd, buf, sizeof(buf)) <= 0){ procs = 0; break; } procs--; } } lcfd = listen(adir, ldir); if(lcfd < 0){ close(wfd); close(acfd); return -1; } switch(rfork(RFPROC|RFMEM)){ case -1: close(lcfd); break; case 0: dfd = accept(lcfd, ldir); close(lcfd); if(dfd < 0) _exits(0); if(chain != nil){ TLSconn conn; int fd; memset(&conn, 0, sizeof conn); conn.cert = emalloc(conn.certlen = chain->pemlen); memmove(conn.cert, chain->pem, conn.certlen); conn.chain = chain->next; fd = tlsServer(dfd, &conn); if(fd < 0){ close(dfd); _exits(0); } free(conn.cert); free(conn.sessionID); dfd = fd; } /* get the callers ip!port */ memset(caller, 0, 128); snprint(buf, sizeof(buf), "%s/remote", ldir); if((rfd = open(buf, OREAD|OCEXEC)) >= 0){ read(rfd, caller, 128-1); close(rfd); } /* child returns */ return dfd; default: procs++; } } }