]> granicus.if.org Git - pdns/commitdiff
dnsdist: Refactoring of the response handling path
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 26 Feb 2019 13:17:46 +0000 (14:17 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 4 Apr 2019 09:46:25 +0000 (11:46 +0200)
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh

index f13422943e0382e8e08b92cb5618597808a2f9b7..26608a4356c027109b647d6569c0a6ee49423367 100644 (file)
@@ -245,7 +245,7 @@ static bool maxConnectionDurationReached(unsigned int maxConnectionDuration, tim
   return false;
 }
 
-void cleanupClosedTCPConnections(std::map<ComboAddress,int>& sockets)
+static void cleanupClosedTCPConnections(std::map<ComboAddress,int>& sockets)
 {
   for(auto it = sockets.begin(); it != sockets.end(); ) {
     if (isTCPSocketUsable(it->second)) {
@@ -272,11 +272,9 @@ void tcpClientThread(int pipefd)
   
   LocalHolders holders;
   auto localRespRulactions = g_resprulactions.getLocal();
-#ifdef HAVE_DNSCRYPT
   /* when the answer is encrypted in place, we need to get a copy
      of the original header before encryption to fill the ring buffer */
-  dnsheader dhCopy;
-#endif
+  dnsheader cleartextDH;
 
   map<ComboAddress,int> sockets;
   for(;;) {
@@ -297,7 +295,7 @@ void tcpClientThread(int pipefd)
     vector<uint8_t> rewrittenResponse;
     shared_ptr<DownstreamState> ds;
     size_t queriesCount = 0;
-    time_t connectionStartTime = time(NULL);
+    time_t connectionStartTime = time(nullptr);
     std::vector<char> queryBuffer;
     std::vector<char> answerBuffer;
 
@@ -343,7 +341,7 @@ void tcpClientThread(int pipefd)
 
         /* allocate a bit more memory to be able to spoof the content,
            or to add ECS without allocating a new buffer */
-        queryBuffer.resize(qlen + 512);
+        queryBuffer.resize((static_cast<size_t>(qlen) + 512) < 4096 ? (static_cast<size_t>(qlen) + 512) : 4096);
 
         char* query = &queryBuffer[0];
         handler.read(query, qlen, g_tcpRecvTimeout, remainingTime);
@@ -357,7 +355,6 @@ void tcpClientThread(int pipefd)
         gettime(&queryRealTime, true);
 
         std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
-
 #ifdef HAVE_DNSCRYPT
         auto dnsCryptResponse = checkDNSCryptQuery(*ci.cs, query, qlen, dnsCryptQuery, queryRealTime.tv_sec, true);
         if (dnsCryptResponse) {
@@ -377,22 +374,21 @@ void tcpClientThread(int pipefd)
         DNSQuestion dq(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, queryBuffer.size(), qlen, true, &queryRealTime);
         dq.dnsCryptQuery = std::move(dnsCryptQuery);
 
-        responseSender sender = [&handler](const ClientState& cs, const char* data, uint16_t dataSize, int delayMsec, const ComboAddress& dest, const ComboAddress& remote) {
-          handler.writeSizeAndMsg(data, dataSize, g_tcpSendTimeout);
-        };
+        std::shared_ptr<DownstreamState> ds{nullptr};
+        auto result = processQuery(dq, *ci.cs, holders, ds);
 
-        bool dropped = false;
-        auto ds = processQuery(dq, *ci.cs, holders, sender, dropped);
-        if (!ds) {
-          if (dropped) {
-            break;
-          }
+        if (result == ProcessQueryResult::Drop) {
+          break;
+        }
+
+        if (result == ProcessQueryResult::SendAnswer) {
+          handler.writeSizeAndMsg(reinterpret_cast<char*>(dq.dh), dq.len, g_tcpSendTimeout);
           continue;
         }
 
-        // check how that would work!!
-        char cachedResponse[4096];
-        uint16_t cachedResponseSize = sizeof cachedResponse;
+        if (result != ProcessQueryResult::PassToBackend || ds == nullptr) {
+          break;
+        }
 
        int dsock = -1;
        uint16_t downstreamFailures=0;
@@ -478,7 +474,7 @@ void tcpClientThread(int pipefd)
         size_t responseSize = rlen;
         uint16_t addRoom = 0;
 #ifdef HAVE_DNSCRYPT
-        if (dq.dnsCryptQuery && (UINT16_MAX - rlen) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) {
+        if (dq.dnsCryptQuery && (UINT16_MAX - rlen) > static_cast<uint16_t>(DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE)) {
           addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
         }
 #endif
@@ -502,43 +498,36 @@ void tcpClientThread(int pipefd)
           break;
         }
         firstPacket=false;
-        bool zeroScope = false;
-        if (!fixUpResponse(&response, &responseLen, &responseSize, qname, dq.origFlags, dq.ednsAdded, dq.ecsAdded, rewrittenResponse, addRoom, dq.useZeroScope ? &zeroScope : nullptr)) {
-          break;
-        }
 
-        dh = (struct dnsheader*) response;
+        dh = reinterpret_cast<struct dnsheader*>(response);
         DNSResponse dr(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, responseSize, responseLen, true, &queryRealTime);
+        dr.origFlags = dq.origFlags;
+        dr.ecsAdded = dq.ecsAdded;
+        dr.ednsAdded = dq.ednsAdded;
+        dr.useZeroScope = dq.useZeroScope;
+        dr.packetCache = std::move(dq.packetCache);
+        dr.delayMsec = dq.delayMsec;
+        dr.skipCache = dq.skipCache;
+        dr.cacheKey = dq.cacheKey;
+        dr.cacheKeyNoECS = dq.cacheKeyNoECS;
+        dr.dnssecOK = dq.dnssecOK;
+        dr.tempFailureTTL = dq.tempFailureTTL;
+        dr.qTag = std::move(dq.qTag);
+        dr.subnet = std::move(dq.subnet);
 #ifdef HAVE_PROTOBUF
-        dr.uniqueId = dq.uniqueId;
+        dr.uniqueId = std::move(dq.uniqueId);
 #endif
-        dr.qTag = dq.qTag;
-
-        if (!processResponse(localRespRulactions, dr, &dq.delayMsec)) {
-          break;
+#ifdef HAVE_DNSCRYPT
+        if (dq.dnsCryptQuery) {
+          dr.dnsCryptQuery = std::move(dq.dnsCryptQuery);
         }
+#endif
 
-       if (dq.packetCache && !dq.skipCache) {
-          if (!dq.useZeroScope) {
-            /* if the query was not suitable for zero-scope, for
-               example because it had an existing ECS entry so the hash is
-               not really 'no ECS', so just insert it for the existing subnet
-               since:
-               - we don't have the correct hash for a non-ECS query
-               - inserting with hash computed before the ECS replacement but with
-               the subnet extracted _after_ the replacement would not work.
-            */
-            zeroScope = false;
-          }
-          // if zeroScope, pass the pre-ECS hash-key and do not pass the subnet to the cache
-          dq.packetCache->insert(zeroScope ? dq.cacheKeyNoECS : dq.cacheKey, zeroScope ? boost::none : dq.subnet, dq.origFlags, dq.dnssecOK, qname, qtype, qclass, response, responseLen, true, dh->rcode, dq.tempFailureTTL);
-       }
-
-#ifdef HAVE_DNSCRYPT
-        if (!encryptResponse(response, &responseLen, responseSize, true, dq.dnsCryptQuery, &dh, &dhCopy)) {
+        memcpy(&cleartextDH, dr.dh, sizeof(cleartextDH));
+        if (!processResponse(&response, &responseLen, &responseSize, localRespRulactions, dr, addRoom, rewrittenResponse, false)) {
           break;
         }
-#endif
+
         if (!handler.writeSizeAndMsg(response, responseLen, g_tcpSendTimeout)) {
           break;
         }
@@ -576,7 +565,7 @@ void tcpClientThread(int pipefd)
         struct timespec answertime;
         gettime(&answertime);
         unsigned int udiff = 1000000.0*DiffTime(now,answertime);
-        g_rings.insertResponse(answertime, ci.remote, qname, dq.qtype, (unsigned int)udiff, (unsigned int)responseLen, *dh, ds->remote);
+        g_rings.insertResponse(answertime, ci.remote, qname, dq.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(responseLen), cleartextDH, ds->remote);
 
         rewrittenResponse.clear();
       }
@@ -625,9 +614,9 @@ void tcpAcceptorThread(void* p)
       ci = std::unique_ptr<ConnectionInfo>(new ConnectionInfo);
       ci->cs = cs;
 #ifdef HAVE_ACCEPT4
-      ci->fd = accept4(cs->tcpFD, (struct sockaddr*)&remote, &remlen, SOCK_NONBLOCK);
+      ci->fd = accept4(cs->tcpFD, reinterpret_cast<struct sockaddr*>(&remote), &remlen, SOCK_NONBLOCK);
 #else
-      ci->fd = accept(cs->tcpFD, (struct sockaddr*)&remote, &remlen);
+      ci->fd = accept(cs->tcpFD, reinterpret_cast<struct sockaddr*>(&remote), &remlen);
 #endif
       if(ci->fd < 0) {
         throw std::runtime_error((boost::format("accepting new connection on socket: %s") % strerror(errno)).str());
@@ -685,7 +674,7 @@ void tcpAcceptorThread(void* p)
         }
       }
     }
-    catch(std::exception& e) {
+    catch(const std::exception& e) {
       errlog("While reading a TCP question: %s", e.what());
       if(tcpClientCountIncremented) {
         decrementTCPClientCount(remote);
index 5d6d8850d31f6dd5cba05c8777279b65883bcd15..07f587e923a8d48a110f3bcfb1d4c97a28effb97 100644 (file)
@@ -191,7 +191,7 @@ struct DelayedPacket
   }
 };
 
-DelayPipe<DelayedPacket> * g_delay = 0;
+DelayPipe<DelayedPacket>* g_delay = nullptr;
 
 void doLatencyStats(double udiff)
 {
@@ -214,14 +214,11 @@ void doLatencyStats(double udiff)
 
 bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote, unsigned int& consumed)
 {
-  uint16_t rqtype, rqclass;
-  DNSName rqname;
-  const struct dnsheader* dh = (struct dnsheader*) response;
-
   if (responseLen < sizeof(dnsheader)) {
     return false;
   }
 
+  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(response);
   if (dh->qdcount == 0) {
     if ((dh->rcode != RCode::NoError && dh->rcode != RCode::NXDomain) || g_allowEmptyResponse) {
       return true;
@@ -232,12 +229,15 @@ bool responseContentMatches(const char* response, const uint16_t responseLen, co
     }
   }
 
+  uint16_t rqtype, rqclass;
+  DNSName rqname;
   try {
     rqname=DNSName(response, responseLen, sizeof(dnsheader), false, &rqtype, &rqclass, &consumed);
   }
-  catch(std::exception& e) {
-    if(responseLen > (ssize_t)sizeof(dnsheader))
+  catch(const std::exception& e) {
+    if(responseLen > 0 && static_cast<size_t>(responseLen) > sizeof(dnsheader)) {
       infolog("Backend %s sent us a response with id %d that did not parse: %s", remote.toStringWithPort(), ntohs(dh->id), e.what());
+    }
     ++g_stats.nonCompliantResponses;
     return false;
   }
@@ -270,14 +270,13 @@ static bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags)
   return addEDNSToQueryTurnedResponse(dq);
 }
 
-bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, std::vector<uint8_t>& rewrittenResponse, uint16_t addRoom, bool* zeroScope)
+static bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, std::vector<uint8_t>& rewrittenResponse, uint16_t addRoom, bool* zeroScope)
 {
-  struct dnsheader* dh = (struct dnsheader*) *response;
-
   if (*responseLen < sizeof(dnsheader)) {
     return false;
   }
 
+  struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(*response);
   restoreFlags(dh, origFlags);
 
   if (*responseLen == sizeof(dnsheader)) {
@@ -368,7 +367,7 @@ bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize,
 }
 
 #ifdef HAVE_DNSCRYPT
-bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr<DNSCryptQuery> dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy)
+static bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr<DNSCryptQuery> dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy)
 {
   if (dnsCryptQuery) {
     uint16_t encryptedResponseLen = 0;
@@ -392,7 +391,80 @@ bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize,
 }
 #endif
 
-static bool sendUDPResponse(int origFD, const char* response, uint16_t responseLen, int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
+static bool applyRulesToResponse(LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr)
+{
+  DNSResponseAction::Action action=DNSResponseAction::Action::None;
+  std::string ruleresult;
+  for(const auto& lr : *localRespRulactions) {
+    if(lr.d_rule->matches(&dr)) {
+      lr.d_rule->d_matches++;
+      action=(*lr.d_action)(&dr, &ruleresult);
+      switch(action) {
+      case DNSResponseAction::Action::Allow:
+        return true;
+        break;
+      case DNSResponseAction::Action::Drop:
+        return false;
+        break;
+      case DNSResponseAction::Action::HeaderModify:
+        return true;
+        break;
+      case DNSResponseAction::Action::ServFail:
+        dr.dh->rcode = RCode::ServFail;
+        return true;
+        break;
+        /* non-terminal actions follow */
+      case DNSResponseAction::Action::Delay:
+        dr.delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
+        break;
+      case DNSResponseAction::Action::None:
+        break;
+      }
+    }
+  }
+
+  return true;
+}
+
+bool processResponse(char** response, uint16_t* responseLen, size_t* responseSize, LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr, size_t addRoom, std::vector<uint8_t>& rewrittenResponse, bool muted)
+{
+  if (!applyRulesToResponse(localRespRulactions, dr)) {
+    return false;
+  }
+
+  bool zeroScope = false;
+  if (!fixUpResponse(response, responseLen, responseSize, *dr.qname, dr.origFlags, dr.ednsAdded, dr.ecsAdded, rewrittenResponse, addRoom, dr.useZeroScope ? &zeroScope : nullptr)) {
+    return false;
+  }
+
+  if (dr.packetCache && !dr.skipCache) {
+    if (!dr.useZeroScope) {
+      /* if the query was not suitable for zero-scope, for
+         example because it had an existing ECS entry so the hash is
+         not really 'no ECS', so just insert it for the existing subnet
+         since:
+         - we don't have the correct hash for a non-ECS query
+         - inserting with hash computed before the ECS replacement but with
+         the subnet extracted _after_ the replacement would not work.
+      */
+      zeroScope = false;
+    }
+    // if zeroScope, pass the pre-ECS hash-key and do not pass the subnet to the cache
+    dr.packetCache->insert(zeroScope ? dr.cacheKeyNoECS : dr.cacheKey, zeroScope ? boost::none : dr.subnet, dr.origFlags, dr.dnssecOK, *dr.qname, dr.qtype, dr.qclass, *response, *responseLen, dr.tcp, dr.dh->rcode, dr.tempFailureTTL);
+  }
+
+#ifdef HAVE_DNSCRYPT
+  if (!muted) {
+    if (!encryptResponse(*response, responseLen, *responseSize, dr.tcp, dr.dnsCryptQuery, nullptr, nullptr)) {
+      return false;
+    }
+  }
+#endif
+
+  return true;
+}
+
+static bool sendUDPResponse(int origFD, const char* response, const uint16_t responseLen, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
 {
   if(delayMsec && g_delay) {
     DelayedPacket dp{origFD, string(response,responseLen), origRemote, origDest};
@@ -401,7 +473,7 @@ static bool sendUDPResponse(int origFD, const char* response, uint16_t responseL
   else {
     ssize_t res;
     if(origDest.sin4.sin_family == 0) {
-      res = sendto(origFD, response, responseLen, 0, (struct sockaddr*)&origRemote, origRemote.getSocklen());
+      res = sendto(origFD, response, responseLen, 0, reinterpret_cast<const struct sockaddr*>(&origRemote), origRemote.getSocklen());
     }
     else {
       res = sendfromto(origFD, response, responseLen, 0, origDest, origRemote);
@@ -443,13 +515,13 @@ try {
   auto localRespRulactions = g_resprulactions.getLocal();
 #ifdef HAVE_DNSCRYPT
   char packet[4096 + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE];
-  /* when the answer is encrypted in place, we need to get a copy
-     of the original header before encryption to fill the ring buffer */
-  dnsheader dhCopy;
 #else
   char packet[4096];
 #endif
   static_assert(sizeof(packet) <= UINT16_MAX, "Packet size should fit in a uint16_t");
+  /* when the answer is encrypted in place, we need to get a copy
+     of the original header before encryption to fill the ring buffer */
+  dnsheader cleartextDH;
   vector<uint8_t> rewrittenResponse;
 
   uint16_t queryId = 0;
@@ -465,10 +537,10 @@ try {
         char * response = packet;
         size_t responseSize = sizeof(packet);
 
-        if (got < (ssize_t) sizeof(dnsheader))
+        if (got < static_cast<ssize_t>(sizeof(dnsheader)))
           continue;
 
-        uint16_t responseLen = (uint16_t) got;
+        uint16_t responseLen = static_cast<uint16_t>(got);
         queryId = dh->id;
 
         if(queryId >= dss->idStates.size())
@@ -509,63 +581,50 @@ try {
 
         uint16_t addRoom = 0;
         DNSResponse dr(&ids->qname, ids->qtype, ids->qclass, consumed, &ids->origDest, &ids->origRemote, dh, sizeof(packet), responseLen, false, &ids->sentTime.d_start);
+        dr.origFlags = ids->origFlags;
+        dr.ecsAdded = ids->ecsAdded;
+        dr.ednsAdded = ids->ednsAdded;
+        dr.useZeroScope = ids->useZeroScope;
+        dr.packetCache = std::move(ids->packetCache);
+        dr.delayMsec = ids->delayMsec;
+        dr.skipCache = ids->skipCache;
+        dr.cacheKey = ids->cacheKey;
+        dr.cacheKeyNoECS = ids->cacheKeyNoECS;
+        dr.dnssecOK = ids->dnssecOK;
+        dr.tempFailureTTL = ids->tempFailureTTL;
+        dr.qTag = std::move(ids->qTag);
+        dr.subnet = std::move(ids->subnet);
 #ifdef HAVE_PROTOBUF
-        dr.uniqueId = ids->uniqueId;
+        dr.uniqueId = std::move(ids->uniqueId);
 #endif
-        dr.qTag = ids->qTag;
-
-        if (!processResponse(localRespRulactions, dr, &ids->delayMsec)) {
-          continue;
-        }
-
 #ifdef HAVE_DNSCRYPT
         if (ids->dnsCryptQuery) {
           addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
+          dr.dnsCryptQuery = std::move(ids->dnsCryptQuery);
         }
 #endif
-        bool zeroScope = false;
-        if (!fixUpResponse(&response, &responseLen, &responseSize, ids->qname, ids->origFlags, ids->ednsAdded, ids->ecsAdded, rewrittenResponse, addRoom, ids->useZeroScope ? &zeroScope : nullptr)) {
-          continue;
-        }
 
-        if (ids->packetCache && !ids->skipCache) {
-          if (!ids->useZeroScope) {
-            /* if the query was not suitable for zero-scope, for
-               example because it had an existing ECS entry so the hash is
-               not really 'no ECS', so just insert it for the existing subnet
-               since:
-               - we don't have the correct hash for a non-ECS query
-               - inserting with hash computed before the ECS replacement but with
-               the subnet extracted _after_ the replacement would not work.
-            */
-            zeroScope = false;
-          }
-          // if zeroScope, pass the pre-ECS hash-key and do not pass the subnet to the cache
-          ids->packetCache->insert(zeroScope ? ids->cacheKeyNoECS : ids->cacheKey, zeroScope ? boost::none : ids->subnet, ids->origFlags, ids->dnssecOK, ids->qname, ids->qtype, ids->qclass, response, responseLen, false, dh->rcode, ids->tempFailureTTL);
+        memcpy(&cleartextDH, dr.dh, sizeof(cleartextDH));
+        if (!processResponse(&response, &responseLen, &responseSize, localRespRulactions, dr, addRoom, rewrittenResponse, ids->cs && ids->cs->muted)) {
+          continue;
         }
 
         if (ids->cs && !ids->cs->muted) {
-#ifdef HAVE_DNSCRYPT
-          if (!encryptResponse(response, &responseLen, responseSize, false, ids->dnsCryptQuery, &dh, &dhCopy)) {
-            continue;
-          }
-#endif
-
           ComboAddress empty;
           empty.sin4.sin_family = 0;
           /* if ids->destHarvested is false, origDest holds the listening address.
              We don't want to use that as a source since it could be 0.0.0.0 for example. */
-          sendUDPResponse(origFD, response, responseLen, ids->delayMsec, ids->destHarvested ? ids->origDest : empty, ids->origRemote);
+          sendUDPResponse(origFD, response, responseLen, dr.delayMsec, ids->destHarvested ? ids->origDest : empty, ids->origRemote);
         }
 
         ++g_stats.responses;
 
         double udiff = ids->sentTime.udiff();
-        vinfolog("Got answer from %s, relayed to %s, took %f usec", dss->remote.toStringWithPort(), ids->origRemote.toStringWithPort(), udiff);
+        vinfolog("Got answer from %s, relayed to %s, took %f usec", dss->remote.toStringWithPort(), dr.remote->toStringWithPort(), udiff);
 
         struct timespec ts;
         gettime(&ts);
-        g_rings.insertResponse(ts, ids->origRemote, ids->qname, ids->qtype, (unsigned int)udiff, (unsigned int)got, *dh, dss->remote);
+        g_rings.insertResponse(ts, *dr.remote, *dr.qname, dr.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(got), cleartextDH, dss->remote);
 
         switch (dh->rcode) {
         case RCode::NXDomain:
@@ -583,14 +642,6 @@ try {
 
         doLatencyStats(udiff);
 
-        /* if the FD is not -1, the state has been actively reused and we should
-           not alter anything */
-        if (ids->origFD == -1) {
-#ifdef HAVE_DNSCRYPT
-          ids->dnsCryptQuery = nullptr;
-#endif
-        }
-
         rewrittenResponse.clear();
       }
     }
@@ -1172,41 +1223,6 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& po
   return true;
 }
 
-bool processResponse(LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr, int* delayMsec)
-{
-  DNSResponseAction::Action action=DNSResponseAction::Action::None;
-  std::string ruleresult;
-  for(const auto& lr : *localRespRulactions) {
-    if(lr.d_rule->matches(&dr)) {
-      lr.d_rule->d_matches++;
-      action=(*lr.d_action)(&dr, &ruleresult);
-      switch(action) {
-      case DNSResponseAction::Action::Allow:
-        return true;
-        break;
-      case DNSResponseAction::Action::Drop:
-        return false;
-        break;
-      case DNSResponseAction::Action::HeaderModify:
-        return true;
-        break;
-      case DNSResponseAction::Action::ServFail:
-        dr.dh->rcode = RCode::ServFail;
-        return true;
-        break;
-        /* non-terminal actions follow */
-      case DNSResponseAction::Action::Delay:
-        *delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
-        break;
-      case DNSResponseAction::Action::None:
-        break;
-      }
-    }
-  }
-
-  return true;
-}
-
 static ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss, const int sd, const char* request, const size_t requestLen, bool healthCheck=false)
 {
   ssize_t result;
@@ -1328,32 +1344,36 @@ static void queueResponse(const ClientState& cs, const char* response, uint16_t
 }
 #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
 
-static int sendResponse(LocalHolders& holders, ClientState& cs, DNSQuestion& dq, char* response, uint16_t responseLen, bool cacheHit, responseSender sender)
+/* self-generated responses or cache hits */
+static bool prepareOutgoingResponse(LocalHolders& holders, ClientState& cs, DNSQuestion& dq, bool cacheHit)
 {
-  DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(response), dq.size, responseLen, dq.tcp, dq.queryTime);
+  DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(dq.dh), dq.size, dq.len, dq.tcp, dq.queryTime);
 
 #ifdef HAVE_PROTOBUF
   dr.uniqueId = dq.uniqueId;
 #endif
   dr.qTag = dq.qTag;
+  dr.delayMsec = dq.delayMsec;
 
-  if (!processResponse(cacheHit ? holders.cacheHitRespRulactions : holders.selfAnsweredRespRulactions, dr, &dq.delayMsec)) {
-    return -1;
+  if (!applyRulesToResponse(cacheHit ? holders.cacheHitRespRulactions : holders.selfAnsweredRespRulactions, dr)) {
+    return false;
   }
 
+  /* in case a rule changed it */
+  dq.delayMsec = dr.delayMsec;
+
   if (!cs.muted) {
 #ifdef HAVE_DNSCRYPT
-    if (!encryptResponse(response, &responseLen, dq.size, dq.tcp, dq.dnsCryptQuery, nullptr, nullptr)) {
-      return -1;
+    if (!encryptResponse(reinterpret_cast<char*>(dq.dh), &dq.len, dq.size, dq.tcp, dq.dnsCryptQuery, nullptr, nullptr)) {
+      return false;
     }
 #endif
-
-    sender(cs, response, responseLen, dq.delayMsec, *dq.local, *dq.remote);
   }
 
   if (cacheHit) {
     ++g_stats.cacheHits;
   }
+
   switch (dr.dh->rcode) {
   case RCode::NXDomain:
     ++g_stats.frontendNXDomain;
@@ -1365,12 +1385,12 @@ static int sendResponse(LocalHolders& holders, ClientState& cs, DNSQuestion& dq,
     ++g_stats.frontendNoError;
     break;
   }
+
   doLatencyStats(0);  // we're not going to measure this
-  return 0;
+  return true;
 }
 
-/* returns nullptr if the query has been taken care of (cache-hit, self-answered or discarded) and a backend it should be sent to otherwise */
-std::shared_ptr<DownstreamState> processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, responseSender sender, bool& dropped)
+ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend)
 {
   const uint16_t queryId = ntohs(dq.dh->id);
 
@@ -1384,26 +1404,20 @@ std::shared_ptr<DownstreamState> processQuery(DNSQuestion& dq, ClientState& cs,
     string poolname;
 
     if (!applyRulesToQuery(holders, dq, poolname, now)) {
-      dropped = true;
-      return nullptr;
+      return ProcessQueryResult::Drop;
     }
 
     if(dq.dh->qr) { // something turned it into a response
       fixUpQueryTurnedResponse(dq, dq.origFlags);
 
-      if (!cs.muted) {
-        char* response = reinterpret_cast<char*>(dq.dh);
-        uint16_t responseLen = dq.len;
-
-        sendResponse(holders, cs, dq, response, responseLen, false, sender);
-
-        ++g_stats.selfAnswered;
+      if (!prepareOutgoingResponse(holders, cs, dq, false)) {
+        return ProcessQueryResult::Drop;
       }
 
-      return nullptr;
+      ++g_stats.selfAnswered;
+      return ProcessQueryResult::SendAnswer;
     }
 
-    std::shared_ptr<DownstreamState> ss{nullptr};
     std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, poolname);
     dq.packetCache = serverPool->packetCache;
     auto policy = *(holders.policy);
@@ -1413,25 +1427,30 @@ std::shared_ptr<DownstreamState> processQuery(DNSQuestion& dq, ClientState& cs,
     auto servers = serverPool->getServers();
     if (policy.isLua) {
       std::lock_guard<std::mutex> lock(g_luamutex);
-      ss = policy.policy(servers, &dq);
+      selectedBackend = policy.policy(servers, &dq);
     }
     else {
-      ss = policy.policy(servers, &dq);
+      selectedBackend = policy.policy(servers, &dq);
     }
 
     uint16_t cachedResponseSize = dq.size;
-    uint32_t allowExpired = ss ? 0 : g_staleCacheEntriesTTL;
+    uint32_t allowExpired = selectedBackend ? 0 : g_staleCacheEntriesTTL;
 
     if (dq.packetCache && !dq.skipCache) {
       dq.dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO);
     }
 
-    if (dq.useECS && ((ss && ss->useECS) || (!ss && serverPool->getECS()))) {
+    if (dq.useECS && ((selectedBackend && selectedBackend->useECS) || (!selectedBackend && serverPool->getECS()))) {
       // we special case our cache in case a downstream explicitly gave us a universally valid response with a 0 scope
-      if (dq.packetCache && !dq.skipCache && (!ss || !ss->disableZeroScope) && dq.packetCache->isECSParsingEnabled()) {
+      if (dq.packetCache && !dq.skipCache && (!selectedBackend || !selectedBackend->disableZeroScope) && dq.packetCache->isECSParsingEnabled()) {
         if (dq.packetCache->get(dq, dq.consumed, dq.dh->id, reinterpret_cast<char*>(dq.dh), &cachedResponseSize, &dq.cacheKeyNoECS, dq.subnet, dq.dnssecOK, allowExpired)) {
-          sendResponse(holders, cs, dq, reinterpret_cast<char*>(dq.dh), cachedResponseSize, true, sender);
-          return nullptr;
+          dq.len = cachedResponseSize;
+
+          if (!prepareOutgoingResponse(holders, cs, dq, true)) {
+            return ProcessQueryResult::Drop;
+          }
+
+          return ProcessQueryResult::SendAnswer;
         }
 
         if (!dq.subnet) {
@@ -1442,49 +1461,54 @@ std::shared_ptr<DownstreamState> processQuery(DNSQuestion& dq, ClientState& cs,
 
       if (!handleEDNSClientSubnet(dq, &(dq.ednsAdded), &(dq.ecsAdded), g_preserveTrailingData)) {
         vinfolog("Dropping query from %s because we couldn't insert the ECS value", dq.remote->toStringWithPort());
-        dropped = true;
-        return nullptr;
+        return ProcessQueryResult::Drop;
       }
     }
 
     if (dq.packetCache && !dq.skipCache) {
       if (dq.packetCache->get(dq, dq.consumed, dq.dh->id, reinterpret_cast<char*>(dq.dh), &cachedResponseSize, &dq.cacheKey, dq.subnet, dq.dnssecOK, allowExpired)) {
-        sendResponse(holders, cs, dq, reinterpret_cast<char*>(dq.dh), cachedResponseSize, true, sender);
-        return nullptr;
+        dq.len = cachedResponseSize;
+
+        if (!prepareOutgoingResponse(holders, cs, dq, true)) {
+          return ProcessQueryResult::Drop;
+        }
+
+        return ProcessQueryResult::SendAnswer;
       }
       ++g_stats.cacheMisses;
     }
 
-    if(!ss) {
+    if(!selectedBackend) {
       ++g_stats.noPolicy;
 
-      if (g_servFailOnNoPolicy && !cs.muted) {
-        char* response = reinterpret_cast<char*>(dq.dh);
-        uint16_t responseLen = dq.len;
+      vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toString(), QType(dq.qtype).getName(), dq.remote->toStringWithPort());
+      if (g_servFailOnNoPolicy) {
         restoreFlags(dq.dh, dq.origFlags);
 
         dq.dh->rcode = RCode::ServFail;
         dq.dh->qr = true;
 
-        sendResponse(holders, cs, dq, response, responseLen, false, sender);
+        if (!prepareOutgoingResponse(holders, cs, dq, false)) {
+          return ProcessQueryResult::Drop;
+        }
         // no response-only statistics counter to update.
+        return ProcessQueryResult::SendAnswer;
       }
-      vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toString(), QType(dq.qtype).getName(), dq.remote->toStringWithPort());
-      return nullptr;
+
+      return ProcessQueryResult::Drop;
     }
 
-    if (dq.addXPF && ss->xpfRRCode != 0) {
-      addXPF(dq, ss->xpfRRCode, g_preserveTrailingData);
+    if (dq.addXPF && selectedBackend->xpfRRCode != 0) {
+      addXPF(dq, selectedBackend->xpfRRCode, g_preserveTrailingData);
     }
 
-    ss->queries++;
-    return ss;
+    selectedBackend->queries++;
+    return ProcessQueryResult::PassToBackend;
   }
   catch(const std::exception& e){
     vinfolog("Got an error while parsing a %s query from %s, id %d: %s", (dq.tcp ? "TCP" : "UDP"), dq.remote->toStringWithPort(), queryId, e.what());
-    dropped = true;
   }
-  return nullptr;
+  return ProcessQueryResult::Drop;
 }
 
 static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, char* query, uint16_t len, size_t queryBufferSize, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, char* respCBuf)
@@ -1501,12 +1525,9 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
        to store into the IDS, but not for insertion into the
        rings for example */
     struct timespec queryRealTime;
-    struct timespec now;
-    gettime(&now);
     gettime(&queryRealTime, true);
 
     std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
-
 #ifdef HAVE_DNSCRYPT
     auto dnsCryptResponse = checkDNSCryptQuery(cs, query, len, dnsCryptQuery, queryRealTime.tv_sec, false);
     if (dnsCryptResponse) {
@@ -1527,22 +1548,26 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
     DNSQuestion dq(&qname, qtype, qclass, consumed, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, dh, queryBufferSize, len, false, &queryRealTime);
     dq.dnsCryptQuery = std::move(dnsCryptQuery);
+    std::shared_ptr<DownstreamState> ss{nullptr};
+    auto result = processQuery(dq, cs, holders, ss);
 
-    responseSender sender = [&responsesVect, &queuedResponses, &respIOV, &respCBuf](const ClientState& cs, const char* data, uint16_t dataSize, int delayMsec, const ComboAddress& dest, const ComboAddress& remote) -> void {
+    if (result == ProcessQueryResult::Drop) {
+      return;
+    }
+
+    if (result == ProcessQueryResult::SendAnswer) {
 #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
-      if (delayMsec == 0 && responsesVect != nullptr) {
-        queueResponse(cs, data, dataSize, dest, remote, responsesVect[*queuedResponses], respIOV, respCBuf);
+      if (dq.delayMsec == 0 && responsesVect != nullptr) {
+        queueResponse(cs, reinterpret_cast<char*>(dq.dh), dq.len, *dq.local, *dq.remote, responsesVect[*queuedResponses], respIOV, respCBuf);
         (*queuedResponses)++;
         return;
       }
 #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
-      sendUDPResponse(cs.udpFD, data, dataSize, delayMsec, dest, remote);
-    };
-
-    bool dropped = false;
-    auto ss = processQuery(dq, cs, holders, sender, dropped);
+      sendUDPResponse(cs.udpFD, reinterpret_cast<char*>(dq.dh), dq.len, dq.delayMsec, *dq.local, *dq.remote);
+      return;
+    }
 
-    if (!ss) {
+    if (result != ProcessQueryResult::PassToBackend || ss == nullptr) {
       return;
     }
 
@@ -1676,8 +1701,8 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde
       unsigned int got = msgVec[msgIdx].msg_len;
       const ComboAddress& remote = recvData[msgIdx].remote;
 
-      if (got < sizeof(struct dnsheader)) {
-        g_stats.nonCompliantQueries++;
+      if (got < 0 || static_cast<size_t>(got) < sizeof(struct dnsheader)) {
+        ++g_stats.nonCompliantQueries;
         continue;
       }
 
@@ -1803,7 +1828,7 @@ try
     sock.bind(ds->sourceAddr);
   }
   sock.connect(ds->remote);
-  ssize_t sent = udpClientSendRequestToBackend(ds, sock.getHandle(), (char*)&packet[0], packet.size(), true);
+  ssize_t sent = udpClientSendRequestToBackend(ds, sock.getHandle(), reinterpret_cast<char*>(&packet[0]), packet.size(), true);
   if (sent < 0) {
     int ret = errno;
     if (g_verboseHealthChecks)
index 137f250bc69b0de769e1834424d283cdd6aabfd1..513cd95955ee6b3f3b4552da2211e61012ccd900 100644 (file)
@@ -1046,15 +1046,13 @@ bool getLuaNoSideEffect(); // set if there were only explicit declarations of _n
 void resetLuaSideEffect(); // reset to indeterminate state
 
 bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote, unsigned int& consumed);
-bool processResponse(LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr, int* delayMsec);
-bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, std::vector<uint8_t>& rewrittenResponse, uint16_t addRoom, bool* zeroScope);
+bool processResponse(char** response, uint16_t* responseLen, size_t* responseSize, LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr, size_t addRoom, std::vector<uint8_t>& rewrittenResponse, bool muted);
 
 bool checkQueryHeaders(const struct dnsheader* dh);
 
 #ifdef HAVE_DNSCRYPT
 extern std::vector<std::tuple<ComboAddress, std::shared_ptr<DNSCryptContext>, bool, int, std::string, std::set<int> > > g_dnsCryptLocals;
 
-bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr<DNSCryptQuery> dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy);
 int handleDNSCryptQuery(char* packet, uint16_t len, std::shared_ptr<DNSCryptQuery> query, uint16_t* decryptedQueryLen, bool tcp, time_t now, std::vector<uint8_t>& response);
 
 boost::optional<std::vector<uint8_t>> checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, time_t now, bool tcp);
@@ -1073,6 +1071,6 @@ extern bool g_addEDNSToSelfGeneratedResponses;
 
 static const size_t s_udpIncomingBufferSize{1500};
 
-typedef std::function<void(const ClientState& cs, const char* data, uint16_t dataSize, int delayMsec, const ComboAddress& dest, const ComboAddress& remote)> responseSender;
-std::shared_ptr<DownstreamState> processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, responseSender sender, bool& dropped);
+enum class ProcessQueryResult { Drop, SendAnswer, PassToBackend };
+ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend);