]> granicus.if.org Git - pdns/commitdiff
dnsdist: Try reading from the TCP backend right away
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 3 Apr 2019 16:10:55 +0000 (18:10 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 4 Apr 2019 09:54:06 +0000 (11:54 +0200)
Instead of waiting for the socket to be readable, as it might
already be, so we save a multiplexer trip, and prevent an issue
if we ever add a TLS layer between dnsdist and the backends.

pdns/dnsdist-tcp.cc

index 0e81173a7bc2f83a0a68572edf5701922bdd218f..fa341c208d9d4f0d3e7f6821fe7a94f3f04701ef 100644 (file)
@@ -531,8 +531,9 @@ public:
 static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
 static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd=boost::none);
 static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
+static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
 
-static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state)
+static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
 {
   handleNewIOState(state, IOState::Done, state->d_ci.fd, handleIOCallback);
 
@@ -540,10 +541,7 @@ static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& stat
     /* we need to resume reading from the backend! */
     state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
     state->d_currentPos = 0;
-    //cerr<<__func__<<": add read client FD "<<state->d_ci.fd<<endl;
-    // XXX: if we ever do TLS toward the backend, we need to try to read right away
-    // because the TLS layer might have more bits already waiting for us
-    handleNewIOState(state, IOState::NeedRead, state->d_downstreamSocket->getHandle(), handleDownstreamIOCallback, state->getBackendReadTTD());
+    handleDownstreamIO(state, now);
     return;
   }
 
@@ -552,8 +550,6 @@ static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& stat
     return;
   }
 
-  struct timeval now;
-  gettimeofday(&now, 0);
   if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
     vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
     return;
@@ -564,7 +560,7 @@ static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& stat
   handleIO(state, now);
 }
 
-static void sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state)
+static void sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
 {
   state->d_state = IncomingTCPConnectionState::State::sendingResponse;
   const uint8_t sizeBytes[] = { static_cast<uint8_t>(state->d_responseSize / 256), static_cast<uint8_t>(state->d_responseSize % 256) };
@@ -575,26 +571,10 @@ static void sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state)
 
   state->d_currentPos = 0;
 
-  try {
-    auto iostate = state->d_handler.tryWrite(state->d_responseBuffer, state->d_currentPos, state->d_responseBuffer.size());
-    if (iostate == IOState::Done) {
-
-      handleResponseSent(state);
-      return;
-    }
-    else {
-      //cerr<<__func__<<": adding client write FD "<<state->d_ci.fd<<endl;
-      handleNewIOState(state, IOState::NeedWrite, state->d_ci.fd, handleIOCallback, state->getClientWriteTTD());
-    }
-  }
-  catch (const std::exception& e) {
-    vinfolog("Got an exception while writing TCP response to %s: %s", state->d_ci.remote.toStringWithPort(), e.what());
-    ++state->d_ci.cs->tcpDiedSendingResponse;
-    handleNewIOState(state, IOState::Done, state->d_ci.fd, handleIOCallback);
-  }
+  handleIO(state, now);
 }
 
-static void handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state)
+static void handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
 {
   if (state->d_responseSize < sizeof(dnsheader)) {
     return;
@@ -643,7 +623,7 @@ static void handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state)
     state->d_xfrStarted = true;
   }
 
-  sendResponse(state);
+  sendResponse(state, now);
 
   ++g_stats.responses;
   struct timespec answertime;
@@ -652,7 +632,7 @@ static void handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state)
   g_rings.insertResponse(answertime, state->d_ci.remote, *dr.qname, dr.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(state->d_responseBuffer.size()), cleartextDH, state->d_ds->remote);
 }
 
-static void sendQueryToBackend(std::shared_ptr<IncomingTCPConnectionState>& state)
+static void sendQueryToBackend(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
 {
   auto ds = state->d_ds;
   state->d_state = IncomingTCPConnectionState::State::sendingQueryToBackend;
@@ -677,8 +657,7 @@ static void sendQueryToBackend(std::shared_ptr<IncomingTCPConnectionState>& stat
       return;
     }
 
-    //cerr<<__func__<<": add write backend FD "<<state->d_downstreamSocket->getHandle()<<endl;
-    handleNewIOState(state, IOState::NeedWrite, state->d_downstreamSocket->getHandle(), handleDownstreamIOCallback, state->getBackendWriteTTD());
+    handleDownstreamIO(state, now);
     return;
   }
 
@@ -687,7 +666,7 @@ static void sendQueryToBackend(std::shared_ptr<IncomingTCPConnectionState>& stat
   vinfolog("Downstream connection to %s failed %u times in a row, giving up.", ds->getName(), state->d_downstreamFailures);
 }
 
-static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state)
+static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
 {
   if (state->d_querySize < sizeof(dnsheader)) {
     ++g_stats.nonCompliantQueries;
@@ -702,9 +681,7 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state)
   /* we need an accurate ("real") value for the response and
      to store into the IDS, but not for insertion into the
      rings for example */
-  struct timespec now;
   struct timespec queryRealTime;
-  gettime(&now);
   gettime(&queryRealTime, true);
 
   auto query = reinterpret_cast<char*>(&state->d_buffer.at(0));
@@ -713,7 +690,7 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state)
   if (dnsCryptResponse) {
     state->d_responseBuffer = std::move(*dnsCryptResponse);
     state->d_responseSize = state->d_responseBuffer.size();
-    sendResponse(state);
+    sendResponse(state, now);
     return;
   }
 
@@ -744,7 +721,7 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state)
     state->d_buffer.resize(dq.len);
     state->d_responseBuffer = std::move(state->d_buffer);
     state->d_responseSize = state->d_responseBuffer.size();
-    sendResponse(state);
+    sendResponse(state, now);
     return;
   }
 
@@ -760,7 +737,7 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state)
      that could occur if we had to deal with the size during the processing,
      especially alignment issues */
   state->d_buffer.insert(state->d_buffer.begin(), sizeBytes, sizeBytes + 2);
-  sendQueryToBackend(state);
+  sendQueryToBackend(state, now);
 }
 
 static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd)
@@ -805,20 +782,15 @@ static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state,
   }
 }
 
-static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
 {
-  auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
   if (state->d_downstreamSocket == nullptr) {
     throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
   }
-  if (fd != state->d_downstreamSocket->getHandle()) {
-    throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_downstreamSocket->getHandle()));
-  }
 
+  int fd = state->d_downstreamSocket->getHandle();
   IOState iostate = IOState::Done;
   bool connectionDied = false;
-  struct timeval now;
-  gettimeofday(&now, 0);
 
   try {
     if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
@@ -877,7 +849,7 @@ static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param
         }
         fd = -1;
 
-        handleResponse(state);
+        handleResponse(state, now);
         return;
       }
     }
@@ -922,8 +894,23 @@ static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param
   }
 
   if (connectionDied) {
-    sendQueryToBackend(state);
+    sendQueryToBackend(state, now);
+  }
+}
+
+static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+{
+  auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
+  if (state->d_downstreamSocket == nullptr) {
+    throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
+  }
+  if (fd != state->d_downstreamSocket->getHandle()) {
+    throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_downstreamSocket->getHandle()));
   }
+
+  struct timeval now;
+  gettimeofday(&now, 0);
+  handleDownstreamIO(state, now);
 }
 
 static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
@@ -967,7 +954,7 @@ static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct
       iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize);
       if (iostate == IOState::Done) {
         handleNewIOState(state, IOState::Done, fd, handleIOCallback);
-        handleQuery(state);
+        handleQuery(state, now);
         return;
       }
     }
@@ -975,7 +962,7 @@ static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct
     if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
       iostate = state->d_handler.tryWrite(state->d_responseBuffer, state->d_currentPos, state->d_responseBuffer.size());
       if (iostate == IOState::Done) {
-        handleResponseSent(state);
+        handleResponseSent(state, now);
         return;
       }
     }