]> granicus.if.org Git - pdns/commitdiff
auth: Add TCP management options described in section 10 of rfc7766
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 19 Sep 2016 15:09:35 +0000 (17:09 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 20 Dec 2016 08:47:37 +0000 (09:47 +0100)
* max-tcp-connection-duration
* max-tcp-connections-per-client
* max-tcp-transactions-per-conn
* tcp-idle-timeout

docs/markdown/authoritative/settings.md
pdns/common_startup.cc
pdns/tcpreceiver.cc
pdns/tcpreceiver.hh

index f62906a85b35234290f5aee90737029015d602f1..0daefa5fb46c10aaffd1e1293f9e5015a6a73a4d 100644 (file)
@@ -496,12 +496,33 @@ hopeless and respawn.
 
 Maximum number of signatures cache entries
 
+## `max-tcp-connection-duration`
+* Integer
+* Default: 0
+
+Maximum time in seconds that a TCP DNS connection is allowed to stay open.
+0 means unlimited.
+Note that exchanges related to an AXFR or IXFR are not affected by this setting.
+
 ## `max-tcp-connections`
 * Integer
 * Default: 20
 
 Allow this many incoming TCP DNS connections simultaneously.
 
+## `max-tcp-connections-per-client`
+* Integer
+* Default: 0
+
+Maximum number of simultaneous TCP connections per client. 0 means unlimited.
+
+## `max-tcp-transactions-per-conn`
+* Integer
+* Default: 0
+
+Allow this many DNS queries in a single TCP transaction. 0 means unlimited.
+Note that exchanges related to an AXFR or IXFR are not affected by this setting.
+
 ## `module-dir`
 * Path
 
@@ -759,6 +780,13 @@ Limit TCP control to a specific client range.
 
 Password for TCP control.
 
+## `tcp-idle-timeout`
+* Integer
+* Default: 5
+
+Maximum time in seconds that a TCP DNS connection is allowed to stay open
+while being idle, meaning without PowerDNS receiving or sending even a single byte.
+
 ## `traceback-handler`
 * Boolean
 * Default: yes
index dd7729523d6574b99dd5235ed4234198bda5cb58..4972f217a856698b345e9a523961e0bb607111e1 100644 (file)
@@ -162,6 +162,11 @@ void declareArguments()
 
   ::arg().set("default-ttl","Seconds a result is valid if not set otherwise")="3600";
   ::arg().set("max-tcp-connections","Maximum number of TCP connections")="20";
+  ::arg().set("max-tcp-connections-per-client","Maximum number of simultaneous TCP connections per client")="0";
+  ::arg().set("max-tcp-transactions-per-conn")="0";
+  ::arg().set("max-tcp-connection-duration")="0";
+  ::arg().set("tcp-idle-timeout")="5";
+
   ::arg().setSwitch("no-shuffle","Set this to prevent random shuffling of answers - for regression testing")="off";
 
   ::arg().set("setuid","If set, change user id to this uid for more security")="";
index 1aa6cc647cc07a9d1d769695a0b6540d5d4b54f1..e29de4aaffaaf906ed9bd34db3f2ae66d44058fd 100644 (file)
@@ -68,6 +68,12 @@ pthread_mutex_t TCPNameserver::s_plock = PTHREAD_MUTEX_INITIALIZER;
 Semaphore *TCPNameserver::d_connectionroom_sem;
 PacketHandler *TCPNameserver::s_P; 
 NetmaskGroup TCPNameserver::d_ng;
+size_t TCPNameserver::d_maxTransactionsPerConn;
+size_t TCPNameserver::d_maxConnectionsPerClient;
+unsigned int TCPNameserver::d_idleTimeout;
+unsigned int TCPNameserver::d_maxConnectionDuration;
+std::mutex TCPNameserver::s_clientsCountMutex;
+std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> TCPNameserver::s_clientsCount;
 
 void TCPNameserver::go()
 {
@@ -89,16 +95,21 @@ void *TCPNameserver::launcher(void *data)
 }
 
 // throws PDNSException if things didn't go according to plan, returns 0 if really 0 bytes were read
-int readnWithTimeout(int fd, void* buffer, unsigned int n, bool throwOnEOF=true)
+static int readnWithTimeout(int fd, void* buffer, unsigned int n, unsigned int idleTimeout, bool throwOnEOF=true, unsigned int totalTimeout=0)
 {
   unsigned int bytes=n;
   char *ptr = (char*)buffer;
   int ret;
+  time_t start = 0;
+  unsigned int remainingTotal = totalTimeout;
+  if (totalTimeout) {
+    start = time(NULL);
+  }
   while(bytes) {
     ret=read(fd, ptr, bytes);
     if(ret < 0) {
       if(errno==EAGAIN) {
-        ret=waitForData(fd, 5);
+        ret=waitForData(fd, (totalTimeout == 0 || idleTimeout <= remainingTotal) ? idleTimeout : remainingTotal);
         if(ret < 0)
           throw NetworkError("Waiting for data read");
         if(!ret)
@@ -117,12 +128,21 @@ int readnWithTimeout(int fd, void* buffer, unsigned int n, bool throwOnEOF=true)
     
     ptr += ret;
     bytes -= ret;
+    if (totalTimeout) {
+      time_t now = time(NULL);
+      unsigned int elapsed = now - start;
+      if (elapsed >= remainingTotal) {
+        throw NetworkError("Timeout while reading data");
+      }
+      start = now;
+      remainingTotal -= elapsed;
+    }
   }
   return n;
 }
 
 // ditto
-void writenWithTimeout(int fd, const void *buffer, unsigned int n)
+static void writenWithTimeout(int fd, const void *buffer, unsigned int n, unsigned int idleTimeout)
 {
   unsigned int bytes=n;
   const char *ptr = (char*)buffer;
@@ -131,7 +151,7 @@ void writenWithTimeout(int fd, const void *buffer, unsigned int n)
     ret=write(fd, ptr, bytes);
     if(ret < 0) {
       if(errno==EAGAIN) {
-        ret=waitForRWData(fd, false, 5, 0);
+        ret=waitForRWData(fd, false, idleTimeout, 0);
         if(ret < 0)
           throw NetworkError("Waiting for data write");
         if(!ret)
@@ -184,20 +204,20 @@ void TCPNameserver::sendPacket(shared_ptr<DNSPacket> p, int outsock)
   uint16_t len=htons(p->getString().length());
   string buffer((const char*)&len, 2);
   buffer.append(p->getString());
-  writenWithTimeout(outsock, buffer.c_str(), buffer.length());
+  writenWithTimeout(outsock, buffer.c_str(), buffer.length(), d_idleTimeout);
 }
 
 
-void TCPNameserver::getQuestion(int fd, char *mesg, int pktlen, const ComboAddress &remote)
+void TCPNameserver::getQuestion(int fd, char *mesg, int pktlen, const ComboAddress &remote, unsigned int totalTime)
 try
 {
-  readnWithTimeout(fd, mesg, pktlen);
+  readnWithTimeout(fd, mesg, pktlen, d_idleTimeout, true, totalTime);
 }
 catch(NetworkError& ae) {
   throw NetworkError("Error reading DNS data from TCP client "+remote.toString()+": "+ae.what());
 }
 
-static void proxyQuestion(shared_ptr<DNSPacket> packet)
+static void proxyQuestion(shared_ptr<DNSPacket> packet, unsigned int idleTimeout)
 {
   int sock=socket(AF_INET, SOCK_STREAM, 0);
   
@@ -217,19 +237,19 @@ static void proxyQuestion(shared_ptr<DNSPacket> packet)
     
     uint16_t len=htons(buffer.length()), slen;
     
-    writenWithTimeout(sock, &len, 2);
-    writenWithTimeout(sock, buffer.c_str(), buffer.length());
+    writenWithTimeout(sock, &len, 2, idleTimeout);
+    writenWithTimeout(sock, buffer.c_str(), buffer.length(), idleTimeout);
     
-    readnWithTimeout(sock, &len, 2);
+    readnWithTimeout(sock, &len, 2, idleTimeout);
     len=ntohs(len);
 
     char answer[len];
-    readnWithTimeout(sock, answer, len);
+    readnWithTimeout(sock, answer, len, idleTimeout);
 
     slen=htons(len);
-    writenWithTimeout(packet->getSocket(), &slen, 2);
+    writenWithTimeout(packet->getSocket(), &slen, 2, idleTimeout);
     
-    writenWithTimeout(packet->getSocket(), answer, len);
+    writenWithTimeout(packet->getSocket(), answer, len, idleTimeout);
   }
   catch(NetworkError& ae) {
     close(sock);
@@ -248,6 +268,30 @@ static void incTCPAnswerCount(const ComboAddress& remote)
   else
     S.inc("tcp4-answers");
 }
+
+static bool maxConnectionDurationReached(unsigned int maxConnectionDuration, time_t start, unsigned int& remainingTime)
+{
+  if (maxConnectionDuration) {
+    time_t elapsed = time(NULL) - start;
+    if (elapsed >= maxConnectionDuration) {
+      return true;
+    }
+    remainingTime = maxConnectionDuration - elapsed;
+  }
+  return false;
+}
+
+void TCPNameserver::decrementClientCount(const ComboAddress& remote)
+{
+  if (d_maxConnectionsPerClient) {
+    std::lock_guard<std::mutex> lock(s_clientsCountMutex);
+    s_clientsCount[remote]--;
+    if (s_clientsCount[remote] == 0) {
+      s_clientsCount.erase(remote);
+    }
+  }
+}
+
 void *TCPNameserver::doConnection(void *data)
 {
   shared_ptr<DNSPacket> packet;
@@ -255,6 +299,11 @@ void *TCPNameserver::doConnection(void *data)
   int fd=(int)(long)data; // gotta love C (generates a harmless warning on opteron)
   ComboAddress remote;
   socklen_t remotelen=sizeof(remote);
+  size_t transactions = 0;
+  time_t start = 0;
+  if (d_maxConnectionDuration) {
+    start = time(NULL);
+  }
 
   pthread_detach(pthread_self());
   if(getpeername(fd, (struct sockaddr *)&remote, &remotelen) < 0) {
@@ -277,9 +326,19 @@ void *TCPNameserver::doConnection(void *data)
     DLOG(L<<"TCP Connection accepted on fd "<<fd<<endl);
     bool logDNSQueries= ::arg().mustDo("log-dns-queries");
     for(;;) {
+      unsigned int remainingTime = 0;
+      transactions++;
+      if (d_maxTransactionsPerConn && transactions > d_maxTransactionsPerConn) {
+        L << Logger::Notice<<"TCP Remote "<< remote <<" exceeded the number of transactions per connection, dropping.";
+        break;
+      }
+      if (maxConnectionDurationReached(d_maxConnectionDuration, start, remainingTime)) {
+        L << Logger::Notice<<"TCP Remote "<< remote <<" exceeded the maximum TCP connection duration, dropping.";
+        break;
+      }
 
       uint16_t pktlen;
-      if(!readnWithTimeout(fd, &pktlen, 2, false))
+      if(!readnWithTimeout(fd, &pktlen, 2, d_idleTimeout, false, remainingTime))
         break;
       else
         pktlen=ntohs(pktlen);
@@ -296,7 +355,12 @@ void *TCPNameserver::doConnection(void *data)
         break;
       }
       
-      getQuestion(fd, mesg.get(), pktlen, remote);
+      if (maxConnectionDurationReached(d_maxConnectionDuration, start, remainingTime)) {
+        L << Logger::Notice<<"TCP Remote "<< remote <<" exceeded the maximum TCP connection duration, dropping.";
+        break;
+      }
+
+      getQuestion(fd, mesg.get(), pktlen, remote, remainingTime);
       S.inc("tcp-queries");      
       if(remote.sin4.sin_family == AF_INET6)
         S.inc("tcp6-queries");
@@ -363,7 +427,7 @@ void *TCPNameserver::doConnection(void *data)
         if(LPE) LPE->police(&(*packet), &(*reply), true);
 
         if(shouldRecurse) {
-          proxyQuestion(packet);
+          proxyQuestion(packet, d_idleTimeout);
           continue;
         }
       }
@@ -399,6 +463,7 @@ void *TCPNameserver::doConnection(void *data)
   catch(const PDNSException& e) {
     L<<Logger::Error<<"Error closing TCP socket: "<<e.reason<<endl;
   }
+  decrementClientCount(remote);
 
   return 0;
 }
@@ -1182,6 +1247,11 @@ TCPNameserver::~TCPNameserver()
 
 TCPNameserver::TCPNameserver()
 {
+  d_maxTransactionsPerConn = ::arg().asNum("max-tcp-transactions-per-conn");
+  d_idleTimeout = ::arg().asNum("tcp-idle-timeout");
+  d_maxConnectionDuration = ::arg().asNum("max-tcp-connection-duration");
+  d_maxConnectionsPerClient = ::arg().asNum("max-tcp-connections-per-client");
+
 //  sem_init(&d_connectionroom_sem,0,::arg().asNum("max-tcp-connections"));
   d_connectionroom_sem = new Semaphore( ::arg().asNum( "max-tcp-connections" ));
   d_tid=0;
@@ -1290,8 +1360,8 @@ void TCPNameserver::thread()
   try {
     for(;;) {
       int fd;
-      struct sockaddr_in remote;
-      Utility::socklen_t addrlen=sizeof(remote);
+      ComboAddress remote;
+      Utility::socklen_t addrlen=remote.getSocklen();
 
       int ret=poll(&d_prfds[0], d_prfds.size(), -1); // blocks, forever if need be
       if(ret <= 0)
@@ -1301,7 +1371,8 @@ void TCPNameserver::thread()
       for(const pollfd& pfd :  d_prfds) {
         if(pfd.revents == POLLIN) {
           sock = pfd.fd;
-          addrlen=sizeof(remote);
+          remote.sin4.sin_family = AF_INET6;
+          addrlen=remote.getSocklen();
 
           if((fd=accept(sock, (sockaddr*)&remote, &addrlen))<0) {
             L<<Logger::Error<<"TCP question accept error: "<<strerror(errno)<<endl;
@@ -1312,6 +1383,16 @@ void TCPNameserver::thread()
             }
           }
           else {
+            if (d_maxConnectionsPerClient) {
+              std::lock_guard<std::mutex> lock(s_clientsCountMutex);
+              if (s_clientsCount[remote] >= d_maxConnectionsPerClient) {
+                L<<Logger::Notice<<"Limit of simultaneous TCP connections per client reached for "<< remote<<", dropping"<<endl;
+                close(fd);
+                continue;
+              }
+              s_clientsCount[remote]++;
+            }
+
             pthread_t tid;
             d_connectionroom_sem->wait(); // blocks if no connections are available
 
@@ -1324,6 +1405,7 @@ void TCPNameserver::thread()
               L<<Logger::Error<<"Error creating thread: "<<stringerror()<<endl;
               d_connectionroom_sem->post();
               close(fd);
+              decrementClientCount(remote);
             }
           }
         }
index 0e7d3492855aba9a84a74e651fe703d9bbd37e0a..cfd6d3001678f40ec58f04e1a32c22ec1d3c599b 100644 (file)
@@ -27,7 +27,7 @@
 #include "dnsbackend.hh"
 #include "packethandler.hh"
 #include <vector>
-
+#include <mutex>
 #include <poll.h>
 #include <sys/select.h>
 #include <sys/socket.h>
@@ -51,18 +51,25 @@ private:
 
   static void sendPacket(std::shared_ptr<DNSPacket> p, int outsock);
   static int readLength(int fd, ComboAddress *remote);
-  static void getQuestion(int fd, char *mesg, int pktlen, const ComboAddress& remote);
+  static void getQuestion(int fd, char *mesg, int pktlen, const ComboAddress& remote, unsigned int totalTime);
   static int doAXFR(const DNSName &target, std::shared_ptr<DNSPacket> q, int outsock);
   static int doIXFR(std::shared_ptr<DNSPacket> q, int outsock);
   static bool canDoAXFR(std::shared_ptr<DNSPacket> q);
   static void *doConnection(void *data);
   static void *launcher(void *data);
+  static void decrementClientCount(const ComboAddress& remote);
   void thread(void);
   static pthread_mutex_t s_plock;
+  static std::mutex s_clientsCountMutex;
+  static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> s_clientsCount;
   static PacketHandler *s_P;
   pthread_t d_tid;
   static Semaphore *d_connectionroom_sem;
   static NetmaskGroup d_ng;
+  static size_t d_maxTransactionsPerConn;
+  static size_t d_maxConnectionsPerClient;
+  static unsigned int d_idleTimeout;
+  static unsigned int d_maxConnectionDuration;
 
   vector<int>d_sockets;
   vector<struct pollfd> d_prfds;