]> granicus.if.org Git - pdns/commitdiff
dnsdist: Check the flags to detect collisions in the packet cache
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 21 Jun 2018 10:38:50 +0000 (12:38 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 22 Jun 2018 09:09:13 +0000 (11:09 +0200)
In the unlikely but quite real event two queries with the same qname,
qtype and qclass but different EDNS options or flags end up with
the same hash, the packet cache would return a answer that might
not be suitable for the query. Reduce the odds by checking the
flags present in the dns header in addition to the qname, qtype
and qclass.
For the same reason we might need to consider storing the ECS
subnet if any.

pdns/dnsdist-cache.cc
pdns/dnsdist-cache.hh
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/test-dnsdistpacketcache_cc.cc

index 656e32f07408f60ba435870ddef02f63220bce73..be09c311c7da75758bb4a4fd4e0a704a36e6a912 100644 (file)
@@ -47,14 +47,16 @@ DNSDistPacketCache::~DNSDistPacketCache()
   }
 }
 
-bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp)
+bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp)
 {
-  if (cachedValue.tcp != tcp || cachedValue.qtype != qtype || cachedValue.qclass != qclass || cachedValue.qname != qname)
+  if (cachedValue.queryFlags != queryFlags || cachedValue.tcp != tcp || cachedValue.qtype != qtype || cachedValue.qclass != qclass || cachedValue.qname != qname) {
     return false;
+  }
+
   return true;
 }
 
-void DNSDistPacketCache::insertLocked(CacheShard& shard, uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp, CacheValue& newValue, time_t now, time_t newValidity)
+void DNSDistPacketCache::insertLocked(CacheShard& shard, uint32_t key, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp, CacheValue& newValue, time_t now, time_t newValidity)
 {
   auto& map = shard.d_map;
   /* check again now that we hold the lock to prevent a race */
@@ -76,7 +78,7 @@ void DNSDistPacketCache::insertLocked(CacheShard& shard, uint32_t key, const DNS
   CacheValue& value = it->second;
   bool wasExpired = value.validity <= now;
 
-  if (!wasExpired && !cachedValueMatches(value, qname, qtype, qclass, tcp)) {
+  if (!wasExpired && !cachedValueMatches(value, queryFlags, qname, qtype, qclass, tcp)) {
     d_insertCollisions++;
     return;
   }
@@ -89,10 +91,11 @@ void DNSDistPacketCache::insertLocked(CacheShard& shard, uint32_t key, const DNS
   value = newValue;
 }
 
-void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional<uint32_t> tempFailureTTL)
+void DNSDistPacketCache::insert(uint32_t key, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional<uint32_t> tempFailureTTL)
 {
-  if (responseLen < sizeof(dnsheader))
+  if (responseLen < sizeof(dnsheader)) {
     return;
+  }
 
   uint32_t minTTL;
 
@@ -136,6 +139,7 @@ void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qty
   newValue.qname = qname;
   newValue.qtype = qtype;
   newValue.qclass = qclass;
+  newValue.queryFlags = queryFlags;
   newValue.len = responseLen;
   newValue.validity = newValidity;
   newValue.added = now;
@@ -151,12 +155,12 @@ void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qty
       d_deferredInserts++;
       return;
     }
-    insertLocked(shard, key, qname, qtype, qclass, tcp, newValue, now, newValidity)    ;
+    insertLocked(shard, key, queryFlags, qname, qtype, qclass, tcp, newValue, now, newValidity)    ;
   }
   else {
     WriteLock w(&shard.d_lock);
 
-    insertLocked(shard, key, qname, qtype, qclass, tcp, newValue, now, newValidity)    ;
+    insertLocked(shard, key, queryFlags, qname, qtype, qclass, tcp, newValue, now, newValidity)    ;
   }
 }
 
@@ -202,7 +206,7 @@ bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t
     }
 
     /* check for collision */
-    if (!cachedValueMatches(value, *dq.qname, dq.qtype, dq.qclass, dq.tcp)) {
+    if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.dh)), *dq.qname, dq.qtype, dq.qclass, dq.tcp)) {
       d_lookupCollisions++;
       return false;
     }
index 6aae4fc0db7acba993624792222de6ffb7769b19..c0892ded650e076d2056c121f067f28c5b3277e1 100644 (file)
@@ -33,7 +33,7 @@ public:
   DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL=86400, uint32_t minTTL=0, uint32_t tempFailureTTL=60, uint32_t maxNegativeTTL=3600, uint32_t staleTTL=60, bool dontAge=false, uint32_t shards=1, bool deferrableInsertLock=true);
   ~DNSDistPacketCache();
 
-  void insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional<uint32_t> tempFailureTTL);
+  void insert(uint32_t key, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional<uint32_t> tempFailureTTL);
   bool get(const DNSQuestion& dq, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, uint32_t allowExpired=0, bool skipAging=false);
   void purgeExpired(size_t upTo=0);
   void expunge(size_t upTo=0);
@@ -62,6 +62,7 @@ private:
     DNSName qname;
     uint16_t qtype{0};
     uint16_t qclass{0};
+    uint16_t queryFlags{0};
     time_t added{0};
     time_t validity{0};
     uint16_t len{0};
@@ -91,9 +92,9 @@ private:
   };
 
   static uint32_t getKey(const std::string& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen, bool tcp);
-  static bool cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp);
+  static bool cachedValueMatches(const CacheValue& cachedValue, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp);
   uint32_t getShardIndex(uint32_t key) const;
-  void insertLocked(CacheShard& shard, uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp, CacheValue& newValue, time_t now, time_t newValidity);
+  void insertLocked(CacheShard& shard, uint32_t key, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp, CacheValue& newValue, time_t now, time_t newValidity);
 
   std::vector<CacheShard> d_shards;
 
index ab2bda916c06247f67da007c8eaa893dd1a15393..5268a3be8898ee18129d5789fb966a76f64c974e 100644 (file)
@@ -600,7 +600,7 @@ void* tcpClientThread(int pipefd)
         }
 
        if (packetCache && !dq.skipCache) {
-         packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true, dh->rcode, dq.tempFailureTTL);
+         packetCache->insert(cacheKey, origFlags, qname, qtype, qclass, response, responseLen, true, dh->rcode, dq.tempFailureTTL);
        }
 
 #ifdef HAVE_DNSCRYPT
index c03e5aff90c9d0b0e48a019e670eb72e52edc776..9e8ad0055562f45fa44c25449ee4edb6f48958bf 100644 (file)
@@ -490,7 +490,7 @@ try {
         }
 
         if (ids->packetCache && !ids->skipCache) {
-          ids->packetCache->insert(ids->cacheKey, ids->qname, ids->qtype, ids->qclass, response, responseLen, false, dh->rcode, ids->tempFailureTTL);
+          ids->packetCache->insert(ids->cacheKey, ids->origFlags, ids->qname, ids->qtype, ids->qclass, response, responseLen, false, dh->rcode, ids->tempFailureTTL);
         }
 
         if (ids->cs && !ids->cs->muted) {
index 0d46082bb36e1af622e3cf34345e28f737e3df06..a2ec1a7ae4457392ed9e718373fb7f041554f8b4 100644 (file)
@@ -3,6 +3,8 @@
 
 #include <boost/test/unit_test.hpp>
 
+#include "ednssubnet.hh"
+#include "ednsoptions.hh"
 #include "dnsdist.hh"
 #include "iputils.hh"
 #include "dnswriter.hh"
@@ -44,11 +46,12 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) {
       char responseBuf[4096];
       uint16_t responseBufSize = sizeof(responseBuf);
       uint32_t key = 0;
-      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime);
+      auto dh = reinterpret_cast<dnsheader*>(query.data());
+      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, dh, query.size(), query.size(), false, &queryTime);
       bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key);
       BOOST_CHECK_EQUAL(found, false);
 
-      PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, 0, boost::none);
+      PC.insert(key, *(getFlagsFromDNSHeader(dh)), a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, 0, boost::none);
 
       found = PC.get(dq, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, 0, true);
       if (found == true) {
@@ -140,17 +143,18 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheServFailTTL) {
     char responseBuf[4096];
     uint16_t responseBufSize = sizeof(responseBuf);
     uint32_t key = 0;
-    DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime);
+    auto dh = reinterpret_cast<dnsheader*>(query.data());
+    DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, dh, query.size(), query.size(), false, &queryTime);
     bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key);
     BOOST_CHECK_EQUAL(found, false);
 
     // Insert with failure-TTL of 0 (-> should not enter cache).
-    PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, RCode::ServFail, boost::optional<uint32_t>(0));
+    PC.insert(key, *(getFlagsFromDNSHeader(dh)), a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, RCode::ServFail, boost::optional<uint32_t>(0));
     found = PC.get(dq, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, 0, true);
     BOOST_CHECK_EQUAL(found, false);
 
     // Insert with failure-TTL non-zero (-> should enter cache).
-    PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, RCode::ServFail, boost::optional<uint32_t>(300));
+    PC.insert(key, *(getFlagsFromDNSHeader(dh)), a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, RCode::ServFail, boost::optional<uint32_t>(300));
     found = PC.get(dq, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, 0, true);
     BOOST_CHECK_EQUAL(found, true);
   }
@@ -192,11 +196,12 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNoDataTTL) {
     char responseBuf[4096];
     uint16_t responseBufSize = sizeof(responseBuf);
     uint32_t key = 0;
-    DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime);
+    auto dh = reinterpret_cast<dnsheader*>(query.data());
+    DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, dh, query.size(), query.size(), false, &queryTime);
     bool found = PC.get(dq, name.wirelength(), 0, responseBuf, &responseBufSize, &key);
     BOOST_CHECK_EQUAL(found, false);
 
-    PC.insert(key, name, QType::A, QClass::IN, reinterpret_cast<const char*>(response.data()), responseLen, false, RCode::NoError, boost::none);
+    PC.insert(key, *(getFlagsFromDNSHeader(dh)), name, QType::A, QClass::IN, reinterpret_cast<const char*>(response.data()), responseLen, false, RCode::NoError, boost::none);
     found = PC.get(dq, name.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, 0, true);
     BOOST_CHECK_EQUAL(found, true);
     sleep(2);
@@ -243,11 +248,12 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNXDomainTTL) {
     char responseBuf[4096];
     uint16_t responseBufSize = sizeof(responseBuf);
     uint32_t key = 0;
-    DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime);
+    auto dh = reinterpret_cast<dnsheader*>(query.data());
+    DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, dh, query.size(), query.size(), false, &queryTime);
     bool found = PC.get(dq, name.wirelength(), 0, responseBuf, &responseBufSize, &key);
     BOOST_CHECK_EQUAL(found, false);
 
-    PC.insert(key, name, QType::A, QClass::IN, reinterpret_cast<const char*>(response.data()), responseLen, false, RCode::NXDomain, boost::none);
+    PC.insert(key, *(getFlagsFromDNSHeader(dh)), name, QType::A, QClass::IN, reinterpret_cast<const char*>(response.data()), responseLen, false, RCode::NXDomain, boost::none);
     found = PC.get(dq, name.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, 0, true);
     BOOST_CHECK_EQUAL(found, true);
     sleep(2);
@@ -291,10 +297,11 @@ static void *threadMangler(void* off)
       char responseBuf[4096];
       uint16_t responseBufSize = sizeof(responseBuf);
       uint32_t key = 0;
-      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime);
+      auto dh = reinterpret_cast<dnsheader*>(query.data());
+      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, dh, query.size(), query.size(), false, &queryTime);
       PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key);
 
-      PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, 0, boost::none);
+      PC.insert(key, *(getFlagsFromDNSHeader(dh)), a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, 0, boost::none);
     }
   }
   catch(PDNSException& e) {