]> granicus.if.org Git - pdns/commitdiff
dnsdist: Add a simple Packet Cache
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 15 Feb 2016 08:49:36 +0000 (09:49 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 15 Feb 2016 08:49:36 +0000 (09:49 +0100)
Per-pool Packet Cache, using the whole query packet minus the id
has hashing key, to prevent issue related to:
* EDNS Payload size
* ECS
* DNSSEC

The packet cache is not enabled by default, and can be skipped
for specific queries using SkipCacheAction.
It's a per-pool cache, in case you have different responses, but
you can use the same cache for several pools if you want to.

We cache the whole response and age the TTLs when fetching the
response from the cache.

This commit also refactors a bit the way server pools are handled
to be able to have a per-pool cache, and to avoid scanning all
servers when looking for the ones in a given pool.

It is using a fixed-size unordered_map to prevent rehashing. It
is not very efficient with regard to cache cleaning, but I really
would like to use only a ReadLock on the fastpath, and using a
multi index container and moving cache entries to the back / front
on hit / miss would prevent that.

Health checks are moved to a different thread, to prevent them from
being impacted by the cache cleaning operation being slow.

17 files changed:
pdns/README-dnsdist.md
pdns/dnsdist-cache.cc [new file with mode: 0644]
pdns/dnsdist-cache.hh [new file with mode: 0644]
pdns/dnsdist-lua.cc
pdns/dnsdist-lua2.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/dnsdist-cache.cc [new symlink]
pdns/dnsdistdist/dnsdist-cache.hh [new symlink]
pdns/dnsdistdist/test-dnsdistpacketcache_cc.cc [new symlink]
pdns/dnsparser.cc
pdns/dnsparser.hh
pdns/dnsrulactions.hh
pdns/test-dnsdistpacketcache_cc.cc [new file with mode: 0644]
regression-tests.dnsdist/test_Advanced.py

index 6d7b5c9f468e5adaf7fbf22d346eca45309e0b20..3a34f24c8c1e55dd44a2a9de579c665f6c814b5c 100644 (file)
@@ -313,6 +313,7 @@ Current actions are:
  * Delay a response by n milliseconds (DelayAction), over UDP only
  * Modify query to clear the RD or CD bit
  * Add the source MAC address to the query (MacAddrAction)
+ * Skip the cache, if any
 
 Rules can be added via:
 
@@ -695,6 +696,25 @@ fe80::/10
 ::/0
 ```
 
+Caching
+-------
+
+`dnsdist` implements a simple but effective packet cache, not enabled by default.
+It is enabled per-pool, but the same cache can be shared between several pools.
+The first step is to define a cache, then to assign that cache to the chosen pool,
+the default one being represented by the empty string:
+
+```
+pc = newPacketCache(10000, 86400, 600)
+getPool(""):setCache(pc)
+```
+
+The first parameter is the maximum number of entries stored in the cache, the
+second one, optional, is the maximum lifetime of an entry in the cache, in seconds,
+and the last one, optional too, is the minimum TTL an entry should have to be considered
+for insertion in the cache.
+
+
 Carbon/Graphite/Metronome
 -------------------------
 To emit metrics to Graphite, or any other software supporting the Carbon protocol, use:
@@ -856,6 +876,7 @@ instantiate a server with additional parameters
     * `QPSPoolAction()`: set the packet into the specified pool only if it does not exceed the specified QPS limits
     * `QPSAction()`: drop these packets if the QPS limits are exceeded
     * `RCodeAction()`: reply immediatly by turning the query into a response with the specified rcode
+    * `SkipCacheAction()`: don't lookup the cache for this query, don't store the answer
     * `SpoofAction()`: forge a response with the specified IPv4 (for an A query) or IPv6 (for an AAAA). If you specify two addresses, the first one should be an IPv4 and will be used for A, the second an IPv6 for an AAAA
     * `SpoofCNAMEAction()`: forge a response with the specified CNAME value
     * `TCAction()`: create answer to query with TC and RD bits set, to move to TCP/IP
@@ -875,7 +896,9 @@ instantiate a server with additional parameters
     * `addPoolRule(netmask, pool)`: send queries to this netmask to that pool
     * `addPoolRule({netmask, netmask}, pool)`: send queries to these netmasks to that pool  
     * `addQPSPoolRule(x, limit, pool)`: like `addPoolRule`, but only select at most 'limit' queries/s for this pool
+    * `getPool(poolname)`: return the ServerPool named `poolname`
     * `getPoolServers(pool)`: return servers part of this pool
+    * `showPools()`: list the current server pools
  * Lua Action related:
     * `addLuaAction(x, func)`: where 'x' is all the combinations from `addPoolRule`, and func is a 
       function with parameters remote, qname, qtype, dh and len, which returns an action to be taken 
@@ -914,6 +937,16 @@ instantiate a server with additional parameters
     * `exceedRespByterate(rate, seconds)`: get set of addresses that exeeded `rate` bytes/s answers over `seconds` seconds
     * `exceedQRate(rate, seconds)`: get set of address that exceed `rate` queries/s over `seconds` seconds
     * `exceedQTypeRate(type, rate, seconds)`: get set of address that exceed `rate` queries/s for queries of type `type` over `seconds` seconds
+ * ServerPool related:
+    * `getCache()`: return the current packet cache, if any
+    * `setCache(PacketCache)`: set the cache for this pool
+ * PacketCache related:
+    * `expungeByName(DNSName)`: remove entries matching the supplied DNSName from the cache
+    * `isFull()`: return true if the cache has reached the maximum number of entries
+    * `newPacketCache(maxEntries, maxTTL=86400, minTTL=60)`: return a new PacketCache
+    * `printStats()`: print the cache stats (hits, misses, deferred lookups and deferred inserts)
+    * `purge()`: remove entries from the cache until it the number of entries is lower than the maximum, starting with expired ones.
+    * `toString()`: return the number of entries in the Packet Cache, and the maximum number of entries
  * Advanced functions for writing your own policies and hooks
     * ComboAddress related:
         * `newCA(address)`: return a new ComboAddress
@@ -963,6 +996,7 @@ instantiate a server with additional parameters
     * `setTCPSendTimeout(n)`: set the write timeout on TCP connections from the client, in seconds
     * `setMaxTCPClientThreads(n)`: set the maximum of TCP client threads, handling TCP connections
     * `setMaxUDPOutstanding(n)`: set the maximum number of outstanding UDP queries to a given backend server. This can only be set at configuration time
+    * `setCacheCleaningDelay(n)`: set the interval in seconds between to run of the cache cleaning algorithm, removing expired entries
  * DNSCrypt related:
     * `addDNSCryptBind("127.0.0.1:8443", "provider name", "/path/to/resolver.cert", "/path/to/resolver.key"):` listen to incoming DNSCrypt queries on 127.0.0.1 port 8443, with a provider name of "provider name", using a resolver certificate and associated key stored respectively in the `resolver.cert` and `resolver.key` files
     * `generateDNSCryptProviderKeys("/path/to/providerPublic.key", "/path/to/providerPrivate.key"):` generate a new provider keypair
diff --git a/pdns/dnsdist-cache.cc b/pdns/dnsdist-cache.cc
new file mode 100644 (file)
index 0000000..092e02c
--- /dev/null
@@ -0,0 +1,254 @@
+#include "dolog.hh"
+#include "dnsdist-cache.hh"
+#include "dnsparser.hh"
+
+DNSDistPacketCache::DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL, uint32_t minTTL): d_maxEntries(maxEntries), d_maxTTL(maxTTL), d_minTTL(minTTL)
+{
+  pthread_rwlock_init(&d_lock, 0);
+  /* we reserve maxEntries + 1 to avoid rehashing from occuring
+     when we get to maxEntries, as it means a load factor of 1 */
+  d_map.reserve(maxEntries + 1);
+}
+
+DNSDistPacketCache::~DNSDistPacketCache()
+{
+  WriteLock l(&d_lock);
+}
+
+bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass)
+{
+  if (cachedValue.qname != qname || cachedValue.qtype != qtype || cachedValue.qclass != qclass)
+    return false;
+  return true;
+}
+
+void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen)
+{
+  if (responseLen == 0)
+    return;
+
+  uint32_t minTTL = getMinTTL(response, responseLen);
+  if (minTTL > d_maxTTL)
+    minTTL = d_maxTTL;
+
+  if (minTTL < d_minTTL)
+    return;
+
+  {
+    TryReadLock r(&d_lock);
+    if (!r.gotIt()) {
+      d_deferredInserts++;
+      return;
+    }
+    if (d_map.size() >= d_maxEntries) {
+      return;
+    }
+  }
+
+  const time_t now = time(NULL);
+  std::unordered_map<uint32_t,CacheValue>::iterator it;
+  bool result;
+  time_t newValidity = now + minTTL;
+  CacheValue newValue;
+  newValue.qname = qname;
+  newValue.qtype = qtype;
+  newValue.qclass = qclass;
+  newValue.len = responseLen;
+  newValue.validity = newValidity;
+  newValue.added = now;
+  newValue.value = std::string(response, responseLen);
+
+  {
+    TryWriteLock w(&d_lock);
+
+    if (!w.gotIt()) {
+      d_deferredInserts++;
+      return;
+    }
+
+    tie(it, result) = d_map.insert({key, newValue});
+
+    if (result) {
+      return;
+    }
+
+    /* in case of collision, don't override the existing entry
+       except if it has expired */
+    CacheValue& value = it->second;
+    bool wasExpired = value.validity <= now;
+
+    if (!wasExpired && !cachedValueMatches(value, qname, qtype, qclass)) {
+      d_insertCollisions++;
+      return;
+    }
+
+    /* if the existing entry had a longer TTD, keep it */
+    if (newValidity <= value.validity)
+      return;
+
+    value = newValue;
+  }
+}
+
+bool DNSDistPacketCache::get(const unsigned char* query, uint16_t queryLen, const DNSName& qname, uint16_t qtype, uint16_t qclass, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, bool skipAging)
+{
+  uint32_t key = getKey(qname, consumed, query, queryLen);
+  if (keyOut)
+    *keyOut = key;
+
+  time_t now = time(NULL);
+  time_t age;
+  {
+    TryReadLock r(&d_lock);
+    if (!r.gotIt()) {
+      d_deferredLookups++;
+      return false;
+    }
+
+    std::unordered_map<uint32_t,CacheValue>::const_iterator it = d_map.find(key);
+    if (it == d_map.end()) {
+      d_misses++;
+      return false;
+    }
+
+    const CacheValue& value = it->second;
+    if (value.validity < now) {
+      d_misses++;
+      return false;
+    }
+
+    if (*responseLen < value.len) {
+      return false;
+    }
+
+    /* check for collision */
+    if (!cachedValueMatches(value, qname, qtype, qclass)) {
+      d_misses++;
+      d_lookupCollisions++;
+      return false;
+    }
+
+    string dnsQName(qname.toDNSString());
+    memcpy(response, &queryId, sizeof(queryId));
+    memcpy(response + sizeof(queryId), value.value.c_str() + sizeof(queryId), sizeof(dnsheader) - sizeof(queryId));
+    memcpy(response + sizeof(dnsheader), dnsQName.c_str(), dnsQName.length());
+    memcpy(response + sizeof(dnsheader) + dnsQName.length(), value.value.c_str() + sizeof(dnsheader) + dnsQName.length(), value.value.length() - (sizeof(dnsheader) + dnsQName.length()));
+    *responseLen = value.len;
+    age = now - value.added;
+  }
+
+  if (!skipAging)
+    ageDNSPacket(response, *responseLen, age);
+  d_hits++;
+  return true;
+}
+
+void DNSDistPacketCache::purge(size_t upTo)
+{
+  time_t now = time(NULL);
+  WriteLock w(&d_lock);
+  if (upTo <= d_map.size())
+    return;
+
+  size_t toRemove = d_map.size() - upTo;
+  for(auto it = d_map.begin(); toRemove > 0 && it != d_map.end(); ) {
+    const CacheValue& value = it->second;
+
+    if (value.validity < now) {
+        it = d_map.erase(it);
+        --toRemove;
+    } else {
+      ++it;
+    }
+  }
+}
+
+void DNSDistPacketCache::expunge(const DNSName& name, uint16_t qtype)
+{
+  WriteLock w(&d_lock);
+
+  for(auto it = d_map.begin(); it != d_map.end(); ) {
+    const CacheValue& value = it->second;
+    uint16_t cqtype = 0;
+    uint16_t cqclass = 0;
+    DNSName cqname(value.value.c_str(), value.len, sizeof(dnsheader), false, &cqtype, &cqclass, nullptr);
+
+    if (cqname == name && (qtype == QType::ANY || qtype == cqtype)) {
+        it = d_map.erase(it);
+    } else {
+      ++it;
+    }
+  }
+}
+
+bool DNSDistPacketCache::isFull()
+{
+    ReadLock r(&d_lock);
+    return (d_map.size() >= d_maxEntries);
+}
+
+uint32_t DNSDistPacketCache::getMinTTL(const char* packet, uint16_t length)
+{
+  const struct dnsheader* dh = (const struct dnsheader*) packet;
+  uint32_t result = std::numeric_limits<uint32_t>::max();
+  vector<uint8_t> content(length - sizeof(dnsheader));
+  copy(packet + sizeof(dnsheader), packet + length, content.begin());
+  PacketReader pr(content);
+  size_t idx = 0;
+  DNSName rrname;
+  uint16_t qdcount = ntohs(dh->qdcount);
+  uint16_t ancount = ntohs(dh->ancount);
+  uint16_t nscount = ntohs(dh->nscount);
+  uint16_t arcount = ntohs(dh->arcount);
+  uint16_t rrtype;
+  uint16_t rrclass;
+  struct dnsrecordheader ah;
+
+  /* consume qd */
+  for(idx = 0; idx < qdcount; idx++) {
+    rrname = pr.getName();
+    rrtype = pr.get16BitInt();
+    rrclass = pr.get16BitInt();
+    (void) rrtype;
+    (void) rrclass;
+  }
+
+  /* consume AN and NS */
+  for (idx = 0; idx < ancount + nscount; idx++) {
+    rrname = pr.getName();
+    pr.getDnsrecordheader(ah);
+    pr.d_pos += ah.d_clen;
+    if (result > ah.d_ttl)
+      result = ah.d_ttl;
+  }
+
+  /* consume AR, watch for OPT */
+  for (idx = 0; idx < arcount; idx++) {
+    rrname = pr.getName();
+    pr.getDnsrecordheader(ah);
+    pr.d_pos += ah.d_clen;
+    if (ah.d_type == QType::OPT) {
+      continue;
+    }
+    if (result > ah.d_ttl)
+      result = ah.d_ttl;
+  }
+  return result;
+}
+
+uint32_t DNSDistPacketCache::getKey(const DNSName& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen)
+{
+  uint32_t result = 0;
+  /* skip the query ID */
+  result = burtle(packet + 2, sizeof(dnsheader) - 2, result);
+  string lc(qname.toDNSStringLC());
+  result = burtle((const unsigned char*) lc.c_str(), lc.length(), result);
+  result = burtle(packet + sizeof(dnsheader) + consumed, packetLen - (sizeof(dnsheader) + consumed), result);
+  return result;
+}
+
+string DNSDistPacketCache::toString()
+{
+  ReadLock r(&d_lock);
+  return std::to_string(d_map.size()) + "/" + std::to_string(d_maxEntries);
+}
diff --git a/pdns/dnsdist-cache.hh b/pdns/dnsdist-cache.hh
new file mode 100644 (file)
index 0000000..407b238
--- /dev/null
@@ -0,0 +1,57 @@
+#pragma once
+
+#include <atomic>
+#include <unordered_map>
+#include "lock.hh"
+
+class DNSDistPacketCache : boost::noncopyable
+{
+public:
+  DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL=86400, uint32_t minTTL=60);
+  ~DNSDistPacketCache();
+
+  void insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen);
+  bool get(const unsigned char* query, uint16_t queryLen, const DNSName& qname, uint16_t qtype, uint16_t qclass, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, bool skipAging=false);
+  void purge(size_t upTo=0);
+  void expunge(const DNSName& name, uint16_t qtype=QType::ANY);
+  bool isFull();
+  string toString();
+  uint64_t getSize() const { return d_map.size(); };
+  uint64_t getHits() const { return d_hits; };
+  uint64_t getMisses() const { return d_misses; };
+  uint64_t getDeferredLookups() const { return d_deferredLookups; };
+  uint64_t getDeferredInserts() const { return d_deferredInserts; };
+  uint64_t getLookupCollisions() const { return d_lookupCollisions; };
+  uint64_t getInsertCollisions() const { return d_insertCollisions; };
+
+  static uint32_t getMinTTL(const char* packet, uint16_t length);
+
+private:
+
+  struct CacheValue
+  {
+    time_t getTTD() const { return validity; }
+    std::string value;
+    DNSName qname;
+    uint16_t qtype{0};
+    uint16_t qclass{0};
+    time_t added{0};
+    time_t validity{0};
+    uint16_t len{0};
+  };
+
+  static uint32_t getKey(const DNSName& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen);
+  static bool cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass);
+
+  pthread_rwlock_t d_lock;
+  std::unordered_map<uint32_t,CacheValue> d_map;
+  std::atomic<uint64_t> d_deferredLookups{0};
+  std::atomic<uint64_t> d_deferredInserts{0};
+  std::atomic<uint64_t> d_hits{0};
+  std::atomic<uint64_t> d_misses{0};
+  std::atomic<uint64_t> d_insertCollisions{0};
+  std::atomic<uint64_t> d_lookupCollisions{0};
+  size_t d_maxEntries;
+  uint32_t d_maxTTL;
+  uint32_t d_minTTL;
+};
index 48625fb7bdbcd098a3461964994a389b4d66af60..79e7d115ad5195be4402c6abedc771fb78d06121 100644 (file)
@@ -243,6 +243,7 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
                          ret->qps=QPSLimiter(qps, qps);
                        }
 
+                       auto localPools = g_pools.getCopy();
                        if(vars.count("pool")) {
                          if(auto* pool = boost::get<string>(&vars["pool"]))
                            ret->pools.insert(*pool);
@@ -251,7 +252,14 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
                            for(auto& p : *pools)
                              ret->pools.insert(p.second);
                          }
+                         for(const auto& poolName: ret->pools) {
+                           addServerToPool(localPools, poolName, ret);
+                         }
+                       }
+                       else {
+                         addServerToPool(localPools, "", ret);
                        }
+                       g_pools.setState(localPools);
 
                        if(vars.count("order")) {
                          ret->order=std::stoi(boost::get<string>(vars["order"]));
@@ -369,12 +377,23 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
                      [](boost::variant<std::shared_ptr<DownstreamState>, int> var)
                      { 
                         setLuaSideEffect();
-                       auto states = g_dstates.getCopy();
-                       if(auto* rem = boost::get<shared_ptr<DownstreamState>>(&var))
-                         states.erase(remove(states.begin(), states.end(), *rem), states.end());
-                       else
-                         states.erase(states.begin() + boost::get<int>(var));
-                       g_dstates.setState(states);
+                        shared_ptr<DownstreamState> server;
+                        auto* rem = boost::get<shared_ptr<DownstreamState>>(&var);
+                        auto states = g_dstates.getCopy();
+                        if(rem) {
+                          server = *rem;
+                        }
+                        else {
+                          int idx = boost::get<int>(var);
+                          server = states[idx];
+                        }
+                        auto localPools = g_pools.getCopy();
+                        for (const string& poolName : server->pools) {
+                          removeServerFromPool(localPools, poolName, server);
+                        }
+                        g_pools.setState(localPools);
+                        states.erase(remove(states.begin(), states.end(), server), states.end());
+                        g_dstates.setState(states);
                      } );
 
 
@@ -610,7 +629,6 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
       return std::shared_ptr<DNSAction>(new DelayAction(msec));
     });
 
-
   g_lua.writeFunction("TCAction", []() {
       return std::shared_ptr<DNSAction>(new TCAction);
     });
@@ -627,6 +645,10 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
       return std::shared_ptr<DNSAction>(new RCodeAction(rcode));
     });
 
+  g_lua.writeFunction("SkipCacheAction", []() {
+      return std::shared_ptr<DNSAction>(new SkipCacheAction);
+    });
+
   g_lua.writeFunction("MaxQPSIPRule", [](unsigned int qps, boost::optional<int> ipv4trunc, boost::optional<int> ipv6trunc) {
       return std::shared_ptr<DNSRule>(new MaxQPSIPRule(qps, ipv4trunc.get_value_or(32), ipv6trunc.get_value_or(64)));
     });
@@ -826,7 +848,7 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
     });
 
   g_lua.writeFunction("getPoolServers", [](string pool) {
-      return getDownstreamCandidates(g_dstates.getCopy(), pool);
+      return getDownstreamCandidates(g_pools.getCopy(), pool);
     });
 
   g_lua.writeFunction("getServer", [client](int i) {
@@ -836,8 +858,18 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
     });
 
   g_lua.registerFunction<void(DownstreamState::*)(int)>("setQPS", [](DownstreamState& s, int lim) { s.qps = lim ? QPSLimiter(lim, lim) : QPSLimiter(); });
-  g_lua.registerFunction<void(DownstreamState::*)(string)>("addPool", [](DownstreamState& s, string pool) { s.pools.insert(pool);});
-  g_lua.registerFunction<void(DownstreamState::*)(string)>("rmPool", [](DownstreamState& s, string pool) { s.pools.erase(pool);});
+  g_lua.registerFunction<void(std::shared_ptr<DownstreamState>::*)(string)>("addPool", [](std::shared_ptr<DownstreamState> s, string pool) {
+      auto localPools = g_pools.getCopy();
+      addServerToPool(localPools, pool, s);
+      g_pools.setState(localPools);
+      s->pools.insert(pool);
+    });
+  g_lua.registerFunction<void(std::shared_ptr<DownstreamState>::*)(string)>("rmPool", [](std::shared_ptr<DownstreamState> s, string pool) {
+      auto localPools = g_pools.getCopy();
+      removeServerFromPool(localPools, pool, s);
+      g_pools.setState(localPools);
+      s->pools.erase(pool);
+    });
 
   g_lua.registerFunction<void(DownstreamState::*)()>("getOutstanding", [](const DownstreamState& s) { g_outputBuffer=std::to_string(s.outstanding.load()); });
 
@@ -1230,6 +1262,8 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
 
   g_lua.writeFunction("setMaxTCPClientThreads", [](uint64_t max) { g_maxTCPClientThreads = max; });
 
+  g_lua.writeFunction("setCacheCleaningDelay", [](uint32_t delay) { g_cacheCleaningDelay = delay; });
+
   g_lua.writeFunction("setECSSourcePrefixV4", [](uint16_t prefix) { g_ECSSourcePrefixV4=prefix; });
 
   g_lua.writeFunction("setECSSourcePrefixV6", [](uint16_t prefix) { g_ECSSourcePrefixV6=prefix; });
@@ -1280,7 +1314,7 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
       }
     });
 
-  moreLua();
+  moreLua(client);
   
   std::ifstream ifs(config);
   if(!ifs) 
index 9b52a4ad1652f1ca9daa7e7b51b19a0660d38118..68cab6987dbd7673c1333810f345985a19c536da 100644 (file)
@@ -1,4 +1,5 @@
 #include "dnsdist.hh"
+#include "dnsdist-cache.hh"
 #include "dnsrulactions.hh"
 #include <thread>
 #include "dolog.hh"
@@ -118,7 +119,7 @@ map<ComboAddress,int> exceedRespByterate(int rate, int seconds)
 }
 
 
-void moreLua()
+void moreLua(bool client)
 {
   typedef NetmaskTree<DynBlock> nmts_t;
   g_lua.writeFunction("newCA", [](const std::string& name) { return ComboAddress(name); });
@@ -493,4 +494,68 @@ void moreLua()
 #endif
     });
 
+    g_lua.writeFunction("showPools", []() {
+      setLuaNoSideEffect();
+      try {
+        ostringstream ret;
+        boost::format fmt("%1$-20.20s %|25t|%2$20s %|50t|%3%" );
+        //             1        3         4
+        ret << (fmt % "Name" % "Cache" % "Servers" ) << endl;
+
+        const auto localPools = g_pools.getCopy();
+        for (const auto& entry : localPools) {
+          const string& name = entry.first;
+          const std::shared_ptr<ServerPool> pool = entry.second;
+          string cache = pool->packetCache != nullptr ? pool->packetCache->toString() : "";
+          string servers;
+
+          for (const auto& server: pool->servers) {
+            if (!servers.empty()) {
+              servers += ", ";
+            }
+            if (!server.second->name.empty()) {
+              servers += server.second->name;
+              servers += " ";
+            }
+            servers += server.second->remote.toStringWithPort();
+          }
+
+          ret << (fmt % name % cache % servers) << endl;
+        }
+        g_outputBuffer=ret.str();
+      }catch(std::exception& e) { g_outputBuffer=e.what(); throw; }
+    });
+
+    g_lua.registerFunction<void(std::shared_ptr<ServerPool>::*)(std::shared_ptr<DNSDistPacketCache>)>("setCache", [](std::shared_ptr<ServerPool> pool, std::shared_ptr<DNSDistPacketCache> cache) {
+        pool->packetCache = cache;
+    });
+    g_lua.registerFunction("getCache", &ServerPool::getCache);
+
+    g_lua.writeFunction("newPacketCache", [client](size_t maxEntries, boost::optional<uint32_t> maxTTL, boost::optional<uint32_t> minTTL) {
+        return std::make_shared<DNSDistPacketCache>(maxEntries, maxTTL ? *maxTTL : 86400, minTTL ? *minTTL : 60);
+      });
+    g_lua.registerFunction("toString", &DNSDistPacketCache::toString);
+    g_lua.registerFunction("isFull", &DNSDistPacketCache::isFull);
+    g_lua.registerFunction("purge", &DNSDistPacketCache::purge);
+    g_lua.registerFunction<void(std::shared_ptr<DNSDistPacketCache>::*)(const DNSName& dname, boost::optional<uint16_t> qtype)>("expungeByName", [](std::shared_ptr<DNSDistPacketCache> cache, const DNSName& dname, boost::optional<uint16_t> qtype) {
+        cache->expunge(dname, qtype ? *qtype : QType::ANY);
+      });
+    g_lua.registerFunction<void(std::shared_ptr<DNSDistPacketCache>::*)()>("printStats", [](const std::shared_ptr<DNSDistPacketCache> cache) {
+        g_outputBuffer="Hits: " + std::to_string(cache->getHits()) + "\n";
+        g_outputBuffer+="Misses: " + std::to_string(cache->getMisses()) + "\n";
+        g_outputBuffer+="Deferred inserts: " + std::to_string(cache->getDeferredInserts()) + "\n";
+        g_outputBuffer+="Deferred lookups: " + std::to_string(cache->getDeferredLookups()) + "\n";
+        g_outputBuffer+="Lookup Collisions: " + std::to_string(cache->getLookupCollisions()) + "\n";
+        g_outputBuffer+="Insert Collisions: " + std::to_string(cache->getInsertCollisions()) + "\n";
+      });
+
+    g_lua.writeFunction("getPool", [client](const string& poolName) {
+        if (client) {
+          return std::make_shared<ServerPool>();
+        }
+        auto localPools = g_pools.getCopy();
+        std::shared_ptr<ServerPool> pool = createPoolIfNotExists(localPools, poolName);
+        g_pools.setState(localPools);
+        return pool;
+      });
 }
index ce57f5b0cc8226bae52576acad1d25288a1133de..9326278423b54613d565fd4c1819a7e03a4fa6fe 100644 (file)
@@ -150,6 +150,7 @@ void* tcpClientThread(int pipefd)
   auto localPolicy = g_policy.getLocal();
   auto localRulactions = g_rulactions.getLocal();
   auto localDynBlockNMG = g_dynblockNMG.getLocal();
+  auto localPools = g_pools.getLocal();
 
   map<ComboAddress,int> sockets;
   for(;;) {
@@ -161,7 +162,7 @@ void* tcpClientThread(int pipefd)
     delete citmp;    
 
     uint16_t qlen, rlen;
-    string pool
+    string poolname;
     const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
     const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
     const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask);
@@ -298,7 +299,7 @@ void* tcpClientThread(int pipefd)
               break;
             /* non-terminal actions follow */
             case DNSAction::Action::Pool:
-              pool=ruleresult;
+              poolname=ruleresult;
               break;
             case DNSAction::Action::Delay:
             case DNSAction::Action::None:
@@ -321,9 +322,10 @@ void* tcpClientThread(int pipefd)
        if(dq.qtype == QType::AXFR || dq.qtype == QType::IXFR)  // XXX fixme we really need to do better
          break;
 
+        std::shared_ptr<ServerPool> serverPool = getPool(*localPools, poolname);
        {
          std::lock_guard<std::mutex> lock(g_luamutex);
-         ds = localPolicy->policy(getDownstreamCandidates(g_dstates.getCopy(), pool), &dq);
+         ds = localPolicy->policy(serverPool->servers, &dq);
        }
        if(!ds) {
          g_stats.noPolicy++;
@@ -342,6 +344,19 @@ void* tcpClientThread(int pipefd)
           }
         }
 
+        uint32_t cacheKey = 0;
+        if (serverPool->packetCache && !dq.skipCache) {
+          char cachedResponse[4096];
+          uint16_t cachedResponseSize = sizeof cachedResponse;
+          if (serverPool->packetCache->get((unsigned char*) query, dq.len, *dq.qname, dq.qtype, dq.qclass, consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey)) {
+            if (putNonBlockingMsgLen(ci.fd, cachedResponseSize, g_tcpSendTimeout))
+              writen2WithTimeout(ci.fd, cachedResponse, cachedResponseSize, g_tcpSendTimeout);
+            g_stats.cacheHits++;
+            goto drop;
+          }
+          g_stats.cacheMisses++;
+        }
+
        int dsock;
        if(sockets.count(ds->remote) == 0) {
          dsock=sockets[ds->remote]=setupTCPDownstream(ds);
@@ -460,6 +475,10 @@ void* tcpClientThread(int pipefd)
          memcpy(response + sizeof(dnsheader), realname.c_str(), realname.length());
        }
 
+       if (serverPool->packetCache && !dq.skipCache) {
+         serverPool->packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen);
+       }
+
 #ifdef HAVE_DNSCRYPT
         if (ci.cs->dnscryptCtx) {
           uint16_t encryptedResponseLen = 0;
index d3aa7822c7cd59e5560f14f61df5db383ca0c9ad..560d598901409d8e979b86defbdb0fa8e7651eda 100644 (file)
@@ -40,6 +40,7 @@
 #include <pwd.h>
 #include "lock.hh"
 #include <getopt.h>
+#include "dnsdist-cache.hh"
 
 /* Known sins:
 
@@ -68,6 +69,7 @@ vector<std::pair<ComboAddress, bool>> g_locals;
 std::vector<std::pair<ComboAddress,DnsCryptContext>> g_dnsCryptLocals;
 #endif
 vector<ClientState *> g_frontends;
+GlobalStateHolder<pools_t> g_pools;
 
 /* UDP: the grand design. Per socket we listen on for incoming queries there is one thread.
    Then we have a bunch of connected sockets for talking to downstream servers. 
@@ -237,6 +239,10 @@ void* responderThread(std::shared_ptr<DownstreamState> state)
 
     g_stats.responses++;
 
+    if (ids->packetCache && !ids->skipCache) {
+      ids->packetCache->insert(ids->cacheKey, ids->qname, ids->qtype, ids->qclass, response, responseLen);
+    }
+
 #ifdef HAVE_DNSCRYPT
     uint16_t encryptedResponseLen = 0;
     if(ids->dnsCryptQuery) {
@@ -336,6 +342,10 @@ shared_ptr<DownstreamState> firstAvailable(const NumberedServerVector& servers,
 // get server with least outstanding queries, and within those, with the lowest order, and within those: the fastest
 shared_ptr<DownstreamState> leastOutstanding(const NumberedServerVector& servers, const DNSQuestion* dq)
 {
+  if (servers.size() == 1 && servers[0].second->isUp()) {
+    return servers[0].second;
+  }
+
   vector<pair<tuple<int,int,double>, shared_ptr<DownstreamState>>> poss;
   /* so you might wonder, why do we go through this trouble? The data on which we sort could change during the sort,
      which would suck royally and could even lead to crashes. So first we snapshot on what we sort, and then we sort */
@@ -444,16 +454,62 @@ static void daemonize(void)
 
 ComboAddress g_serverControl{"127.0.0.1:5199"};
 
+std::shared_ptr<ServerPool> createPoolIfNotExists(pools_t& pools, const string& poolName)
+{
+  std::shared_ptr<ServerPool> pool;
+  pools_t::iterator it = pools.find(poolName);
+  if (it != pools.end()) {
+    pool = it->second;
+  }
+  else {
+    vinfolog("Creating pool %s", poolName);
+    pool = std::make_shared<ServerPool>();
+    pools.insert(std::pair<std::string,std::shared_ptr<ServerPool> >(poolName, pool));
+  }
+  return pool;
+}
 
-NumberedServerVector getDownstreamCandidates(const servers_t& servers, const std::string& pool)
+void addServerToPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server)
 {
-  NumberedServerVector ret;
-  int count=0;
-  for(const auto& s : servers) 
-    if((pool.empty() && s->pools.empty()) || s->pools.count(pool))
-      ret.push_back(make_pair(++count, s));
-  
-  return ret;
+  std::shared_ptr<ServerPool> pool = createPoolIfNotExists(pools, poolName);
+  unsigned int count = pool->servers.size();
+  vinfolog("Adding server to pool %s", poolName);
+  pool->servers.push_back(make_pair(++count, server));
+}
+
+void removeServerFromPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server)
+{
+  vinfolog("Removing from pool %s", poolName);
+  pools_t::iterator poolIt = pools.find(poolName);
+  if (poolIt == pools.end()) {
+    throw std::out_of_range("No pool named " + poolName);
+  }
+
+  std::shared_ptr<ServerPool> pool = poolIt->second;
+
+  for (NumberedVector<shared_ptr<DownstreamState> >::iterator it = pool->servers.begin(); it != pool->servers.end(); it++) {
+    if (it->second == server) {
+      pool->servers.erase(it);
+      break;
+    }
+  }
+}
+
+std::shared_ptr<ServerPool> getPool(const pools_t& pools, const std::string& poolName)
+{
+  pools_t::const_iterator it = pools.find(poolName);
+
+  if (it == pools.end()) {
+    throw std::out_of_range("No pool named " + poolName);
+  }
+
+  return it->second;
+}
+
+const NumberedServerVector& getDownstreamCandidates(const pools_t& pools, const std::string& poolName)
+{
+  std::shared_ptr<ServerPool> pool = getPool(pools, poolName);
+  return pool->servers;
 }
 
 // goal in life - if you send us a reasonably normal packet, we'll get Z for you, otherwise 0
@@ -535,6 +591,7 @@ try
   auto localRulactions = g_rulactions.getLocal();
   auto localServers = g_dstates.getLocal();
   auto localDynBlock = g_dynblockNMG.getLocal();
+  auto localPools = g_pools.getLocal();
   struct msghdr msgh;
   struct iovec iov;
   /* used by HarvestDestinationAddress */
@@ -644,7 +701,7 @@ try
 
       DNSAction::Action action=DNSAction::Action::None;
       string ruleresult;
-      string pool;
+      string poolname;
       int delayMsec=0;
       bool done=false;
       for(const auto& lr : *localRulactions) {
@@ -675,7 +732,7 @@ try
             break;
           /* non-terminal actions follow */
           case DNSAction::Action::Pool:
-            pool=ruleresult;
+            poolname=ruleresult;
             break;
           case DNSAction::Action::Delay:
             delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
@@ -720,12 +777,12 @@ try
         continue;
       }
 
-      DownstreamState* ss = 0;
-      auto candidates=getDownstreamCandidates(*localServers, pool);
+      DownstreamState* ss = nullptr;
+      std::shared_ptr<ServerPool> serverPool = getPool(*localPools, poolname);
       auto policy=localPolicy->policy;
       {
        std::lock_guard<std::mutex> lock(g_luamutex);
-       ss = policy(candidates, &dq).get();
+       ss = policy(serverPool->servers, &dq).get();
       }
 
       if(!ss) {
@@ -733,6 +790,27 @@ try
        continue;
       }
 
+      bool ednsAdded = false;
+      if (ss->useECS) {
+        handleEDNSClientSubnet(query, dq.size, consumed, &dq.len, largerQuery, &(ednsAdded), remote);
+      }
+
+      uint32_t cacheKey = 0;
+      if (serverPool->packetCache && !dq.skipCache) {
+        char cachedResponse[4096];
+        uint16_t cachedResponseSize = sizeof cachedResponse;
+        if (serverPool->packetCache->get((unsigned char*) query, dq.len, *dq.qname, dq.qtype, dq.qclass, consumed, dh->id, cachedResponse, &cachedResponseSize, &cacheKey)) {
+          ComboAddress dest;
+          if(HarvestDestinationAddress(&msgh, &dest))
+            sendfromto(cs->udpFD, cachedResponse, cachedResponseSize, 0, dest, remote);
+          else
+            sendto(cs->udpFD, cachedResponse, cachedResponseSize, 0, (struct sockaddr*)&remote, remote.getSocklen());
+          g_stats.cacheHits++;
+          continue;
+        }
+        g_stats.cacheMisses++;
+      }
+
       ss->queries++;
 
       unsigned int idOffset = (ss->idOffset++) % ss->idStates.size();
@@ -752,10 +830,15 @@ try
       ids->sentTime.start();
       ids->qname = qname;
       ids->qtype = dq.qtype;
+      ids->qclass = dq.qclass;
       ids->origDest.sin4.sin_family=0;
       ids->delayMsec = delayMsec;
       ids->origFlags = origFlags;
       ids->ednsAdded = false;
+      ids->cacheKey = cacheKey;
+      ids->skipCache = dq.skipCache;
+      ids->packetCache = serverPool->packetCache;
+      ids->ednsAdded = ednsAdded;
 #ifdef HAVE_DNSCRYPT
       ids->dnsCryptQuery = dnsCryptQuery;
 #endif
@@ -763,10 +846,6 @@ try
 
       dh->id = idOffset;
 
-      if (ss->useECS) {
-        handleEDNSClientSubnet(query, dq.size, consumed, &dq.len, largerQuery, &(ids->ednsAdded), remote);
-      }
-
       if (largerQuery.empty()) {
         ret = udpClientSendRequestToBackend(ss, ss->fd, query, dq.len);
       }
@@ -853,8 +932,38 @@ catch(...)
 }
 
 std::atomic<uint64_t> g_maxTCPClientThreads{10};
+std::atomic<uint16_t> g_cacheCleaningDelay{60};
 
 void* maintThread()
+{
+  int interval = 1;
+  size_t counter = 0;
+
+  for(;;) {
+    sleep(interval);
+
+    std::lock_guard<std::mutex> lock(g_luamutex);
+    auto f =g_lua.readVariable<boost::optional<std::function<void()> > >("maintenance");
+    if(f)
+      (*f)();
+
+    counter++;
+    if (counter >= g_cacheCleaningDelay) {
+      const auto localPools = g_pools.getCopy();
+      for (const auto& entry : localPools) {
+        if (entry.second->packetCache) {
+          entry.second->packetCache->purge();
+        }
+      }
+      counter = 0;
+    }
+
+    // ponder pruning g_dynblocks of expired entries here
+  }
+  return 0;
+}
+
+void* healthChecksThread()
 {
   int interval = 1;
 
@@ -900,14 +1009,6 @@ void* maintThread()
         }          
       }
     }
-    
-    std::lock_guard<std::mutex> lock(g_luamutex);
-    auto f =g_lua.readVariable<boost::optional<std::function<void()> > >("maintenance");
-    if(f)
-      (*f)();
-    
-
-    // ponder pruning g_dynblocks of expired entries here
   }
   return 0;
 }
@@ -1278,14 +1379,16 @@ try
   for(auto& t : todo)
     t();
 
-
+  auto localPools = g_pools.getCopy();
   if(g_cmdLine.remotes.size()) {
     for(const auto& address : g_cmdLine.remotes) {
       auto ret=std::make_shared<DownstreamState>(ComboAddress(address, 53));
+      addServerToPool(localPools, "", ret);
       ret->tid = move(thread(responderThread, ret));
       g_dstates.modify([ret](servers_t& servers) { servers.push_back(ret); });
     }
   }
+  g_pools.setState(localPools);
 
   if(g_dstates.getCopy().empty()) {
     errlog("No downstream servers defined: all packets will get dropped");
@@ -1315,12 +1418,15 @@ try
   carbonthread.detach();
 
   thread stattid(maintThread);
+  stattid.detach();
   
+  thread healththread(healthChecksThread);
+
   if(g_cmdLine.beDaemon || g_cmdLine.beSupervised) {
-    stattid.join();
+    healththread.join();
   }
   else {
-    stattid.detach();
+    healththread.detach();
     doConsole();
   }
   _exit(EXIT_SUCCESS);
index 175a5f5bb2b645d70bff33cf999814a626dc917a..f44fbbad57041e5e7e824d95cf1243db1e9a6d92 100644 (file)
@@ -12,6 +12,7 @@
 #include <thread>
 #include "sholder.hh"
 #include "dnscrypt.hh"
+#include "dnsdist-cache.hh"
 void* carbonDumpThread();
 uint64_t uptimeOfProcess(const std::string& str);
 
@@ -53,6 +54,8 @@ struct DNSDistStats
   stat_t downstreamSendErrors{0};
   stat_t truncFail{0};
   stat_t noPolicy{0};
+  stat_t cacheHits{0};
+  stat_t cacheMisses{0};
   stat_t latency0_1{0}, latency1_10{0}, latency10_50{0}, latency50_100{0}, latency100_1000{0}, latencySlow{0};
   
   double latencyAvg100{0}, latencyAvg1000{0}, latencyAvg10000{0}, latencyAvg1000000{0};
@@ -75,6 +78,8 @@ struct DNSDistStats
     {"noncompliant-queries", &nonCompliantQueries},
     {"rdqueries", &rdQueries},
     {"empty-queries", &emptyQueries},
+    {"cache-hits", &cacheHits},
+    {"cache-misses", &cacheMisses},
     {"cpu-user-msec", getCPUTimeUser},
     {"cpu-sys-msec", getCPUTimeSystem},
     {"fd-usage", getOpenFileDescriptors}, {"dyn-blocked", &dynBlocked}, 
@@ -199,12 +204,16 @@ struct IDState
 #ifdef HAVE_DNSCRYPT
   std::shared_ptr<DnsCryptQuery> dnsCryptQuery{0};
 #endif
+  std::shared_ptr<DNSDistPacketCache> packetCache{nullptr};
+  uint32_t cacheKey;                                          // 8
   std::atomic<uint16_t> age;                                  // 4
   uint16_t qtype;                                             // 2
+  uint16_t qclass;                                            // 2
   uint16_t origID;                                            // 2
   uint16_t origFlags;                                         // 2
   int delayMsec;
   bool ednsAdded{false};
+  bool skipCache{false};
 };
 
 struct Rings {
@@ -362,6 +371,7 @@ struct DNSQuestion
   size_t size;
   uint16_t len;
   const bool tcp;
+  bool skipCache{false};
 };
 
 template <class T> using NumberedVector = std::vector<std::pair<unsigned int, T> >;
@@ -405,6 +415,17 @@ struct ServerPolicy
   policyfunc_t policy;
 };
 
+struct ServerPool
+{
+  const std::shared_ptr<DNSDistPacketCache> getCache() const { return packetCache; };
+
+  NumberedVector<shared_ptr<DownstreamState>> servers;
+  std::shared_ptr<DNSDistPacketCache> packetCache{nullptr};
+};
+using pools_t=map<std::string,std::shared_ptr<ServerPool>>;
+void addServerToPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server);
+void removeServerFromPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server);
+
 struct CarbonConfig
 {
   ComboAddress server{"0.0.0.0", 0};
@@ -425,6 +446,7 @@ enum ednsOptionCodes {
 extern GlobalStateHolder<CarbonConfig> g_carbon;
 extern GlobalStateHolder<ServerPolicy> g_policy;
 extern GlobalStateHolder<servers_t> g_dstates;
+extern GlobalStateHolder<pools_t> g_pools;
 extern GlobalStateHolder<vector<pair<std::shared_ptr<DNSRule>, std::shared_ptr<DNSAction> > > > g_rulactions;
 extern GlobalStateHolder<NetmaskGroup> g_ACL;
 
@@ -440,6 +462,7 @@ extern int g_tcpSendTimeout;
 extern uint16_t g_maxOutstanding;
 extern std::atomic<bool> g_configurationDone;
 extern std::atomic<uint64_t> g_maxTCPClientThreads;
+extern std::atomic<uint16_t> g_cacheCleaningDelay;
 extern uint16_t g_ECSSourcePrefixV4;
 extern uint16_t g_ECSSourcePrefixV6;
 extern bool g_ECSOverride;
@@ -448,7 +471,9 @@ struct dnsheader;
 
 void controlThread(int fd, ComboAddress local);
 vector<std::function<void(void)>> setupLua(bool client, const std::string& config);
-NumberedServerVector getDownstreamCandidates(const servers_t& servers, const std::string& pool);
+std::shared_ptr<ServerPool> getPool(const pools_t& pools, const std::string& poolName);
+std::shared_ptr<ServerPool> createPoolIfNotExists(pools_t& pools, const string& poolName);
+const NumberedServerVector& getDownstreamCandidates(const pools_t& pools, const std::string& poolName);
 
 std::shared_ptr<DownstreamState> firstAvailable(const NumberedServerVector& servers, const DNSQuestion* dq);
 
@@ -464,7 +489,7 @@ bool getMsgLen32(int fd, uint32_t* len);
 bool putMsgLen32(int fd, uint32_t len);
 void* tcpAcceptorThread(void* p);
 
-void moreLua();
+void moreLua(bool client);
 void doClient(ComboAddress server, const std::string& command);
 void doConsole();
 void controlClientThread(int fd, ComboAddress client);
index f71fb7fbe909c6bfeb4ea0e01ba53873e031d776..dafc9e9283915375692ef98257ab56e36f542eb3 100644 (file)
@@ -44,6 +44,7 @@ dnsdist_SOURCES = \
        dns.cc dns.hh \
        dnscrypt.cc dnscrypt.hh \
        dnsdist.cc dnsdist.hh \
+       dnsdist-cache.cc dnsdist-cache.hh \
        dnsdist-carbon.cc \
        dnsdist-console.cc \
        dnsdist-dnscrypt.cc \
@@ -94,8 +95,10 @@ testrunner_SOURCES = \
        dns.hh \
        test-base64_cc.cc \
        test-dnsdist_cc.cc \
+       test-dnsdistpacketcache_cc.cc \
        test-dnscrypt_cc.cc \
        dnsdist.hh \
+       dnsdist-cache.cc dnsdist-cache.hh \
        dnsdist-ecs.cc dnsdist-ecs.hh \
        dnscrypt.cc dnscrypt.hh \
        dnslabeltext.cc \
diff --git a/pdns/dnsdistdist/dnsdist-cache.cc b/pdns/dnsdistdist/dnsdist-cache.cc
new file mode 120000 (symlink)
index 0000000..9730d71
--- /dev/null
@@ -0,0 +1 @@
+../dnsdist-cache.cc
\ No newline at end of file
diff --git a/pdns/dnsdistdist/dnsdist-cache.hh b/pdns/dnsdistdist/dnsdist-cache.hh
new file mode 120000 (symlink)
index 0000000..84794d8
--- /dev/null
@@ -0,0 +1 @@
+../dnsdist-cache.hh
\ No newline at end of file
diff --git a/pdns/dnsdistdist/test-dnsdistpacketcache_cc.cc b/pdns/dnsdistdist/test-dnsdistpacketcache_cc.cc
new file mode 120000 (symlink)
index 0000000..dde3be0
--- /dev/null
@@ -0,0 +1 @@
+../test-dnsdistpacketcache_cc.cc
\ No newline at end of file
index 2d636d82fcfc193b11aceb56179c3f721d76fce3..67f64a32e06ecc36bfced26ecd7ee053c29075ef 100644 (file)
@@ -565,7 +565,10 @@ class DNSPacketMangler
 {
 public:
   explicit DNSPacketMangler(std::string& packet)
-    : d_packet(packet), d_notyouroffset(12), d_offset(d_notyouroffset)
+    : d_packet((char*) packet.c_str()), d_length(packet.length()), d_notyouroffset(12), d_offset(d_notyouroffset)
+  {}
+  DNSPacketMangler(char* packet, size_t length)
+    : d_packet(packet), d_length(length), d_notyouroffset(12), d_offset(d_notyouroffset)
   {}
   
   void skipLabel()
@@ -585,7 +588,7 @@ public:
   }
   uint16_t get16BitInt()
   {
-    const char* p = d_packet.c_str() + d_offset;
+    const char* p = d_packet + d_offset;
     moveOffset(2);
     uint16_t ret;
     memcpy(&ret, (void*)p, 2);
@@ -594,7 +597,7 @@ public:
   
   uint8_t get8BitInt()
   {
-    const char* p = d_packet.c_str() + d_offset;
+    const char* p = d_packet + d_offset;
     moveOffset(1);
     return *p;
   }
@@ -606,7 +609,7 @@ public:
   }
   void decreaseAndSkip32BitInt(uint32_t decrease)
   {
-    const char *p = (const char*)d_packet.c_str() + d_offset;
+    const char *p = d_packet + d_offset;
     moveOffset(4);
     
     uint32_t tmp;
@@ -614,17 +617,18 @@ public:
     tmp = ntohl(tmp);
     tmp-=decrease;
     tmp = htonl(tmp);
-    d_packet.replace(d_offset-4, sizeof(tmp), (const char*)&tmp, sizeof(tmp));
+    memcpy(d_packet + d_offset-4, (const char*)&tmp, sizeof(tmp));
   }
 private:
   void moveOffset(uint16_t by)
   {
     d_notyouroffset += by;
-    if(d_notyouroffset > d_packet.length())
+    if(d_notyouroffset > d_length)
       throw std::out_of_range("dns packet out of range: "+std::to_string(d_notyouroffset) +" > " 
-      + std::to_string(d_packet.length()) );
+      + std::to_string(d_length) );
   }
-  std::string& d_packet;
+  char* d_packet;
+  size_t d_length;
   
   uint32_t d_notyouroffset;  // only 'moveOffset' can touch this
   const uint32_t&  d_offset; // look.. but don't touch
@@ -632,16 +636,16 @@ private:
 };
 
 // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
-void ageDNSPacket(std::string& packet, uint32_t seconds)
+void ageDNSPacket(char* packet, size_t length, uint32_t seconds)
 {
-  if(packet.length() < sizeof(dnsheader))
+  if(length < sizeof(dnsheader))
     return;
   try 
   {
     dnsheader dh;
-    memcpy((void*)&dh, (const dnsheader*)packet.c_str(), sizeof(dh));
+    memcpy((void*)&dh, (const dnsheader*)packet, sizeof(dh));
     int numrecords = ntohs(dh.ancount) + ntohs(dh.nscount) + ntohs(dh.arcount);
-    DNSPacketMangler dpm(packet);
+    DNSPacketMangler dpm(packet, length);
     
     int n;
     for(n=0; n < ntohs(dh.qdcount) ; ++n) {
@@ -667,3 +671,8 @@ void ageDNSPacket(std::string& packet, uint32_t seconds)
     return;
   }
 }
+
+void ageDNSPacket(std::string& packet, uint32_t seconds)
+{
+  ageDNSPacket((char*)packet.c_str(), packet.length(), seconds);
+}
index ef7c080fb6222041e0183dc7d7f0d6a234ab7fbe..94df048ec7b8797b984fe99a0f9f5864a08da1d0 100644 (file)
@@ -344,6 +344,7 @@ private:
 
 string simpleCompress(const string& label, const string& root="");
 void simpleExpandTo(const string& label, unsigned int frompos, string& ret);
+void ageDNSPacket(char* packet, size_t length, uint32_t seconds);
 void ageDNSPacket(std::string& packet, uint32_t seconds);
 
 template<typename T>
index 9317433309e8d41312440ba369637361446a8a12..e70571b0b3aef24a4fa087dacd2f88cbbf1f3b5a 100644 (file)
@@ -646,3 +646,17 @@ public:
     return "set cd=1";
   }
 };
+
+class SkipCacheAction : public DNSAction
+{
+public:
+  DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
+  {
+    dq->skipCache = true;
+    return Action::None;
+  }
+  string toString() const override
+  {
+    return "skip cache";
+  }
+};
diff --git a/pdns/test-dnsdistpacketcache_cc.cc b/pdns/test-dnsdistpacketcache_cc.cc
new file mode 100644 (file)
index 0000000..9a21a1e
--- /dev/null
@@ -0,0 +1,197 @@
+#define BOOST_TEST_DYN_LINK
+#define BOOST_TEST_NO_MAIN
+
+#include <boost/test/unit_test.hpp>
+
+#include "iputils.hh"
+#include "dnsdist-cache.hh"
+#include "dnswriter.hh"
+
+BOOST_AUTO_TEST_SUITE(dnsdistpacketcache_cc)
+
+BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) {
+  const size_t maxEntries = 150000;
+  DNSDistPacketCache PC(maxEntries, 86400, 1);
+  BOOST_CHECK_EQUAL(PC.getSize(), 0);
+
+  size_t counter=0;
+  size_t skipped=0;
+  try {
+    for(counter = 0; counter < 100000; ++counter) {
+      DNSName a=DNSName("hello ")+DNSName(std::to_string(counter));
+      BOOST_CHECK_EQUAL(DNSName(a.toString()), a);
+
+      vector<uint8_t> query;
+      DNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0);
+      pwQ.getHeader()->rd = 1;
+
+      vector<uint8_t> response;
+      DNSPacketWriter pwR(response, a, QType::A, QClass::IN, 0);
+      pwR.getHeader()->rd = 1;
+      pwR.getHeader()->ra = 1;
+      pwR.getHeader()->qr = 1;
+      pwR.getHeader()->id = pwQ.getHeader()->id;
+      pwR.startRecord(a, QType::A, 100, QClass::IN, DNSResourceRecord::ANSWER);
+      pwR.xfr32BitInt(0x01020304);
+      pwR.commit();
+      uint16_t responseLen = response.size();
+
+      char responseBuf[4096];
+      uint16_t responseBufSize = sizeof(responseBuf);
+      uint32_t key = 0;
+      bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, &key);
+      BOOST_CHECK_EQUAL(found, false);
+
+      PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen);
+
+      found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, true);
+      if (found == true) {
+        BOOST_CHECK_EQUAL(responseBufSize, responseLen);
+        int match = memcmp(responseBuf, response.data(), responseLen);
+        BOOST_CHECK_EQUAL(match, 0);
+      }
+      else {
+        skipped++;
+      }
+    }
+
+    BOOST_CHECK_EQUAL(skipped, PC.getInsertCollisions());
+    BOOST_CHECK_EQUAL(PC.getSize(), counter - skipped);
+
+    size_t deleted=0;
+    size_t delcounter=0;
+    for(delcounter=0; delcounter < counter/1000; ++delcounter) {
+      DNSName a=DNSName("hello ")+DNSName(std::to_string(delcounter));
+      vector<uint8_t> query;
+      DNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0);
+      pwQ.getHeader()->rd = 1;
+      char responseBuf[4096];
+      uint16_t responseBufSize = sizeof(responseBuf);
+      uint32_t key = 0;
+      bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, &key);
+      if (found == true) {
+        PC.expunge(a);
+        deleted++;
+      }
+    }
+    BOOST_CHECK_EQUAL(PC.getSize(), counter - skipped - deleted);
+
+    size_t matches=0;
+    vector<DNSResourceRecord> entry;
+    size_t expected=counter-skipped-deleted;
+    for(; delcounter < counter; ++delcounter) {
+      DNSName a(DNSName("hello ")+DNSName(std::to_string(delcounter)));
+      vector<uint8_t> query;
+      DNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0);
+      pwQ.getHeader()->rd = 1;
+      uint16_t len = query.size();
+      uint32_t key = 0;
+      char response[4096];
+      uint16_t responseSize = sizeof(response);
+      if(PC.get(query.data(), len, a, QType::A, QClass::IN, a.wirelength(), pwQ.getHeader()->id, response, &responseSize, &key)) {
+       matches++;
+      }
+    }
+    BOOST_CHECK_EQUAL(matches, expected);
+  }
+  catch(PDNSException& e) {
+    cerr<<"Had error: "<<e.reason<<endl;
+    throw;
+  }
+}
+
+static DNSDistPacketCache PC(500000);
+
+static void *threadMangler(void* a)
+{
+  try {
+    unsigned int offset=(unsigned int)(unsigned long)a;
+    for(unsigned int counter=0; counter < 100000; ++counter) {
+      DNSName a=DNSName("hello ")+DNSName(std::to_string(counter+offset));
+      vector<uint8_t> query;
+      DNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0);
+      pwQ.getHeader()->rd = 1;
+
+      vector<uint8_t> response;
+      DNSPacketWriter pwR(response, a, QType::A, QClass::IN, 0);
+      pwR.getHeader()->rd = 1;
+      pwR.getHeader()->ra = 1;
+      pwR.getHeader()->qr = 1;
+      pwR.getHeader()->id = pwQ.getHeader()->id;
+      pwR.startRecord(a, QType::A, 3600, QClass::IN, DNSResourceRecord::ANSWER);
+      pwR.xfr32BitInt(0x01020304);
+      pwR.commit();
+      uint16_t responseLen = response.size();
+
+      char responseBuf[4096];
+      uint16_t responseBufSize = sizeof(responseBuf);
+      uint32_t key = 0;
+      PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, &key);
+
+      PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen);
+    }
+  }
+  catch(PDNSException& e) {
+    cerr<<"Had error: "<<e.reason<<endl;
+    throw;
+  }
+  return 0;
+}
+
+AtomicCounter g_missing;
+
+static void *threadReader(void* a)
+{
+  try
+  {
+    unsigned int offset=(unsigned int)(unsigned long)a;
+    vector<DNSResourceRecord> entry;
+    for(unsigned int counter=0; counter < 100000; ++counter) {
+      DNSName a=DNSName("hello ")+DNSName(std::to_string(counter+offset));
+      vector<uint8_t> query;
+      DNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0);
+      pwQ.getHeader()->rd = 1;
+
+      char responseBuf[4096];
+      uint16_t responseBufSize = sizeof(responseBuf);
+      uint32_t key = 0;
+      bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, &key);
+      if (!found) {
+       g_missing++;
+      }
+    }
+  }
+  catch(PDNSException& e) {
+    cerr<<"Had error in threadReader: "<<e.reason<<endl;
+    throw;
+  }
+  return 0;
+}
+
+BOOST_AUTO_TEST_CASE(test_PacketCacheThreaded) {
+  try {
+    pthread_t tid[4];
+    for(int i=0; i < 4; ++i)
+      pthread_create(&tid[i], 0, threadMangler, (void*)(i*1000000UL));
+    void* res;
+    for(int i=0; i < 4 ; ++i)
+      pthread_join(tid[i], &res);
+
+    BOOST_CHECK_EQUAL(PC.getSize() + PC.getDeferredInserts() + PC.getInsertCollisions(), 400000);
+    BOOST_CHECK_SMALL(1.0*PC.getInsertCollisions(), 10000.0);
+
+    for(int i=0; i < 4; ++i)
+      pthread_create(&tid[i], 0, threadReader, (void*)(i*1000000UL));
+    for(int i=0; i < 4 ; ++i)
+      pthread_join(tid[i], &res);
+
+    BOOST_CHECK((PC.getDeferredInserts() + PC.getDeferredLookups() + PC.getInsertCollisions()) >= g_missing);
+  }
+  catch(PDNSException& e) {
+    cerr<<"Had error: "<<e.reason<<endl;
+    throw;
+  }
+
+}
+
+BOOST_AUTO_TEST_SUITE_END()
index e93ea7936f222aba73a6a728843a92e871ea5a00..392bbd395308577a4ae3040a81d86c0dff63cb56 100644 (file)
@@ -864,3 +864,325 @@ class TestAdvancedOr(DNSDistTest):
         (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
         receivedResponse.id = expectedResponse.id
         self.assertEquals(receivedResponse, expectedResponse)
+
+class TestAdvancedCaching(DNSDistTest):
+
+    _config_template = """
+    pc = newPacketCache(5, 86400, 1)
+    getPool(""):setCache(pc)
+    addAction(makeRule("nocache.tests.powerdns.com."), SkipCacheAction())
+    newServer{address="127.0.0.1:%s"}
+    """
+    def testCached(self):
+        """
+        Advanced: Served from cache
+
+        dnsdist is configured to cache entries, we are sending several
+        identical requests and checking that the backend only receive
+        the first one.
+        """
+        numberOfQueries = 10
+        name = 'cached.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'AAAA', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '::1')
+        response.answer.append(rrset)
+
+        # first query to fill the cache
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(receivedResponse, response)
+
+        for idx in range(numberOfQueries):
+            (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+            receivedResponse.id = response.id
+            self.assertEquals(receivedResponse, response)
+
+            (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+            receivedResponse.id = response.id
+            self.assertEquals(receivedResponse, response)
+
+        total = 0
+        for key in TestAdvancedCaching._responsesCounter:
+            total += TestAdvancedCaching._responsesCounter[key]
+
+        self.assertEquals(total, 1)
+
+    def testSkipCache(self):
+        """
+        Advanced: SkipCacheAction
+
+        dnsdist is configured to not cache entries for nocache.tests.powerdns.com.
+         we are sending several requests and checking that the backend get them all.
+        """
+        name = 'nocache.tests.powerdns.com.'
+        numberOfQueries = 10
+        query = dns.message.make_query(name, 'AAAA', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '::1')
+        response.answer.append(rrset)
+
+        for idx in range(numberOfQueries):
+            (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            receivedResponse.id = response.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
+
+            (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            receivedResponse.id = response.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
+
+        for key in TestAdvancedCaching._responsesCounter:
+            value = TestAdvancedCaching._responsesCounter[key]
+            self.assertEquals(value, numberOfQueries)
+
+    def testCacheExpiration(self):
+        """
+        Advanced: Cache expiration
+
+        dnsdist is configured to cache entries, we are sending one request
+        (cache miss) with a very short TTL, checking that the next requests
+        are cached. Then we wait for the TTL to expire, check that the
+        next request is a miss but the following one a hit.
+        """
+        ttl = 2
+        misses = 0
+        name = 'cacheexpiration.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'AAAA', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    ttl,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '::1')
+        response.answer.append(rrset)
+
+        # first query to fill the cache
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(receivedResponse, response)
+        misses += 1
+
+        # next queries should hit the cache
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        receivedResponse.id = response.id
+        self.assertEquals(receivedResponse, response)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        receivedResponse.id = response.id
+        self.assertEquals(receivedResponse, response)
+
+        # now we wait a bit for the cache entry to expire
+        time.sleep(ttl + 1)
+
+        # next query should be a miss, fill the cache again
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(receivedResponse, response)
+        misses += 1
+
+        # following queries should hit the cache again
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        receivedResponse.id = response.id
+        self.assertEquals(receivedResponse, response)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        receivedResponse.id = response.id
+        self.assertEquals(receivedResponse, response)
+
+        total = 0
+        for key in TestAdvancedCaching._responsesCounter:
+            total += TestAdvancedCaching._responsesCounter[key]
+
+        self.assertEquals(total, misses)
+
+    def testCacheDecreaseTTL(self):
+        """
+        Advanced: Cache decreases TTL
+
+        dnsdist is configured to cache entries, we are sending one request
+        (cache miss) and verify that the cache hits have a decreasing TTL.
+        """
+        ttl = 600
+        misses = 0
+        name = 'cachedecreasettl.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'AAAA', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    ttl,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '::1')
+        response.answer.append(rrset)
+
+        # first query to fill the cache
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(receivedResponse, response)
+        misses += 1
+
+        # next queries should hit the cache
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        receivedResponse.id = response.id
+        self.assertEquals(receivedResponse, response)
+        for an in receivedResponse.answer:
+            self.assertTrue(an.ttl <= ttl)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        receivedResponse.id = response.id
+        self.assertEquals(receivedResponse, response)
+        for an in receivedResponse.answer:
+            self.assertTrue(an.ttl <= ttl)
+
+        # now we wait a bit for the TTL to decrease
+        time.sleep(1)
+
+        # next queries should hit the cache
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        receivedResponse.id = response.id
+        self.assertEquals(receivedResponse, response)
+        for an in receivedResponse.answer:
+            self.assertTrue(an.ttl < ttl)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        receivedResponse.id = response.id
+        self.assertEquals(receivedResponse, response)
+        for an in receivedResponse.answer:
+            self.assertTrue(an.ttl < ttl)
+
+        total = 0
+        for key in TestAdvancedCaching._responsesCounter:
+            total += TestAdvancedCaching._responsesCounter[key]
+
+        self.assertEquals(total, misses)
+
+    def testCacheDifferentCase(self):
+        """
+        Advanced: Cache matches different case
+
+        dnsdist is configured to cache entries, we are sending one request
+        (cache miss) and verify that the same one with a different case
+        matches.
+        """
+        ttl = 600
+        name = 'cachedifferentcase.tests.powerdns.com.'
+        differentCaseName = 'CacheDifferentCASE.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'AAAA', 'IN')
+        differentCaseQuery = dns.message.make_query(differentCaseName, 'AAAA', 'IN')
+        response = dns.message.make_response(query)
+        differentCaseResponse = dns.message.make_response(differentCaseQuery)
+        rrset = dns.rrset.from_text(name,
+                                    ttl,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '::1')
+        response.answer.append(rrset)
+        differentCaseResponse.answer.append(rrset)
+
+        # first query to fill the cache
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(receivedResponse, response)
+
+        # different case query should still hit the cache
+        (_, receivedResponse) = self.sendUDPQuery(differentCaseQuery, response=None, useQueue=False)
+        receivedResponse.id = differentCaseResponse.id
+        self.assertEquals(receivedResponse, differentCaseResponse)
+
+        (_, receivedResponse) = self.sendTCPQuery(differentCaseQuery, response=None, useQueue=False)
+        receivedResponse.id = differentCaseResponse.id
+        self.assertEquals(receivedResponse, differentCaseResponse)
+
+class TestAdvancedCachingWithExistingEDNS(DNSDistTest):
+
+    _config_template = """
+    pc = newPacketCache(5, 86400, 1)
+    getPool(""):setCache(pc)
+    newServer{address="127.0.0.1:%s"}
+    """
+    def testCacheWithEDNS(self):
+        """
+        Advanced: Cache should not match different EDNS value
+
+        dnsdist is configured to cache entries, we are sending one request
+        (cache miss) and verify that the same one with a different EDNS UDP
+        Payload size is not served from the cache.
+        """
+        misses = 0
+        name = 'cachedifferentedns.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=512)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+        misses += 1
+
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+        misses += 1
+
+        total = 0
+        for key in TestAdvancedCaching._responsesCounter:
+            total += TestAdvancedCaching._responsesCounter[key]
+
+        self.assertEquals(total, misses)