]> granicus.if.org Git - pdns/commitdiff
Add support for TCP timeouts and limit the number of retries on downstream server
authorRemi Gacogne <rgacogne-github@coredump.fr>
Thu, 5 Nov 2015 23:47:14 +0000 (00:47 +0100)
committerRemi Gacogne <rgacogne-github@coredump.fr>
Fri, 6 Nov 2015 11:04:55 +0000 (12:04 +0100)
The TCP timeouts default to 30s for downstream server in both sides, 2s for client.

We now only retry 5 times if the downstream server keeps failing on us.

pdns/README-dnsdist.md
pdns/dnsdist-lua.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/misc.cc
pdns/misc.hh

index 368eed15e1623b55be81ed8f80f07b64625dd9a1..5c2fb257da6bdab9a36cab973ea82293c5838845 100644 (file)
@@ -133,6 +133,24 @@ To change the QPS for a server:
 > getServer(0):setQPS(1000)
 ```
 
+TCP timeouts
+------------
+
+By default, a 2 seconds timeout is enforced on the TCP connection from the client,
+meaning that a connection will be closed if the query can't be read in less than 2s
+or if the answer can't be sent in less than 2s. This can be configured with:
+```
+> setTCPRecvTimeout(5)
+> setTCPSendTimeout(5)
+```
+
+The same kind of timeouts is enforced on the TCP connections to the downstream servers.
+The default value of 30s can be modified by passing the `tcpRecvTimeout` and `tcpSendTimeout`
+parameters to `newServer`:
+```
+newServer {address="192.0.2.1", tcpRecvTimeout=10, tcpSendTimeout=10}
+```
+
 Webserver
 ---------
 To visually interact with dnsdist, try adding:
@@ -544,7 +562,7 @@ Here are all functions:
    * `errlog(string)`: log at level error
  * Server related:
    * `newServer("ip:port")`: instantiate a new downstream server with default settings
-   * `newServer({address="ip:port", qps=1000, order=1, weight=10, pool="abuse"})`: instantiate
+   * `newServer({address="ip:port", qps=1000, order=1, weight=10, pool="abuse", retries=5, tcpSendTimeout=30, tcpRecvTimeout=30})`: instantiate
      a server with additional parameters
    * `showServers()`: output all servers
    * `getServer(n)`: returns server with index n 
index 3afa7923b89e2a5833067ef8d748263f059104c6..0ef97a5f7d37039b4283d534cc2e4184c1a78a2e 100644 (file)
@@ -152,6 +152,18 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
                          ret->weight=boost::lexical_cast<int>(boost::get<string>(vars["weight"]));
                        }
 
+                       if(vars.count("retries")) {
+                         ret->retries=boost::lexical_cast<int>(boost::get<string>(vars["retries"]));
+                       }
+
+                       if(vars.count("tcpSendTimeout")) {
+                         ret->tcpSendTimeout=boost::lexical_cast<int>(boost::get<string>(vars["tcpSendTimeout"]));
+                       }
+
+                       if(vars.count("tcpRecvTimeout")) {
+                         ret->tcpRecvTimeout=boost::lexical_cast<int>(boost::get<string>(vars["tcpRecvTimeout"]));
+                       }
+
                        if(g_launchWork) {
                          g_launchWork->push_back([ret]() {
                              ret->tid = move(thread(responderThread, ret));
@@ -814,6 +826,9 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
        g_outputBuffer="Crypto failed..\n";
      }});
 
+  g_lua.writeFunction("setTCPRecvTimeout", [](int timeout) { g_tcpRecvTimeout=timeout; });
+
+  g_lua.writeFunction("setTCPSendTimeout", [](int timeout) { g_tcpSendTimeout=timeout; });
   
   std::ifstream ifs(config);
   if(!ifs) 
index e42d9daa7de06cce5b5bae3fe4b5eb6849eda62e..45d12a681b119e08b15f3a027cc81c42d88e77c2 100644 (file)
@@ -48,6 +48,7 @@ static int setupTCPDownstream(const ComboAddress& remote)
   vinfolog("TCP connecting to downstream %s", remote.toStringWithPort());
   int sock = SSocket(remote.sin4.sin_family, SOCK_STREAM, 0);
   SConnect(sock, remote);
+  setNonBlocking(sock);
   return sock;
 }
 
@@ -68,12 +69,40 @@ void TCPClientCollection::addTCPClientThread()
   if(pipe(pipefds) < 0)
     unixDie("Creating pipe");
 
+  if (!setNonBlocking(pipefds[1]))
+    unixDie("Setting pipe non-blocking");
+
   d_tcpclientthreads.push_back(pipefds[1]);    
   thread t1(tcpClientThread, pipefds[0]);
   t1.detach();
   ++d_numthreads;
 }
 
+static bool getNonBlockingMsgLen(int fd, uint16_t* len, int timeout)
+try
+{
+  uint16_t raw;
+  int ret = readn2WithTimeout(fd, &raw, sizeof raw, timeout);
+  if(ret != sizeof raw)
+    return false;
+  *len = ntohs(raw);
+  return true;
+}
+catch(...) {
+  return false;
+}
+
+static bool putNonBlockingMsgLen(int fd, uint16_t len, int timeout)
+try
+{
+  uint16_t raw = htons(len);
+  int ret = writen2WithTimeout(fd, &raw, sizeof raw, timeout);
+  return ret == sizeof raw;
+}
+catch(...) {
+  return false;
+}
+
 TCPClientCollection g_tcpclientthreads;
 
 void* tcpClientThread(int pipefd)
@@ -107,13 +136,16 @@ void* tcpClientThread(int pipefd)
     string pool; 
 
     shared_ptr<DownstreamState> ds;
+    if (!setNonBlocking(ci.fd))
+      goto drop;
+
     try {
       for(;;) {      
-        if(!getMsgLen(ci.fd, &qlen))
+        if(!getNonBlockingMsgLen(ci.fd, &qlen, g_tcpRecvTimeout))
           break;
         
         char query[qlen];
-        readn2(ci.fd, query, qlen);
+        readn2WithTimeout(ci.fd, query, qlen, g_tcpRecvTimeout);
        uint16_t qtype;
        DNSName qname(query, qlen, 12, false, &qtype);
        string ruleresult;
@@ -168,8 +200,8 @@ void* tcpClientThread(int pipefd)
        }
        
        if(dh->qr) { // something turned it into a response
-         putMsgLen(ci.fd, qlen);
-         writen2(ci.fd, query, rlen);
+         if (putNonBlockingMsgLen(ci.fd, qlen, g_tcpSendTimeout))
+           writen2WithTimeout(ci.fd, query, rlen, g_tcpSendTimeout);
          goto drop;
        }
 
@@ -194,28 +226,43 @@ void* tcpClientThread(int pipefd)
        if(qtype == QType::AXFR || qtype == QType::IXFR)  // XXX fixme we really need to do better
          break;
 
+        uint16_t downstream_failures=0;
       retry:; 
-        if(!putMsgLen(dsock, qlen)) {
+        if (dsock < 0) {
+          sockets.erase(ds->remote);
+          break;
+        }
+
+        if (ds->retries > 0 && downstream_failures > ds->retries) {
+          vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->remote.toStringWithPort(), downstream_failures);
+          close(dsock);
+          sockets.erase(ds->remote);
+          break;
+        }
+
+        if(!putNonBlockingMsgLen(dsock, qlen, ds->tcpSendTimeout)) {
          vinfolog("Downstream connection to %s died on us, getting a new one!", ds->remote.toStringWithPort());
           close(dsock);
           sockets[ds->remote]=dsock=setupTCPDownstream(ds->remote);
+          downstream_failures++;
           goto retry;
         }
       
-        writen2(dsock, query, qlen);
+        writen2WithTimeout(dsock, query, qlen, ds->tcpSendTimeout);
       
-        if(!getMsgLen(dsock, &rlen)) {
+        if(!getNonBlockingMsgLen(dsock, &rlen, ds->tcpRecvTimeout)) {
          vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds->remote.toStringWithPort());
           close(dsock);
           sockets[ds->remote]=dsock=setupTCPDownstream(ds->remote);
+          downstream_failures++;
           goto retry;
         }
 
         char answerbuffer[rlen];
-        readn2(dsock, answerbuffer, rlen);
+        readn2WithTimeout(dsock, answerbuffer, rlen, ds->tcpRecvTimeout);
       
-        putMsgLen(ci.fd, rlen);
-        writen2(ci.fd, answerbuffer, rlen);
+        if (putNonBlockingMsgLen(ci.fd, rlen, ds->tcpSendTimeout))
+          writen2WithTimeout(ci.fd, answerbuffer, rlen, ds->tcpSendTimeout);
       }
     }
     catch(...){}
@@ -265,7 +312,8 @@ void* tcpAcceptorThread(void* p)
       vinfolog("Got TCP connection from %s", remote.toStringWithPort());
       
       ci->remote = remote;
-      writen2(g_tcpclientthreads.getThread(), &ci, sizeof(ci));
+      int pipe = g_tcpclientthreads.getThread();
+      writen2WithTimeout(pipe, &ci, sizeof(ci), 0);
     }
     catch(std::exception& e) {
       errlog("While reading a TCP question: %s", e.what());
index a2f922cf12acfd3f72edc05a82c2e85ad7827e73..4fb2e6cec64787278e5a7d0f15dd832e1e1faf97 100644 (file)
@@ -103,6 +103,9 @@ Rings g_rings;
 
 GlobalStateHolder<servers_t> g_dstates;
 
+int g_tcpRecvTimeout{2};
+int g_tcpSendTimeout{2};
+
 bool g_truncateTC{1};
 void truncateTC(const char* packet, unsigned int* len)
 try
index a5c598aacad8378486627820860da4ff1b12aac4..7d68354b3dbe3bf869e075666821ccd9a0bbb447 100644 (file)
@@ -246,6 +246,9 @@ struct DownstreamState
   double latencyUsec{0.0};
   int order{1};
   int weight{1};
+  int tcpRecvTimeout{30};
+  int tcpSendTimeout{30};
+  uint16_t retries{5};
   StopWatch sw;
   set<string> pools;
   enum class Availability { Up, Down, Auto} availability{Availability::Auto};
@@ -323,6 +326,8 @@ extern ComboAddress g_serverControl; // not changed during runtime
 extern std::vector<std::pair<ComboAddress, bool>> g_locals; // not changed at runtime (we hope XXX)
 extern std::string g_key; // in theory needs locking
 extern bool g_truncateTC;
+extern int g_tcpRecvTimeout;
+extern int g_tcpSendTimeout;
 struct dnsheader;
 
 void controlThread(int fd, ComboAddress local);
index bf5dae296daeb0bda60b788003b9376178d5ea5b..a3befbc4e5c8f46246daf1521c7c594a71a69eb7 100644 (file)
@@ -87,12 +87,12 @@ int readn2(int fd, void* buffer, unsigned int len)
   for(;;) {
     res = read(fd, (char*)buffer + pos, len - pos);
     if(res == 0)
-      throw runtime_error("EOF while writing message");
+      throw runtime_error("EOF while reading message");
     if(res < 0) {
       if (errno == EAGAIN)
-        throw std::runtime_error("used writen2 on non-blocking socket, got EAGAIN");
+        throw std::runtime_error("used readn2 on non-blocking socket, got EAGAIN");
       else
-        unixDie("failed in writen2");
+        unixDie("failed in readn2");
     }
 
     pos+=res;
@@ -102,6 +102,71 @@ int readn2(int fd, void* buffer, unsigned int len)
   return len;
 }
 
+int readn2WithTimeout(int fd, void* buffer, size_t len, int timeout)
+{
+  size_t pos = 0;
+  do {
+    ssize_t got = read(fd, (char *)buffer + pos, len - pos);
+    if (got > 0) {
+      pos += (size_t) got;
+    }
+    else if (got == 0) {
+      throw runtime_error("EOF while reading message");
+    }
+    else {
+      if (errno == EAGAIN) {
+        int res = waitForData(fd, timeout);
+        if (res > 0) {
+          /* there is data available */
+        }
+        else if (res == 0) {
+          throw runtime_error("Timeout while waiting for data to read");
+        } else {
+          throw runtime_error("Error while waiting for data to read");
+        }
+      }
+      else {
+        unixDie("failed in readn2WithTimeout");
+      }
+    }
+  }
+  while (pos < len);
+
+  return len;
+}
+
+int writen2WithTimeout(int fd, const void * buffer, size_t len, int timeout)
+{
+  size_t pos = 0;
+  do {
+    ssize_t written = write(fd, (char *)buffer + pos, len - pos);
+
+    if (written > 0) {
+      pos += (size_t) written;
+    }
+    else if (written == 0)
+      throw runtime_error("EOF while writing message");
+    else {
+      if (errno == EAGAIN) {
+        int res = waitForRWData(fd, false, timeout, 0);
+        if (res > 0) {
+          /* there is room available */
+        }
+        else if (res == 0) {
+          throw runtime_error("Timeout while waiting to write data");
+        } else {
+          throw runtime_error("Error while waiting for room to write data");
+        }
+      }
+      else {
+        unixDie("failed in write2WithTimeout");
+      }
+    }
+  }
+  while (pos < len);
+
+  return len;
+}
 
 string nowTime()
 {
index 2141cd143cb18b9492eb600e434f47886f964774..19b5e38754a0cabe72115eb83e057b60d24e1e3a 100644 (file)
@@ -151,6 +151,8 @@ vstringtok (Container &container, string const &in,
 int writen2(int fd, const void *buf, size_t count);
 inline int writen2(int fd, const std::string &s) { return writen2(fd, s.data(), s.size()); }
 int readn2(int fd, void* buffer, unsigned int len);
+int readn2WithTimeout(int fd, void* buffer, size_t len, int timeout);
+int writen2WithTimeout(int fd, const void * buffer, size_t len, int timeout);
 
 const string toLower(const string &upper);
 const string toLowerCanonic(const string &upper);