]> granicus.if.org Git - pdns/commitdiff
move full-blown djbdns-style socket-per-query support
authorBert Hubert <bert.hubert@netherlabs.nl>
Fri, 14 Apr 2006 18:00:12 +0000 (18:00 +0000)
committerBert Hubert <bert.hubert@netherlabs.nl>
Fri, 14 Apr 2006 18:00:12 +0000 (18:00 +0000)
add --single-socket to fall back to old behaviour
make --query-local-port set that feature

git-svn-id: svn://svn.powerdns.com/pdns/trunk/pdns@702 d19b8d6e-7fed-0310-83ef-9ca221ded41b

pdns/lwres.cc
pdns/lwres.hh
pdns/pdns_recursor.cc
pdns/syncres.hh

index 93a18b042846fa7652710f77c56a6f5e974eb152..cf50b063313b2b101004cab0d587cebd60670493 100644 (file)
@@ -82,13 +82,14 @@ int LWRes::asyncresolve(uint32_t ip, const string& domain, int type, bool doTCP,
   dt.setTimeval(*now);
 
   if(!doTCP) {
-    if(asendto((const char*)&*vpacket.begin(), vpacket.size(), 0, (struct sockaddr*)(&toaddr), sizeof(toaddr), pw.getHeader()->id)<0) {
+    int queryfd;
+    if(asendto((const char*)&*vpacket.begin(), vpacket.size(), 0, (struct sockaddr*)(&toaddr), sizeof(toaddr), pw.getHeader()->id, &queryfd) < 0) {
       return -1;
     }
   
     // sleep until we see an answer to this, interface to mtasker
     
-    ret=arecvfrom(reinterpret_cast<char *>(d_buf), d_bufsize-1,0,(struct sockaddr*)(&toaddr), &addrlen, &d_len, pw.getHeader()->id, domain);
+    ret=arecvfrom(reinterpret_cast<char *>(d_buf), d_bufsize-1,0,(struct sockaddr*)(&toaddr), &addrlen, &d_len, pw.getHeader()->id, domain, queryfd);
   }
   else {
     Socket s(InterNetwork, Stream);
index 66d999072dfe6b89636ae1fcccab2444a41ad2d2..078f346be3c6ac8b095eb65f96429622d8d10e18 100644 (file)
@@ -43,8 +43,8 @@
 #include "dns.hh"
 using namespace std;
 
-int asendto(const char *data, int len, int flags, struct sockaddr *toaddr, int addrlen, int id);
-int arecvfrom(char *data, int len, int flags, struct sockaddr *toaddr, Utility::socklen_t *addrlen, int *d_len, int id, const string& domain);
+int asendto(const char *data, int len, int flags, struct sockaddr *toaddr, int addrlen, int id, int* fd);
+int arecvfrom(char *data, int len, int flags, struct sockaddr *toaddr, Utility::socklen_t *addrlen, int *d_len, int id, const string& domain, int fd);
 
 class LWResException : public AhuException
 {
index 418d51afe398566d48ee9a403c78605270411c4f..38befcfd649204f6ecfa5e5a24fd88ae7549bbd8 100644 (file)
@@ -118,9 +118,8 @@ ArgvMap &arg()
   static ArgvMap theArg;
   return theArg;
 }
-static int d_clientsock;
-static int d_prevclientsock;
-static vector<int> d_udpserversocks;
+
+static vector<int> g_udpserversocks;
 
 typedef vector<int> tcpserversocks_t;
 static tcpserversocks_t s_tcpserversocks;
@@ -163,22 +162,125 @@ int arecvtcp(string& data, int len, Socket* sock)
   return ret;
 }
 
+int makeClientSocket()
+{
+  int ret=socket(AF_INET, SOCK_DGRAM, 0);
+  if(ret<0) 
+    throw AhuException("Making a socket for resolver: "+stringerror());
+
+  static optional<struct sockaddr_in> sin;
+  int queryPort=0;
+  if(!sin) {
+    struct sockaddr_in tmp;
+    sin=tmp;
+    memset((char *)&*sin,0, sizeof(sin));
+    sin->sin_family = AF_INET;
+    
+    if(!IpToU32(::arg()["query-local-address"], &sin->sin_addr.s_addr))
+      throw AhuException("Unable to resolve local address '"+ ::arg()["query-local-address"] +"'"); 
+
+    queryPort=::arg().asNum("query-local-port");
+  }
+
+  int tries=100;
+  while(--tries) {
+    uint16_t port=10000+Utility::random()%50000;
+    if(queryPort) {
+      port=queryPort;
+      tries=1;
+    }
+    sin->sin_port = htons(port); 
+    
+    if (::bind(ret, (struct sockaddr *)&*sin, sizeof(sin)) >= 0) 
+      break;
+  }
+  if(!tries)
+    throw AhuException("Resolver binding to local query client socket: "+stringerror());
+
+  Utility::setNonBlocking(ret);
+  return ret;
+}
+
+// you can ask this class for a UDP socket to send a query from
+// this socket is not yours, don't even think about deleting it
+// but after you call 'returnSocket' on it, don't assume anything anymore
+class UDPClientSocks
+{
+  bool d_passthrough;
+  unsigned int d_numsocks;
+  unsigned int d_maxsocks;
+public:
+  UDPClientSocks() : d_passthrough(false) , d_numsocks(0), d_maxsocks(500)
+  {
+  }
+
+  typedef map<int,int> socks_t;
+  socks_t d_socks;
+
+  void setPassthrough(bool state)
+  {
+    if((d_passthrough=state)) {
+      pair<int, int> sock=make_pair(makeClientSocket(), 1);
+      d_socks.insert(sock);
+    }
+  }
+  
+  int getSocket()
+  {
+    if(d_passthrough)
+      return d_socks.begin()->first;
+
+    if(d_numsocks < d_maxsocks) {
+      pair<int, int> sock=make_pair(makeClientSocket(), 1);
+      d_socks.insert(sock);
+      d_numsocks++;
+      return sock.first;
+    }
+    else {
+      socks_t::iterator pos=d_socks.begin();
+      advance(pos, random() % d_socks.size());
+      pos->second++;
+      return pos->first;
+    }
+  }
+
+  // return a socket to the pool, or simply erase it
+  void returnSocket(socks_t::iterator i)
+  {
+    if(d_passthrough) {
+      ++i;
+      return;
+    }
+
+    if(!--i->second) {
+      ::close(i->first);
+      d_socks.erase(i++);
+      --d_numsocks;
+    }
+    else {
+      ++i;
+    }
+  }
+}g_udpclientsocks;
+
 
 /* these two functions are used by LWRes */
 // -1 is error, > 1 is success
-int asendto(const char *data, int len, int flags, struct sockaddr *toaddr, int addrlen, int id) 
+int asendto(const char *data, int len, int flags, struct sockaddr *toaddr, int addrlen, int id, int* fd
 {
-  return sendto(d_clientsock, data, len, flags, toaddr, addrlen);
+  *fd=g_udpclientsocks.getSocket();
+  return sendto(*fd, data, len, flags, toaddr, addrlen);
 }
 
 // -1 is error, 0 is timeout, 1 is success
-int arecvfrom(char *data, int len, int flags, struct sockaddr *toaddr, Utility::socklen_t *addrlen, int *d_len, int id, const string& domain)
+int arecvfrom(char *data, int len, int flags, struct sockaddr *toaddr, Utility::socklen_t *addrlen, int *d_len, int id, const string& domain, int fd)
 {
   static optional<unsigned int> nearMissLimit;
   if(!nearMissLimit) 
     nearMissLimit=::arg().asNum("spoof-nearmiss-max");
 
   PacketID pident;
+  pident.fd=fd;
   pident.id=id;
   pident.domain=domain;
   memcpy(&pident.remote, toaddr, sizeof(pident.remote));
@@ -462,63 +564,6 @@ void makeControlChannelSocket()
   s_rcc.listen(sockname);
 }
 
-// this stuff is a tad complicated. There are two client sockets, the current one and the previous one (prevclientsocket)
-// if this function is called, and more than 5 seconds have passed since the previous call, the previous client socket is closed,
-// and replaced by the current one, which is then reopened.
-void remakeClientSocket()
-{
-  static time_t lastChange;
-
-  if(d_clientsock>=0 && !::arg()["query-local-port"].empty()) // already have a port, and we are fixed
-    return;
-
-  if(!lastChange)
-    lastChange=time(0)-10;
-
-  if(lastChange > time(0) - 5)
-    return;
-
-  lastChange=time(0);
-
-  if(d_prevclientsock >= 0) {
-    close(d_prevclientsock);
-  }
-  d_prevclientsock=d_clientsock;  
-
-  d_clientsock=socket(AF_INET, SOCK_DGRAM,0);
-  if(d_clientsock<0) 
-    throw AhuException("Making a socket for resolver: "+stringerror());
-  setReceiveBuffer(d_clientsock, 200000);  
-  struct sockaddr_in sin;
-  memset((char *)&sin,0, sizeof(sin));
-  
-  sin.sin_family = AF_INET;
-
-  if(!IpToU32(::arg()["query-local-address"], &sin.sin_addr.s_addr))
-    throw AhuException("Unable to resolve local address '"+ ::arg()["query-local-address"] +"'"); 
-
-  int tries=10;
-  while(--tries) {
-    uint16_t port;
-    if(::arg()["query-local-port"].empty())
-      port=10000+Utility::random()%50000;
-    else {
-      port=::arg().asNum("query-local-port");
-      tries=1;
-    }
-    sin.sin_port = htons(port); 
-    
-    if (::bind(d_clientsock, (struct sockaddr *)&sin, sizeof(sin)) >= 0) 
-      break;
-    
-  }
-  if(!tries)
-    throw AhuException("Resolver binding to local query client socket: "+stringerror());
-
-  Utility::setNonBlocking(d_clientsock);
-  
-  //  L<<Logger::Error<<"Sending UDP queries from "<<inet_ntoa(sin.sin_addr)<<":"<< ntohs(sin.sin_port)  <<endl;
-}
 
 void makeTCPServerSockets()
 {
@@ -595,7 +640,7 @@ void makeUDPServerSockets()
       throw AhuException("Resolver binding to server socket for "+*i+": "+stringerror());
     
     Utility::setNonBlocking(fd);
-    d_udpserversocks.push_back(fd);
+    g_udpserversocks.push_back(fd);
     L<<Logger::Error<<"Listening for UDP queries on "<<inet_ntoa(sin.sin_addr)<<":"<<::arg().asNum("local-port")<<endl;
   }
 }
@@ -753,7 +798,7 @@ string questionExpand(const char* packet, uint16_t len)
   const char* pos=packet+12;
   unsigned char labellen;
   string ret;
-
+  ret.reserve(len-12);
   while((labellen=*pos++)) {
     if(pos+labellen > end)
       break;
@@ -804,6 +849,7 @@ int main(int argc, char **argv)
     ::arg().set("max-tcp-per-client", "If set, maximum number of TCP sessions per client (IP address)")="0";
     ::arg().set("fork", "If set, fork the daemon for possible double performance")="no";
     ::arg().set("spoof-nearmiss-max", "If non-zero, assume spoofing after this many near misses")="20";
+    ::arg().set("single-socket", "If set, only use a single socket for outgoing queries")="off";
 
     ::arg().setCmd("help","Provide a helpful message");
     L.toConsole(Logger::Warning);
@@ -871,8 +917,7 @@ int main(int argc, char **argv)
       fork();
       L<<Logger::Warning<<"This is forked pid "<<getpid()<<endl;
     }
-    d_clientsock=d_prevclientsock=-1;
-    remakeClientSocket();
+
 
     makeControlChannelSocket();
     
@@ -921,11 +966,13 @@ int main(int argc, char **argv)
 
     unsigned int maxTCPPerClient=::arg().asNum("max-tcp-per-client");
 
+    if(::arg().parmIsset("query-local-port") || ::arg().mustDo("single-socket"))
+      g_udpclientsocks.setPassthrough(true);
+
     for(;;) {
       while(MT->schedule()); // housekeeping, let threads do their thing
       
       if(!((counter++)%500)) {
-       remakeClientSocket();
        MT->makeThread(houseKeeping,0);
       }
       if(statsWanted) {
@@ -942,12 +989,13 @@ int main(int argc, char **argv)
       fd_set readfds, writefds;
       FD_ZERO( &readfds );
       FD_ZERO( &writefds );
-      FD_SET( d_clientsock, &readfds );
-      if(d_prevclientsock >= 0)
-       FD_SET( d_prevclientsock, &readfds );
 
       FD_SET( s_rcc.d_fd, &readfds);
-      int fdmax=max(d_clientsock, s_rcc.d_fd);
+      int fdmax=s_rcc.d_fd;
+      for(UDPClientSocks::socks_t::iterator i=g_udpclientsocks.d_socks.begin(); i!=g_udpclientsocks.d_socks.end(); ++i) {
+       FD_SET( i->first, &readfds );
+       fdmax=max(fdmax, i->first);
+      }
 
       if(!g_tcpconnections.empty())
        gettimeofday(&now, 0);
@@ -970,7 +1018,7 @@ int main(int argc, char **argv)
       }
       sweeped.swap(g_tcpconnections);
 
-      for(vector<int>::const_iterator i=d_udpserversocks.begin(); i!=d_udpserversocks.end(); ++i) {
+      for(vector<int>::const_iterator i=g_udpserversocks.begin(); i!=g_udpserversocks.end(); ++i) {
        FD_SET( *i, &readfds );
        fdmax=max(fdmax,*i);
       }
@@ -1010,20 +1058,20 @@ int main(int argc, char **argv)
        command();
       }
       
-      for(int port=0; port < 2; ++port) {
-       if(port && d_prevclientsock < 0)
-         break;
-       int sock = port ? d_prevclientsock : d_clientsock;
-         
-       if(FD_ISSET(sock,&readfds)) { // do we have a UDP question response from a server ("we are the client", hence d_clientsock)
-         while((d_len=recvfrom(sock, data, sizeof(data), 0, (sockaddr *)&fromaddr, &addrlen)) >= 0) {
-           dnsheader dh;
+      for(UDPClientSocks::socks_t::iterator i=g_udpclientsocks.d_socks.begin(); i!=g_udpclientsocks.d_socks.end(); i++ ) {
+       if(FD_ISSET(i->first, &readfds)) { // do we have a UDP question response from a server ("we are the client", hence clientsock)
+         while((d_len=recvfrom(i->first, data, sizeof(data), 0, (sockaddr *)&fromaddr, &addrlen)) >= 0) {
            if((size_t) d_len >= sizeof(dnsheader)) {
+             dnsheader dh;
              memcpy(&dh, data, sizeof(dh));
              
-             if(dh.qr && dh.qdcount) {
+             if(!dh.qdcount) // UPC, Nominum?
+               continue; 
+
+             if(dh.qr) {
                pident.remote=fromaddr;
                pident.id=dh.id;
+               pident.fd=i->first;
                pident.domain=questionExpand(data, d_len);
                string packet;
                packet.assign(data, d_len);
@@ -1039,7 +1087,7 @@ int main(int argc, char **argv)
                  }
                }
              }
-             else 
+             else
                L<<Logger::Warning<<"Ignoring question on outgoing socket from "<< sockAddrToString((struct sockaddr_in*) &fromaddr, addrlen)  <<endl;
            }
            else {
@@ -1048,10 +1096,11 @@ int main(int argc, char **argv)
                L<<Logger::Error<<"Unable to parse packet from remote UDP server "<< sockAddrToString((struct sockaddr_in*) &fromaddr, addrlen) <<": packet too small"<<endl;
            }
          }
+         g_udpclientsocks.returnSocket(i);
        }
       }
       
-      for(vector<int>::const_iterator i=d_udpserversocks.begin(); i!=d_udpserversocks.end(); ++i) {
+      for(vector<int>::const_iterator i=g_udpserversocks.begin(); i!=g_udpserversocks.end(); ++i) {
        if(FD_ISSET(*i,&readfds)) { // do we have a new question on udp?
          while((d_len=recvfrom(*i, data, sizeof(data), 0, (sockaddr *)&fromaddr, &addrlen)) >= 0) {
            //      g_stats.queryrate.pulse(now);  // (broken)
index 07ae33468c82627d2aa197ff9a4ee39f63852b38..becb1906e96cdf652a071a4d1e40398ab3732765 100644 (file)
@@ -317,7 +317,7 @@ int arecvtcp(string& data, int len, Socket* sock);
 
 struct PacketID
 {
-  PacketID() : sock(0), inNeeded(0), outPos(0), nearMisses(0)
+  PacketID() : sock(0), inNeeded(0), outPos(0), nearMisses(0), fd(-1)
   {}
 
   uint16_t id;  // wait for a specific id/remote pair
@@ -332,16 +332,17 @@ struct PacketID
   string::size_type outPos;    // how far we are along in the outMSG
 
   mutable uint32_t nearMisses; // number of near misses - host correct, id wrong
+  int fd;
 
   bool operator<(const PacketID& b) const
   {
     int ourSock= sock ? sock->getHandle() : 0;
     int bSock = b.sock ? b.sock->getHandle() : 0;
-    if( tie(id, remote.sin_addr.s_addr, remote.sin_port, ourSock) <
-        tie(b.id, b.remote.sin_addr.s_addr, b.remote.sin_port, bSock))
+    if( tie(id, remote.sin_addr.s_addr, remote.sin_port, fd, ourSock) <
+        tie(b.id, b.remote.sin_addr.s_addr, b.remote.sin_port, b.fd, bSock))
       return true;
-    if( tie(id, remote.sin_addr.s_addr, remote.sin_port, ourSock) >
-        tie(b.id, b.remote.sin_addr.s_addr, b.remote.sin_port, bSock))
+    if( tie(id, remote.sin_addr.s_addr, remote.sin_port, fd, ourSock) >
+        tie(b.id, b.remote.sin_addr.s_addr, b.remote.sin_port, b.fd, bSock))
       return false;
 
     return strcasecmp(domain.c_str(), b.domain.c_str()) < 0;