From 8dcdbdb11e711db9140afe3f3728f0a22b550101 Mon Sep 17 00:00:00 2001 From: Remi Gacogne Date: Thu, 21 Jun 2018 12:38:50 +0200 Subject: [PATCH] dnsdist: Check the flags to detect collisions in the packet cache In the unlikely but quite real event two queries with the same qname, qtype and qclass but different EDNS options or flags end up with the same hash, the packet cache would return a answer that might not be suitable for the query. Reduce the odds by checking the flags present in the dns header in addition to the qname, qtype and qclass. For the same reason we might need to consider storing the ECS subnet if any. --- pdns/dnsdist-cache.cc | 22 +++++++++++++--------- pdns/dnsdist-cache.hh | 7 ++++--- pdns/dnsdist-tcp.cc | 2 +- pdns/dnsdist.cc | 2 +- pdns/test-dnsdistpacketcache_cc.cc | 29 ++++++++++++++++++----------- 5 files changed, 37 insertions(+), 25 deletions(-) diff --git a/pdns/dnsdist-cache.cc b/pdns/dnsdist-cache.cc index 656e32f07..be09c311c 100644 --- a/pdns/dnsdist-cache.cc +++ b/pdns/dnsdist-cache.cc @@ -47,14 +47,16 @@ DNSDistPacketCache::~DNSDistPacketCache() } } -bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp) +bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp) { - if (cachedValue.tcp != tcp || cachedValue.qtype != qtype || cachedValue.qclass != qclass || cachedValue.qname != qname) + if (cachedValue.queryFlags != queryFlags || cachedValue.tcp != tcp || cachedValue.qtype != qtype || cachedValue.qclass != qclass || cachedValue.qname != qname) { return false; + } + return true; } -void DNSDistPacketCache::insertLocked(CacheShard& shard, uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp, CacheValue& newValue, time_t now, time_t newValidity) +void DNSDistPacketCache::insertLocked(CacheShard& shard, uint32_t key, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp, CacheValue& newValue, time_t now, time_t newValidity) { auto& map = shard.d_map; /* check again now that we hold the lock to prevent a race */ @@ -76,7 +78,7 @@ void DNSDistPacketCache::insertLocked(CacheShard& shard, uint32_t key, const DNS CacheValue& value = it->second; bool wasExpired = value.validity <= now; - if (!wasExpired && !cachedValueMatches(value, qname, qtype, qclass, tcp)) { + if (!wasExpired && !cachedValueMatches(value, queryFlags, qname, qtype, qclass, tcp)) { d_insertCollisions++; return; } @@ -89,10 +91,11 @@ void DNSDistPacketCache::insertLocked(CacheShard& shard, uint32_t key, const DNS value = newValue; } -void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional tempFailureTTL) +void DNSDistPacketCache::insert(uint32_t key, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional tempFailureTTL) { - if (responseLen < sizeof(dnsheader)) + if (responseLen < sizeof(dnsheader)) { return; + } uint32_t minTTL; @@ -136,6 +139,7 @@ void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qty newValue.qname = qname; newValue.qtype = qtype; newValue.qclass = qclass; + newValue.queryFlags = queryFlags; newValue.len = responseLen; newValue.validity = newValidity; newValue.added = now; @@ -151,12 +155,12 @@ void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qty d_deferredInserts++; return; } - insertLocked(shard, key, qname, qtype, qclass, tcp, newValue, now, newValidity) ; + insertLocked(shard, key, queryFlags, qname, qtype, qclass, tcp, newValue, now, newValidity) ; } else { WriteLock w(&shard.d_lock); - insertLocked(shard, key, qname, qtype, qclass, tcp, newValue, now, newValidity) ; + insertLocked(shard, key, queryFlags, qname, qtype, qclass, tcp, newValue, now, newValidity) ; } } @@ -202,7 +206,7 @@ bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t } /* check for collision */ - if (!cachedValueMatches(value, *dq.qname, dq.qtype, dq.qclass, dq.tcp)) { + if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.dh)), *dq.qname, dq.qtype, dq.qclass, dq.tcp)) { d_lookupCollisions++; return false; } diff --git a/pdns/dnsdist-cache.hh b/pdns/dnsdist-cache.hh index 6aae4fc0d..c0892ded6 100644 --- a/pdns/dnsdist-cache.hh +++ b/pdns/dnsdist-cache.hh @@ -33,7 +33,7 @@ public: DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL=86400, uint32_t minTTL=0, uint32_t tempFailureTTL=60, uint32_t maxNegativeTTL=3600, uint32_t staleTTL=60, bool dontAge=false, uint32_t shards=1, bool deferrableInsertLock=true); ~DNSDistPacketCache(); - void insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional tempFailureTTL); + void insert(uint32_t key, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional tempFailureTTL); bool get(const DNSQuestion& dq, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, uint32_t allowExpired=0, bool skipAging=false); void purgeExpired(size_t upTo=0); void expunge(size_t upTo=0); @@ -62,6 +62,7 @@ private: DNSName qname; uint16_t qtype{0}; uint16_t qclass{0}; + uint16_t queryFlags{0}; time_t added{0}; time_t validity{0}; uint16_t len{0}; @@ -91,9 +92,9 @@ private: }; static uint32_t getKey(const std::string& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen, bool tcp); - static bool cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp); + static bool cachedValueMatches(const CacheValue& cachedValue, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp); uint32_t getShardIndex(uint32_t key) const; - void insertLocked(CacheShard& shard, uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp, CacheValue& newValue, time_t now, time_t newValidity); + void insertLocked(CacheShard& shard, uint32_t key, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp, CacheValue& newValue, time_t now, time_t newValidity); std::vector d_shards; diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index ab2bda916..5268a3be8 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -600,7 +600,7 @@ void* tcpClientThread(int pipefd) } if (packetCache && !dq.skipCache) { - packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true, dh->rcode, dq.tempFailureTTL); + packetCache->insert(cacheKey, origFlags, qname, qtype, qclass, response, responseLen, true, dh->rcode, dq.tempFailureTTL); } #ifdef HAVE_DNSCRYPT diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index c03e5aff9..9e8ad0055 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -490,7 +490,7 @@ try { } if (ids->packetCache && !ids->skipCache) { - ids->packetCache->insert(ids->cacheKey, ids->qname, ids->qtype, ids->qclass, response, responseLen, false, dh->rcode, ids->tempFailureTTL); + ids->packetCache->insert(ids->cacheKey, ids->origFlags, ids->qname, ids->qtype, ids->qclass, response, responseLen, false, dh->rcode, ids->tempFailureTTL); } if (ids->cs && !ids->cs->muted) { diff --git a/pdns/test-dnsdistpacketcache_cc.cc b/pdns/test-dnsdistpacketcache_cc.cc index 0d46082bb..a2ec1a7ae 100644 --- a/pdns/test-dnsdistpacketcache_cc.cc +++ b/pdns/test-dnsdistpacketcache_cc.cc @@ -3,6 +3,8 @@ #include +#include "ednssubnet.hh" +#include "ednsoptions.hh" #include "dnsdist.hh" #include "iputils.hh" #include "dnswriter.hh" @@ -44,11 +46,12 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { char responseBuf[4096]; uint16_t responseBufSize = sizeof(responseBuf); uint32_t key = 0; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime); + auto dh = reinterpret_cast(query.data()); + DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, dh, query.size(), query.size(), false, &queryTime); bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key); BOOST_CHECK_EQUAL(found, false); - PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, 0, boost::none); + PC.insert(key, *(getFlagsFromDNSHeader(dh)), a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, 0, boost::none); found = PC.get(dq, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, 0, true); if (found == true) { @@ -140,17 +143,18 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheServFailTTL) { char responseBuf[4096]; uint16_t responseBufSize = sizeof(responseBuf); uint32_t key = 0; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime); + auto dh = reinterpret_cast(query.data()); + DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, dh, query.size(), query.size(), false, &queryTime); bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key); BOOST_CHECK_EQUAL(found, false); // Insert with failure-TTL of 0 (-> should not enter cache). - PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, RCode::ServFail, boost::optional(0)); + PC.insert(key, *(getFlagsFromDNSHeader(dh)), a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, RCode::ServFail, boost::optional(0)); found = PC.get(dq, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, 0, true); BOOST_CHECK_EQUAL(found, false); // Insert with failure-TTL non-zero (-> should enter cache). - PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, RCode::ServFail, boost::optional(300)); + PC.insert(key, *(getFlagsFromDNSHeader(dh)), a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, RCode::ServFail, boost::optional(300)); found = PC.get(dq, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, 0, true); BOOST_CHECK_EQUAL(found, true); } @@ -192,11 +196,12 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNoDataTTL) { char responseBuf[4096]; uint16_t responseBufSize = sizeof(responseBuf); uint32_t key = 0; - DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime); + auto dh = reinterpret_cast(query.data()); + DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, dh, query.size(), query.size(), false, &queryTime); bool found = PC.get(dq, name.wirelength(), 0, responseBuf, &responseBufSize, &key); BOOST_CHECK_EQUAL(found, false); - PC.insert(key, name, QType::A, QClass::IN, reinterpret_cast(response.data()), responseLen, false, RCode::NoError, boost::none); + PC.insert(key, *(getFlagsFromDNSHeader(dh)), name, QType::A, QClass::IN, reinterpret_cast(response.data()), responseLen, false, RCode::NoError, boost::none); found = PC.get(dq, name.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, 0, true); BOOST_CHECK_EQUAL(found, true); sleep(2); @@ -243,11 +248,12 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNXDomainTTL) { char responseBuf[4096]; uint16_t responseBufSize = sizeof(responseBuf); uint32_t key = 0; - DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime); + auto dh = reinterpret_cast(query.data()); + DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, dh, query.size(), query.size(), false, &queryTime); bool found = PC.get(dq, name.wirelength(), 0, responseBuf, &responseBufSize, &key); BOOST_CHECK_EQUAL(found, false); - PC.insert(key, name, QType::A, QClass::IN, reinterpret_cast(response.data()), responseLen, false, RCode::NXDomain, boost::none); + PC.insert(key, *(getFlagsFromDNSHeader(dh)), name, QType::A, QClass::IN, reinterpret_cast(response.data()), responseLen, false, RCode::NXDomain, boost::none); found = PC.get(dq, name.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, 0, true); BOOST_CHECK_EQUAL(found, true); sleep(2); @@ -291,10 +297,11 @@ static void *threadMangler(void* off) char responseBuf[4096]; uint16_t responseBufSize = sizeof(responseBuf); uint32_t key = 0; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime); + auto dh = reinterpret_cast(query.data()); + DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, dh, query.size(), query.size(), false, &queryTime); PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key); - PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, 0, boost::none); + PC.insert(key, *(getFlagsFromDNSHeader(dh)), a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, 0, boost::none); } } catch(PDNSException& e) { -- 2.40.0