Let's start naively.
*/
-int getTCPDownstream(policy_t policy, string pool, DownstreamState** ds, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh)
+static int setupTCPDownstream(const ComboAddress& remote)
{
- {
- std::lock_guard<std::mutex> lock(g_luamutex);
- *ds = policy(getDownstreamCandidates(g_dstates.getCopy(), pool), remote, qname, qtype, dh).get();
- }
- vinfolog("TCP connecting to downstream %s", (*ds)->remote.toStringWithPort());
- int sock = SSocket((*ds)->remote.sin4.sin_family, SOCK_STREAM, 0);
- SConnect(sock, (*ds)->remote);
+ vinfolog("TCP connecting to downstream %s", remote.toStringWithPort());
+ int sock = SSocket(remote.sin4.sin_family, SOCK_STREAM, 0);
+ SConnect(sock, remote);
return sock;
}
-bool getMsgLen(int fd, uint16_t* len)
-try
-{
- uint16_t raw;
- int ret = readn2(fd, &raw, 2);
- if(ret != 2)
- return false;
- *len = ntohs(raw);
- return true;
-}
-catch(...) {
- return false;
-}
-
-bool putMsgLen(int fd, uint16_t len)
-try
-{
- uint16_t raw = htons(len);
- int ret = writen2(fd, &raw, 2);
- return ret==2;
-}
-catch(...) {
- return false;
-}
struct ConnectionInfo
{
TCPClientCollection g_tcpclientthreads;
-
-
void* tcpClientThread(int pipefd)
{
/* we get launched with a pipe on which we receive file descriptors from clients that we own
from that point on */
- int dsock = -1;
- DownstreamState *ds=0;
-
+ auto localPolicy = g_policy.getLocal();
+ map<ComboAddress,int> sockets;
for(;;) {
ConnectionInfo* citmp, ci;
readn2(pipefd, &citmp, sizeof(citmp));
--g_tcpclientthreads.d_queued;
ci=*citmp;
- delete citmp;
-
+ delete citmp;
+
uint16_t qlen, rlen;
- string pool; // empty for now
+ string pool; // empty for now, we actually should do ACL, rulactions, the works here! XXX
+ shared_ptr<DownstreamState> ds;
try {
- auto localPolicy = g_policy.getLocal();
for(;;) {
if(!getMsgLen(ci.fd, &qlen))
break;
uint16_t qtype;
DNSName qname(query, qlen, 12, false, &qtype);
struct dnsheader* dh =(dnsheader*)query;
- if(dsock == -1) {
- dsock = getTCPDownstream(localPolicy->policy, pool, &ds, ci.remote, qname, qtype, dh);
+
+ {
+ std::lock_guard<std::mutex> lock(g_luamutex);
+ ds = localPolicy->policy(getDownstreamCandidates(g_dstates.getCopy(), pool), ci.remote, qname, qtype, dh);
}
- else {
- vinfolog("Reusing existing TCP connection to %s", ds->remote.toStringWithPort());
+ int dsock;
+ if(sockets.count(ds->remote) == 0) {
+ dsock=sockets[ds->remote]=setupTCPDownstream(ds->remote);
}
+ else
+ dsock=sockets[ds->remote];
+
ds->queries++;
ds->outstanding++;
if(!putMsgLen(dsock, qlen)) {
vinfolog("Downstream connection to %s died on us, getting a new one!", ds->remote.toStringWithPort());
close(dsock);
- dsock=getTCPDownstream(localPolicy->policy, pool, &ds, ci.remote, qname, qtype, dh);
+ sockets[ds->remote]=dsock=setupTCPDownstream(ds->remote);
goto retry;
}
if(!getMsgLen(dsock, &rlen)) {
vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds->remote.toStringWithPort());
close(dsock);
- dsock=getTCPDownstream(localPolicy->policy, pool, &ds, ci.remote, qname, qtype, dh);
+ sockets[ds->remote]=dsock=setupTCPDownstream(ds->remote);
goto retry;
}
vinfolog("Closing client connection with %s", ci.remote.toStringWithPort());
close(ci.fd);
ci.fd=-1;
- --ds->outstanding;
+ if(ds)
+ --ds->outstanding;
}
return 0;
}
return 0;
}
+
+
+bool getMsgLen(int fd, uint16_t* len)
+try
+{
+ uint16_t raw;
+ int ret = readn2(fd, &raw, 2);
+ if(ret != 2)
+ return false;
+ *len = ntohs(raw);
+ return true;
+}
+catch(...) {
+ return false;
+}
+
+bool putMsgLen(int fd, uint16_t len)
+try
+{
+ uint16_t raw = htons(len);
+ int ret = writen2(fd, &raw, 2);
+ return ret==2;
+}
+catch(...) {
+ return false;
+}