From a176d205922081b97ed404c6615ef43e807202e3 Mon Sep 17 00:00:00 2001 From: Remi Gacogne Date: Fri, 19 Feb 2016 12:58:05 +0100 Subject: [PATCH] dnsdist: Do not share the packet cache entries between TCP and UDP It would obviously cause issues, for example with truncated responses. It is possible to disable the cache for all TCP queries by using something like: addAction(TCPRule(true), SkipCacheAction()) --- pdns/dnsdist-cache.cc | 18 +++--- pdns/dnsdist-cache.hh | 9 +-- pdns/dnsdist-tcp.cc | 4 +- pdns/dnsdist.cc | 4 +- pdns/test-dnsdistpacketcache_cc.cc | 16 ++--- regression-tests.dnsdist/test_Advanced.py | 74 ++++++----------------- 6 files changed, 45 insertions(+), 80 deletions(-) diff --git a/pdns/dnsdist-cache.cc b/pdns/dnsdist-cache.cc index 092e02c26..29a1c587e 100644 --- a/pdns/dnsdist-cache.cc +++ b/pdns/dnsdist-cache.cc @@ -15,14 +15,14 @@ DNSDistPacketCache::~DNSDistPacketCache() WriteLock l(&d_lock); } -bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass) +bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp) { - if (cachedValue.qname != qname || cachedValue.qtype != qtype || cachedValue.qclass != qclass) + if (cachedValue.tcp != tcp || cachedValue.qtype != qtype || cachedValue.qclass != qclass || cachedValue.qname != qname) return false; return true; } -void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen) +void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp) { if (responseLen == 0) return; @@ -56,6 +56,7 @@ void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qty newValue.len = responseLen; newValue.validity = newValidity; newValue.added = now; + newValue.tcp = tcp; newValue.value = std::string(response, responseLen); { @@ -77,7 +78,7 @@ void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qty CacheValue& value = it->second; bool wasExpired = value.validity <= now; - if (!wasExpired && !cachedValueMatches(value, qname, qtype, qclass)) { + if (!wasExpired && !cachedValueMatches(value, qname, qtype, qclass, tcp)) { d_insertCollisions++; return; } @@ -90,9 +91,9 @@ void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qty } } -bool DNSDistPacketCache::get(const unsigned char* query, uint16_t queryLen, const DNSName& qname, uint16_t qtype, uint16_t qclass, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, bool skipAging) +bool DNSDistPacketCache::get(const unsigned char* query, uint16_t queryLen, const DNSName& qname, uint16_t qtype, uint16_t qclass, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, bool tcp, uint32_t* keyOut, bool skipAging) { - uint32_t key = getKey(qname, consumed, query, queryLen); + uint32_t key = getKey(qname, consumed, query, queryLen, tcp); if (keyOut) *keyOut = key; @@ -122,7 +123,7 @@ bool DNSDistPacketCache::get(const unsigned char* query, uint16_t queryLen, cons } /* check for collision */ - if (!cachedValueMatches(value, qname, qtype, qclass)) { + if (!cachedValueMatches(value, qname, qtype, qclass, tcp)) { d_misses++; d_lookupCollisions++; return false; @@ -236,7 +237,7 @@ uint32_t DNSDistPacketCache::getMinTTL(const char* packet, uint16_t length) return result; } -uint32_t DNSDistPacketCache::getKey(const DNSName& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen) +uint32_t DNSDistPacketCache::getKey(const DNSName& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen, bool tcp) { uint32_t result = 0; /* skip the query ID */ @@ -244,6 +245,7 @@ uint32_t DNSDistPacketCache::getKey(const DNSName& qname, uint16_t consumed, con string lc(qname.toDNSStringLC()); result = burtle((const unsigned char*) lc.c_str(), lc.length(), result); result = burtle(packet + sizeof(dnsheader) + consumed, packetLen - (sizeof(dnsheader) + consumed), result); + result = burtle((const unsigned char*) &tcp, sizeof(tcp), result); return result; } diff --git a/pdns/dnsdist-cache.hh b/pdns/dnsdist-cache.hh index 407b23849..76b1ad2f3 100644 --- a/pdns/dnsdist-cache.hh +++ b/pdns/dnsdist-cache.hh @@ -10,8 +10,8 @@ public: DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL=86400, uint32_t minTTL=60); ~DNSDistPacketCache(); - void insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen); - bool get(const unsigned char* query, uint16_t queryLen, const DNSName& qname, uint16_t qtype, uint16_t qclass, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, bool skipAging=false); + void insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp); + bool get(const unsigned char* query, uint16_t queryLen, const DNSName& qname, uint16_t qtype, uint16_t qclass, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, bool tcp, uint32_t* keyOut, bool skipAging=false); void purge(size_t upTo=0); void expunge(const DNSName& name, uint16_t qtype=QType::ANY); bool isFull(); @@ -38,10 +38,11 @@ private: time_t added{0}; time_t validity{0}; uint16_t len{0}; + bool tcp{false}; }; - static uint32_t getKey(const DNSName& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen); - static bool cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass); + static uint32_t getKey(const DNSName& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen, bool tcp); + static bool cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp); pthread_rwlock_t d_lock; std::unordered_map d_map; diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index 932627842..e2014152f 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -348,7 +348,7 @@ void* tcpClientThread(int pipefd) if (serverPool->packetCache && !dq.skipCache) { char cachedResponse[4096]; uint16_t cachedResponseSize = sizeof cachedResponse; - if (serverPool->packetCache->get((unsigned char*) query, dq.len, *dq.qname, dq.qtype, dq.qclass, consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey)) { + if (serverPool->packetCache->get((unsigned char*) query, dq.len, *dq.qname, dq.qtype, dq.qclass, consumed, dq.dh->id, cachedResponse, &cachedResponseSize, true, &cacheKey)) { if (putNonBlockingMsgLen(ci.fd, cachedResponseSize, g_tcpSendTimeout)) writen2WithTimeout(ci.fd, cachedResponse, cachedResponseSize, g_tcpSendTimeout); g_stats.cacheHits++; @@ -476,7 +476,7 @@ void* tcpClientThread(int pipefd) } if (serverPool->packetCache && !dq.skipCache) { - serverPool->packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen); + serverPool->packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true); } #ifdef HAVE_DNSCRYPT diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index f11f92475..795be3161 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -248,7 +248,7 @@ void* responderThread(std::shared_ptr state) g_stats.responses++; if (ids->packetCache && !ids->skipCache) { - ids->packetCache->insert(ids->cacheKey, ids->qname, ids->qtype, ids->qclass, response, responseLen); + ids->packetCache->insert(ids->cacheKey, ids->qname, ids->qtype, ids->qclass, response, responseLen, false); } #ifdef HAVE_DNSCRYPT @@ -811,7 +811,7 @@ try if (serverPool->packetCache && !dq.skipCache) { char cachedResponse[4096]; uint16_t cachedResponseSize = sizeof cachedResponse; - if (serverPool->packetCache->get((unsigned char*) query, dq.len, *dq.qname, dq.qtype, dq.qclass, consumed, dh->id, cachedResponse, &cachedResponseSize, &cacheKey)) { + if (serverPool->packetCache->get((unsigned char*) query, dq.len, *dq.qname, dq.qtype, dq.qclass, consumed, dh->id, cachedResponse, &cachedResponseSize, false, &cacheKey)) { ComboAddress dest; if(HarvestDestinationAddress(&msgh, &dest)) sendfromto(cs->udpFD, cachedResponse, cachedResponseSize, 0, dest, remote); diff --git a/pdns/test-dnsdistpacketcache_cc.cc b/pdns/test-dnsdistpacketcache_cc.cc index 9a21a1eaf..f1d1085fc 100644 --- a/pdns/test-dnsdistpacketcache_cc.cc +++ b/pdns/test-dnsdistpacketcache_cc.cc @@ -39,12 +39,12 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { char responseBuf[4096]; uint16_t responseBufSize = sizeof(responseBuf); uint32_t key = 0; - bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, &key); + bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, false, &key); BOOST_CHECK_EQUAL(found, false); - PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen); + PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false); - found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, true); + found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, false, &key, true); if (found == true) { BOOST_CHECK_EQUAL(responseBufSize, responseLen); int match = memcmp(responseBuf, response.data(), responseLen); @@ -68,7 +68,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { char responseBuf[4096]; uint16_t responseBufSize = sizeof(responseBuf); uint32_t key = 0; - bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, &key); + bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, false, &key); if (found == true) { PC.expunge(a); deleted++; @@ -88,7 +88,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { uint32_t key = 0; char response[4096]; uint16_t responseSize = sizeof(response); - if(PC.get(query.data(), len, a, QType::A, QClass::IN, a.wirelength(), pwQ.getHeader()->id, response, &responseSize, &key)) { + if(PC.get(query.data(), len, a, QType::A, QClass::IN, a.wirelength(), pwQ.getHeader()->id, response, &responseSize, false, &key)) { matches++; } } @@ -126,9 +126,9 @@ static void *threadMangler(void* a) char responseBuf[4096]; uint16_t responseBufSize = sizeof(responseBuf); uint32_t key = 0; - PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, &key); + PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, false, &key); - PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen); + PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false); } } catch(PDNSException& e) { @@ -155,7 +155,7 @@ static void *threadReader(void* a) char responseBuf[4096]; uint16_t responseBufSize = sizeof(responseBuf); uint32_t key = 0; - bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, &key); + bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, false, &key); if (!found) { g_missing++; } diff --git a/regression-tests.dnsdist/test_Advanced.py b/regression-tests.dnsdist/test_Advanced.py index 64bda3431..a03102833 100644 --- a/regression-tests.dnsdist/test_Advanced.py +++ b/regression-tests.dnsdist/test_Advanced.py @@ -1256,22 +1256,37 @@ class TestAdvancedCaching(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEquals(query, receivedQuery) self.assertEquals(receivedResponse, response) for _ in range(numberOfQueries): (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) - receivedResponse.id = response.id self.assertEquals(receivedResponse, response) + total = 0 + for key in TestAdvancedCaching._responsesCounter: + total += TestAdvancedCaching._responsesCounter[key] + TestAdvancedCaching._responsesCounter[key] = 0 + + self.assertEquals(total, 1) + + # TCP should not be cached + # first query to fill the cache + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(receivedResponse, response) + + for _ in range(numberOfQueries): (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False) - receivedResponse.id = response.id self.assertEquals(receivedResponse, response) total = 0 for key in TestAdvancedCaching._responsesCounter: total += TestAdvancedCaching._responsesCounter[key] + TestAdvancedCaching._responsesCounter[key] = 0 self.assertEquals(total, 1) @@ -1298,7 +1313,6 @@ class TestAdvancedCaching(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEquals(query, receivedQuery) self.assertEquals(receivedResponse, response) @@ -1306,7 +1320,6 @@ class TestAdvancedCaching(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEquals(query, receivedQuery) self.assertEquals(receivedResponse, response) @@ -1340,18 +1353,12 @@ class TestAdvancedCaching(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEquals(query, receivedQuery) self.assertEquals(receivedResponse, response) misses += 1 # next queries should hit the cache (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) - receivedResponse.id = response.id - self.assertEquals(receivedResponse, response) - - (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False) - receivedResponse.id = response.id self.assertEquals(receivedResponse, response) # now we wait a bit for the cache entry to expire @@ -1362,18 +1369,12 @@ class TestAdvancedCaching(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEquals(query, receivedQuery) self.assertEquals(receivedResponse, response) misses += 1 # following queries should hit the cache again (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) - receivedResponse.id = response.id - self.assertEquals(receivedResponse, response) - - (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False) - receivedResponse.id = response.id self.assertEquals(receivedResponse, response) total = 0 @@ -1415,18 +1416,12 @@ class TestAdvancedCaching(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEquals(query, receivedQuery) self.assertEquals(receivedResponse, response) misses += 1 # next queries should hit the cache (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) - receivedResponse.id = response.id - self.assertEquals(receivedResponse, response) - - (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False) - receivedResponse.id = response.id self.assertEquals(receivedResponse, response) # now we wait a bit for the cache entry to expire @@ -1437,18 +1432,12 @@ class TestAdvancedCaching(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEquals(query, receivedQuery) self.assertEquals(receivedResponse, response) misses += 1 # following queries should hit the cache again (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) - receivedResponse.id = response.id - self.assertEquals(receivedResponse, response) - - (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False) - receivedResponse.id = response.id self.assertEquals(receivedResponse, response) total = 0 @@ -1481,20 +1470,12 @@ class TestAdvancedCaching(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEquals(query, receivedQuery) self.assertEquals(receivedResponse, response) misses += 1 # next queries should hit the cache (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) - receivedResponse.id = response.id - self.assertEquals(receivedResponse, response) - for an in receivedResponse.answer: - self.assertTrue(an.ttl <= ttl) - - (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False) - receivedResponse.id = response.id self.assertEquals(receivedResponse, response) for an in receivedResponse.answer: self.assertTrue(an.ttl <= ttl) @@ -1504,13 +1485,6 @@ class TestAdvancedCaching(DNSDistTest): # next queries should hit the cache (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) - receivedResponse.id = response.id - self.assertEquals(receivedResponse, response) - for an in receivedResponse.answer: - self.assertTrue(an.ttl < ttl) - - (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False) - receivedResponse.id = response.id self.assertEquals(receivedResponse, response) for an in receivedResponse.answer: self.assertTrue(an.ttl < ttl) @@ -1549,18 +1523,13 @@ class TestAdvancedCaching(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEquals(query, receivedQuery) self.assertEquals(receivedResponse, response) # different case query should still hit the cache (_, receivedResponse) = self.sendUDPQuery(differentCaseQuery, response=None, useQueue=False) - receivedResponse.id = differentCaseResponse.id self.assertEquals(receivedResponse, differentCaseResponse) - (_, receivedResponse) = self.sendTCPQuery(differentCaseQuery, response=None, useQueue=False) - receivedResponse.id = differentCaseResponse.id - self.assertEquals(receivedResponse, differentCaseResponse) class TestAdvancedCachingWithExistingEDNS(DNSDistTest): @@ -1592,7 +1561,6 @@ class TestAdvancedCachingWithExistingEDNS(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEquals(query, receivedQuery) self.assertEquals(response, receivedResponse) misses += 1 @@ -1610,7 +1578,6 @@ class TestAdvancedCachingWithExistingEDNS(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEquals(query, receivedQuery) self.assertEquals(response, receivedResponse) misses += 1 @@ -1648,7 +1615,6 @@ class TestAdvancedLogAction(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEquals(query, receivedQuery) self.assertEquals(response, receivedResponse) @@ -1681,7 +1647,6 @@ class TestAdvancedDNSSEC(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEquals(query, receivedQuery) self.assertEquals(response, receivedResponse) @@ -1689,7 +1654,6 @@ class TestAdvancedDNSSEC(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEquals(query, receivedQuery) self.assertEquals(response, receivedResponse) @@ -1743,7 +1707,6 @@ class TestAdvancedQClass(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEquals(query, receivedQuery) self.assertEquals(response, receivedResponse) @@ -1751,6 +1714,5 @@ class TestAdvancedQClass(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEquals(query, receivedQuery) self.assertEquals(response, receivedResponse) -- 2.40.0