* Delay a response by n milliseconds (DelayAction), over UDP only
* Modify query to clear the RD or CD bit
* Add the source MAC address to the query (MacAddrAction)
+ * Skip the cache, if any
Rules can be added via:
::/0
```
+Caching
+-------
+
+`dnsdist` implements a simple but effective packet cache, not enabled by default.
+It is enabled per-pool, but the same cache can be shared between several pools.
+The first step is to define a cache, then to assign that cache to the chosen pool,
+the default one being represented by the empty string:
+
+```
+pc = newPacketCache(10000, 86400, 600)
+getPool(""):setCache(pc)
+```
+
+The first parameter is the maximum number of entries stored in the cache, the
+second one, optional, is the maximum lifetime of an entry in the cache, in seconds,
+and the last one, optional too, is the minimum TTL an entry should have to be considered
+for insertion in the cache.
+
+
Carbon/Graphite/Metronome
-------------------------
To emit metrics to Graphite, or any other software supporting the Carbon protocol, use:
* `QPSPoolAction()`: set the packet into the specified pool only if it does not exceed the specified QPS limits
* `QPSAction()`: drop these packets if the QPS limits are exceeded
* `RCodeAction()`: reply immediatly by turning the query into a response with the specified rcode
+ * `SkipCacheAction()`: don't lookup the cache for this query, don't store the answer
* `SpoofAction()`: forge a response with the specified IPv4 (for an A query) or IPv6 (for an AAAA). If you specify two addresses, the first one should be an IPv4 and will be used for A, the second an IPv6 for an AAAA
* `SpoofCNAMEAction()`: forge a response with the specified CNAME value
* `TCAction()`: create answer to query with TC and RD bits set, to move to TCP/IP
* `addPoolRule(netmask, pool)`: send queries to this netmask to that pool
* `addPoolRule({netmask, netmask}, pool)`: send queries to these netmasks to that pool
* `addQPSPoolRule(x, limit, pool)`: like `addPoolRule`, but only select at most 'limit' queries/s for this pool
+ * `getPool(poolname)`: return the ServerPool named `poolname`
* `getPoolServers(pool)`: return servers part of this pool
+ * `showPools()`: list the current server pools
* Lua Action related:
* `addLuaAction(x, func)`: where 'x' is all the combinations from `addPoolRule`, and func is a
function with parameters remote, qname, qtype, dh and len, which returns an action to be taken
* `exceedRespByterate(rate, seconds)`: get set of addresses that exeeded `rate` bytes/s answers over `seconds` seconds
* `exceedQRate(rate, seconds)`: get set of address that exceed `rate` queries/s over `seconds` seconds
* `exceedQTypeRate(type, rate, seconds)`: get set of address that exceed `rate` queries/s for queries of type `type` over `seconds` seconds
+ * ServerPool related:
+ * `getCache()`: return the current packet cache, if any
+ * `setCache(PacketCache)`: set the cache for this pool
+ * PacketCache related:
+ * `expungeByName(DNSName)`: remove entries matching the supplied DNSName from the cache
+ * `isFull()`: return true if the cache has reached the maximum number of entries
+ * `newPacketCache(maxEntries, maxTTL=86400, minTTL=60)`: return a new PacketCache
+ * `printStats()`: print the cache stats (hits, misses, deferred lookups and deferred inserts)
+ * `purge()`: remove entries from the cache until it the number of entries is lower than the maximum, starting with expired ones.
+ * `toString()`: return the number of entries in the Packet Cache, and the maximum number of entries
* Advanced functions for writing your own policies and hooks
* ComboAddress related:
* `newCA(address)`: return a new ComboAddress
* `setTCPSendTimeout(n)`: set the write timeout on TCP connections from the client, in seconds
* `setMaxTCPClientThreads(n)`: set the maximum of TCP client threads, handling TCP connections
* `setMaxUDPOutstanding(n)`: set the maximum number of outstanding UDP queries to a given backend server. This can only be set at configuration time
+ * `setCacheCleaningDelay(n)`: set the interval in seconds between to run of the cache cleaning algorithm, removing expired entries
* DNSCrypt related:
* `addDNSCryptBind("127.0.0.1:8443", "provider name", "/path/to/resolver.cert", "/path/to/resolver.key"):` listen to incoming DNSCrypt queries on 127.0.0.1 port 8443, with a provider name of "provider name", using a resolver certificate and associated key stored respectively in the `resolver.cert` and `resolver.key` files
* `generateDNSCryptProviderKeys("/path/to/providerPublic.key", "/path/to/providerPrivate.key"):` generate a new provider keypair
--- /dev/null
+#include "dolog.hh"
+#include "dnsdist-cache.hh"
+#include "dnsparser.hh"
+
+DNSDistPacketCache::DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL, uint32_t minTTL): d_maxEntries(maxEntries), d_maxTTL(maxTTL), d_minTTL(minTTL)
+{
+ pthread_rwlock_init(&d_lock, 0);
+ /* we reserve maxEntries + 1 to avoid rehashing from occuring
+ when we get to maxEntries, as it means a load factor of 1 */
+ d_map.reserve(maxEntries + 1);
+}
+
+DNSDistPacketCache::~DNSDistPacketCache()
+{
+ WriteLock l(&d_lock);
+}
+
+bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass)
+{
+ if (cachedValue.qname != qname || cachedValue.qtype != qtype || cachedValue.qclass != qclass)
+ return false;
+ return true;
+}
+
+void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen)
+{
+ if (responseLen == 0)
+ return;
+
+ uint32_t minTTL = getMinTTL(response, responseLen);
+ if (minTTL > d_maxTTL)
+ minTTL = d_maxTTL;
+
+ if (minTTL < d_minTTL)
+ return;
+
+ {
+ TryReadLock r(&d_lock);
+ if (!r.gotIt()) {
+ d_deferredInserts++;
+ return;
+ }
+ if (d_map.size() >= d_maxEntries) {
+ return;
+ }
+ }
+
+ const time_t now = time(NULL);
+ std::unordered_map<uint32_t,CacheValue>::iterator it;
+ bool result;
+ time_t newValidity = now + minTTL;
+ CacheValue newValue;
+ newValue.qname = qname;
+ newValue.qtype = qtype;
+ newValue.qclass = qclass;
+ newValue.len = responseLen;
+ newValue.validity = newValidity;
+ newValue.added = now;
+ newValue.value = std::string(response, responseLen);
+
+ {
+ TryWriteLock w(&d_lock);
+
+ if (!w.gotIt()) {
+ d_deferredInserts++;
+ return;
+ }
+
+ tie(it, result) = d_map.insert({key, newValue});
+
+ if (result) {
+ return;
+ }
+
+ /* in case of collision, don't override the existing entry
+ except if it has expired */
+ CacheValue& value = it->second;
+ bool wasExpired = value.validity <= now;
+
+ if (!wasExpired && !cachedValueMatches(value, qname, qtype, qclass)) {
+ d_insertCollisions++;
+ return;
+ }
+
+ /* if the existing entry had a longer TTD, keep it */
+ if (newValidity <= value.validity)
+ return;
+
+ value = newValue;
+ }
+}
+
+bool DNSDistPacketCache::get(const unsigned char* query, uint16_t queryLen, const DNSName& qname, uint16_t qtype, uint16_t qclass, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, bool skipAging)
+{
+ uint32_t key = getKey(qname, consumed, query, queryLen);
+ if (keyOut)
+ *keyOut = key;
+
+ time_t now = time(NULL);
+ time_t age;
+ {
+ TryReadLock r(&d_lock);
+ if (!r.gotIt()) {
+ d_deferredLookups++;
+ return false;
+ }
+
+ std::unordered_map<uint32_t,CacheValue>::const_iterator it = d_map.find(key);
+ if (it == d_map.end()) {
+ d_misses++;
+ return false;
+ }
+
+ const CacheValue& value = it->second;
+ if (value.validity < now) {
+ d_misses++;
+ return false;
+ }
+
+ if (*responseLen < value.len) {
+ return false;
+ }
+
+ /* check for collision */
+ if (!cachedValueMatches(value, qname, qtype, qclass)) {
+ d_misses++;
+ d_lookupCollisions++;
+ return false;
+ }
+
+ string dnsQName(qname.toDNSString());
+ memcpy(response, &queryId, sizeof(queryId));
+ memcpy(response + sizeof(queryId), value.value.c_str() + sizeof(queryId), sizeof(dnsheader) - sizeof(queryId));
+ memcpy(response + sizeof(dnsheader), dnsQName.c_str(), dnsQName.length());
+ memcpy(response + sizeof(dnsheader) + dnsQName.length(), value.value.c_str() + sizeof(dnsheader) + dnsQName.length(), value.value.length() - (sizeof(dnsheader) + dnsQName.length()));
+ *responseLen = value.len;
+ age = now - value.added;
+ }
+
+ if (!skipAging)
+ ageDNSPacket(response, *responseLen, age);
+ d_hits++;
+ return true;
+}
+
+void DNSDistPacketCache::purge(size_t upTo)
+{
+ time_t now = time(NULL);
+ WriteLock w(&d_lock);
+ if (upTo <= d_map.size())
+ return;
+
+ size_t toRemove = d_map.size() - upTo;
+ for(auto it = d_map.begin(); toRemove > 0 && it != d_map.end(); ) {
+ const CacheValue& value = it->second;
+
+ if (value.validity < now) {
+ it = d_map.erase(it);
+ --toRemove;
+ } else {
+ ++it;
+ }
+ }
+}
+
+void DNSDistPacketCache::expunge(const DNSName& name, uint16_t qtype)
+{
+ WriteLock w(&d_lock);
+
+ for(auto it = d_map.begin(); it != d_map.end(); ) {
+ const CacheValue& value = it->second;
+ uint16_t cqtype = 0;
+ uint16_t cqclass = 0;
+ DNSName cqname(value.value.c_str(), value.len, sizeof(dnsheader), false, &cqtype, &cqclass, nullptr);
+
+ if (cqname == name && (qtype == QType::ANY || qtype == cqtype)) {
+ it = d_map.erase(it);
+ } else {
+ ++it;
+ }
+ }
+}
+
+bool DNSDistPacketCache::isFull()
+{
+ ReadLock r(&d_lock);
+ return (d_map.size() >= d_maxEntries);
+}
+
+uint32_t DNSDistPacketCache::getMinTTL(const char* packet, uint16_t length)
+{
+ const struct dnsheader* dh = (const struct dnsheader*) packet;
+ uint32_t result = std::numeric_limits<uint32_t>::max();
+ vector<uint8_t> content(length - sizeof(dnsheader));
+ copy(packet + sizeof(dnsheader), packet + length, content.begin());
+ PacketReader pr(content);
+ size_t idx = 0;
+ DNSName rrname;
+ uint16_t qdcount = ntohs(dh->qdcount);
+ uint16_t ancount = ntohs(dh->ancount);
+ uint16_t nscount = ntohs(dh->nscount);
+ uint16_t arcount = ntohs(dh->arcount);
+ uint16_t rrtype;
+ uint16_t rrclass;
+ struct dnsrecordheader ah;
+
+ /* consume qd */
+ for(idx = 0; idx < qdcount; idx++) {
+ rrname = pr.getName();
+ rrtype = pr.get16BitInt();
+ rrclass = pr.get16BitInt();
+ (void) rrtype;
+ (void) rrclass;
+ }
+
+ /* consume AN and NS */
+ for (idx = 0; idx < ancount + nscount; idx++) {
+ rrname = pr.getName();
+ pr.getDnsrecordheader(ah);
+ pr.d_pos += ah.d_clen;
+ if (result > ah.d_ttl)
+ result = ah.d_ttl;
+ }
+
+ /* consume AR, watch for OPT */
+ for (idx = 0; idx < arcount; idx++) {
+ rrname = pr.getName();
+ pr.getDnsrecordheader(ah);
+ pr.d_pos += ah.d_clen;
+ if (ah.d_type == QType::OPT) {
+ continue;
+ }
+ if (result > ah.d_ttl)
+ result = ah.d_ttl;
+ }
+ return result;
+}
+
+uint32_t DNSDistPacketCache::getKey(const DNSName& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen)
+{
+ uint32_t result = 0;
+ /* skip the query ID */
+ result = burtle(packet + 2, sizeof(dnsheader) - 2, result);
+ string lc(qname.toDNSStringLC());
+ result = burtle((const unsigned char*) lc.c_str(), lc.length(), result);
+ result = burtle(packet + sizeof(dnsheader) + consumed, packetLen - (sizeof(dnsheader) + consumed), result);
+ return result;
+}
+
+string DNSDistPacketCache::toString()
+{
+ ReadLock r(&d_lock);
+ return std::to_string(d_map.size()) + "/" + std::to_string(d_maxEntries);
+}
--- /dev/null
+#pragma once
+
+#include <atomic>
+#include <unordered_map>
+#include "lock.hh"
+
+class DNSDistPacketCache : boost::noncopyable
+{
+public:
+ DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL=86400, uint32_t minTTL=60);
+ ~DNSDistPacketCache();
+
+ void insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen);
+ bool get(const unsigned char* query, uint16_t queryLen, const DNSName& qname, uint16_t qtype, uint16_t qclass, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, bool skipAging=false);
+ void purge(size_t upTo=0);
+ void expunge(const DNSName& name, uint16_t qtype=QType::ANY);
+ bool isFull();
+ string toString();
+ uint64_t getSize() const { return d_map.size(); };
+ uint64_t getHits() const { return d_hits; };
+ uint64_t getMisses() const { return d_misses; };
+ uint64_t getDeferredLookups() const { return d_deferredLookups; };
+ uint64_t getDeferredInserts() const { return d_deferredInserts; };
+ uint64_t getLookupCollisions() const { return d_lookupCollisions; };
+ uint64_t getInsertCollisions() const { return d_insertCollisions; };
+
+ static uint32_t getMinTTL(const char* packet, uint16_t length);
+
+private:
+
+ struct CacheValue
+ {
+ time_t getTTD() const { return validity; }
+ std::string value;
+ DNSName qname;
+ uint16_t qtype{0};
+ uint16_t qclass{0};
+ time_t added{0};
+ time_t validity{0};
+ uint16_t len{0};
+ };
+
+ static uint32_t getKey(const DNSName& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen);
+ static bool cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass);
+
+ pthread_rwlock_t d_lock;
+ std::unordered_map<uint32_t,CacheValue> d_map;
+ std::atomic<uint64_t> d_deferredLookups{0};
+ std::atomic<uint64_t> d_deferredInserts{0};
+ std::atomic<uint64_t> d_hits{0};
+ std::atomic<uint64_t> d_misses{0};
+ std::atomic<uint64_t> d_insertCollisions{0};
+ std::atomic<uint64_t> d_lookupCollisions{0};
+ size_t d_maxEntries;
+ uint32_t d_maxTTL;
+ uint32_t d_minTTL;
+};
ret->qps=QPSLimiter(qps, qps);
}
+ auto localPools = g_pools.getCopy();
if(vars.count("pool")) {
if(auto* pool = boost::get<string>(&vars["pool"]))
ret->pools.insert(*pool);
for(auto& p : *pools)
ret->pools.insert(p.second);
}
+ for(const auto& poolName: ret->pools) {
+ addServerToPool(localPools, poolName, ret);
+ }
+ }
+ else {
+ addServerToPool(localPools, "", ret);
}
+ g_pools.setState(localPools);
if(vars.count("order")) {
ret->order=std::stoi(boost::get<string>(vars["order"]));
[](boost::variant<std::shared_ptr<DownstreamState>, int> var)
{
setLuaSideEffect();
- auto states = g_dstates.getCopy();
- if(auto* rem = boost::get<shared_ptr<DownstreamState>>(&var))
- states.erase(remove(states.begin(), states.end(), *rem), states.end());
- else
- states.erase(states.begin() + boost::get<int>(var));
- g_dstates.setState(states);
+ shared_ptr<DownstreamState> server;
+ auto* rem = boost::get<shared_ptr<DownstreamState>>(&var);
+ auto states = g_dstates.getCopy();
+ if(rem) {
+ server = *rem;
+ }
+ else {
+ int idx = boost::get<int>(var);
+ server = states[idx];
+ }
+ auto localPools = g_pools.getCopy();
+ for (const string& poolName : server->pools) {
+ removeServerFromPool(localPools, poolName, server);
+ }
+ g_pools.setState(localPools);
+ states.erase(remove(states.begin(), states.end(), server), states.end());
+ g_dstates.setState(states);
} );
return std::shared_ptr<DNSAction>(new DelayAction(msec));
});
-
g_lua.writeFunction("TCAction", []() {
return std::shared_ptr<DNSAction>(new TCAction);
});
return std::shared_ptr<DNSAction>(new RCodeAction(rcode));
});
+ g_lua.writeFunction("SkipCacheAction", []() {
+ return std::shared_ptr<DNSAction>(new SkipCacheAction);
+ });
+
g_lua.writeFunction("MaxQPSIPRule", [](unsigned int qps, boost::optional<int> ipv4trunc, boost::optional<int> ipv6trunc) {
return std::shared_ptr<DNSRule>(new MaxQPSIPRule(qps, ipv4trunc.get_value_or(32), ipv6trunc.get_value_or(64)));
});
});
g_lua.writeFunction("getPoolServers", [](string pool) {
- return getDownstreamCandidates(g_dstates.getCopy(), pool);
+ return getDownstreamCandidates(g_pools.getCopy(), pool);
});
g_lua.writeFunction("getServer", [client](int i) {
});
g_lua.registerFunction<void(DownstreamState::*)(int)>("setQPS", [](DownstreamState& s, int lim) { s.qps = lim ? QPSLimiter(lim, lim) : QPSLimiter(); });
- g_lua.registerFunction<void(DownstreamState::*)(string)>("addPool", [](DownstreamState& s, string pool) { s.pools.insert(pool);});
- g_lua.registerFunction<void(DownstreamState::*)(string)>("rmPool", [](DownstreamState& s, string pool) { s.pools.erase(pool);});
+ g_lua.registerFunction<void(std::shared_ptr<DownstreamState>::*)(string)>("addPool", [](std::shared_ptr<DownstreamState> s, string pool) {
+ auto localPools = g_pools.getCopy();
+ addServerToPool(localPools, pool, s);
+ g_pools.setState(localPools);
+ s->pools.insert(pool);
+ });
+ g_lua.registerFunction<void(std::shared_ptr<DownstreamState>::*)(string)>("rmPool", [](std::shared_ptr<DownstreamState> s, string pool) {
+ auto localPools = g_pools.getCopy();
+ removeServerFromPool(localPools, pool, s);
+ g_pools.setState(localPools);
+ s->pools.erase(pool);
+ });
g_lua.registerFunction<void(DownstreamState::*)()>("getOutstanding", [](const DownstreamState& s) { g_outputBuffer=std::to_string(s.outstanding.load()); });
g_lua.writeFunction("setMaxTCPClientThreads", [](uint64_t max) { g_maxTCPClientThreads = max; });
+ g_lua.writeFunction("setCacheCleaningDelay", [](uint32_t delay) { g_cacheCleaningDelay = delay; });
+
g_lua.writeFunction("setECSSourcePrefixV4", [](uint16_t prefix) { g_ECSSourcePrefixV4=prefix; });
g_lua.writeFunction("setECSSourcePrefixV6", [](uint16_t prefix) { g_ECSSourcePrefixV6=prefix; });
}
});
- moreLua();
+ moreLua(client);
std::ifstream ifs(config);
if(!ifs)
#include "dnsdist.hh"
+#include "dnsdist-cache.hh"
#include "dnsrulactions.hh"
#include <thread>
#include "dolog.hh"
}
-void moreLua()
+void moreLua(bool client)
{
typedef NetmaskTree<DynBlock> nmts_t;
g_lua.writeFunction("newCA", [](const std::string& name) { return ComboAddress(name); });
#endif
});
+ g_lua.writeFunction("showPools", []() {
+ setLuaNoSideEffect();
+ try {
+ ostringstream ret;
+ boost::format fmt("%1$-20.20s %|25t|%2$20s %|50t|%3%" );
+ // 1 3 4
+ ret << (fmt % "Name" % "Cache" % "Servers" ) << endl;
+
+ const auto localPools = g_pools.getCopy();
+ for (const auto& entry : localPools) {
+ const string& name = entry.first;
+ const std::shared_ptr<ServerPool> pool = entry.second;
+ string cache = pool->packetCache != nullptr ? pool->packetCache->toString() : "";
+ string servers;
+
+ for (const auto& server: pool->servers) {
+ if (!servers.empty()) {
+ servers += ", ";
+ }
+ if (!server.second->name.empty()) {
+ servers += server.second->name;
+ servers += " ";
+ }
+ servers += server.second->remote.toStringWithPort();
+ }
+
+ ret << (fmt % name % cache % servers) << endl;
+ }
+ g_outputBuffer=ret.str();
+ }catch(std::exception& e) { g_outputBuffer=e.what(); throw; }
+ });
+
+ g_lua.registerFunction<void(std::shared_ptr<ServerPool>::*)(std::shared_ptr<DNSDistPacketCache>)>("setCache", [](std::shared_ptr<ServerPool> pool, std::shared_ptr<DNSDistPacketCache> cache) {
+ pool->packetCache = cache;
+ });
+ g_lua.registerFunction("getCache", &ServerPool::getCache);
+
+ g_lua.writeFunction("newPacketCache", [client](size_t maxEntries, boost::optional<uint32_t> maxTTL, boost::optional<uint32_t> minTTL) {
+ return std::make_shared<DNSDistPacketCache>(maxEntries, maxTTL ? *maxTTL : 86400, minTTL ? *minTTL : 60);
+ });
+ g_lua.registerFunction("toString", &DNSDistPacketCache::toString);
+ g_lua.registerFunction("isFull", &DNSDistPacketCache::isFull);
+ g_lua.registerFunction("purge", &DNSDistPacketCache::purge);
+ g_lua.registerFunction<void(std::shared_ptr<DNSDistPacketCache>::*)(const DNSName& dname, boost::optional<uint16_t> qtype)>("expungeByName", [](std::shared_ptr<DNSDistPacketCache> cache, const DNSName& dname, boost::optional<uint16_t> qtype) {
+ cache->expunge(dname, qtype ? *qtype : QType::ANY);
+ });
+ g_lua.registerFunction<void(std::shared_ptr<DNSDistPacketCache>::*)()>("printStats", [](const std::shared_ptr<DNSDistPacketCache> cache) {
+ g_outputBuffer="Hits: " + std::to_string(cache->getHits()) + "\n";
+ g_outputBuffer+="Misses: " + std::to_string(cache->getMisses()) + "\n";
+ g_outputBuffer+="Deferred inserts: " + std::to_string(cache->getDeferredInserts()) + "\n";
+ g_outputBuffer+="Deferred lookups: " + std::to_string(cache->getDeferredLookups()) + "\n";
+ g_outputBuffer+="Lookup Collisions: " + std::to_string(cache->getLookupCollisions()) + "\n";
+ g_outputBuffer+="Insert Collisions: " + std::to_string(cache->getInsertCollisions()) + "\n";
+ });
+
+ g_lua.writeFunction("getPool", [client](const string& poolName) {
+ if (client) {
+ return std::make_shared<ServerPool>();
+ }
+ auto localPools = g_pools.getCopy();
+ std::shared_ptr<ServerPool> pool = createPoolIfNotExists(localPools, poolName);
+ g_pools.setState(localPools);
+ return pool;
+ });
}
auto localPolicy = g_policy.getLocal();
auto localRulactions = g_rulactions.getLocal();
auto localDynBlockNMG = g_dynblockNMG.getLocal();
+ auto localPools = g_pools.getLocal();
map<ComboAddress,int> sockets;
for(;;) {
delete citmp;
uint16_t qlen, rlen;
- string pool;
+ string poolname;
const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask);
break;
/* non-terminal actions follow */
case DNSAction::Action::Pool:
- pool=ruleresult;
+ poolname=ruleresult;
break;
case DNSAction::Action::Delay:
case DNSAction::Action::None:
if(dq.qtype == QType::AXFR || dq.qtype == QType::IXFR) // XXX fixme we really need to do better
break;
+ std::shared_ptr<ServerPool> serverPool = getPool(*localPools, poolname);
{
std::lock_guard<std::mutex> lock(g_luamutex);
- ds = localPolicy->policy(getDownstreamCandidates(g_dstates.getCopy(), pool), &dq);
+ ds = localPolicy->policy(serverPool->servers, &dq);
}
if(!ds) {
g_stats.noPolicy++;
}
}
+ uint32_t cacheKey = 0;
+ if (serverPool->packetCache && !dq.skipCache) {
+ char cachedResponse[4096];
+ uint16_t cachedResponseSize = sizeof cachedResponse;
+ if (serverPool->packetCache->get((unsigned char*) query, dq.len, *dq.qname, dq.qtype, dq.qclass, consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey)) {
+ if (putNonBlockingMsgLen(ci.fd, cachedResponseSize, g_tcpSendTimeout))
+ writen2WithTimeout(ci.fd, cachedResponse, cachedResponseSize, g_tcpSendTimeout);
+ g_stats.cacheHits++;
+ goto drop;
+ }
+ g_stats.cacheMisses++;
+ }
+
int dsock;
if(sockets.count(ds->remote) == 0) {
dsock=sockets[ds->remote]=setupTCPDownstream(ds);
memcpy(response + sizeof(dnsheader), realname.c_str(), realname.length());
}
+ if (serverPool->packetCache && !dq.skipCache) {
+ serverPool->packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen);
+ }
+
#ifdef HAVE_DNSCRYPT
if (ci.cs->dnscryptCtx) {
uint16_t encryptedResponseLen = 0;
#include <pwd.h>
#include "lock.hh"
#include <getopt.h>
+#include "dnsdist-cache.hh"
/* Known sins:
std::vector<std::pair<ComboAddress,DnsCryptContext>> g_dnsCryptLocals;
#endif
vector<ClientState *> g_frontends;
+GlobalStateHolder<pools_t> g_pools;
/* UDP: the grand design. Per socket we listen on for incoming queries there is one thread.
Then we have a bunch of connected sockets for talking to downstream servers.
g_stats.responses++;
+ if (ids->packetCache && !ids->skipCache) {
+ ids->packetCache->insert(ids->cacheKey, ids->qname, ids->qtype, ids->qclass, response, responseLen);
+ }
+
#ifdef HAVE_DNSCRYPT
uint16_t encryptedResponseLen = 0;
if(ids->dnsCryptQuery) {
// get server with least outstanding queries, and within those, with the lowest order, and within those: the fastest
shared_ptr<DownstreamState> leastOutstanding(const NumberedServerVector& servers, const DNSQuestion* dq)
{
+ if (servers.size() == 1 && servers[0].second->isUp()) {
+ return servers[0].second;
+ }
+
vector<pair<tuple<int,int,double>, shared_ptr<DownstreamState>>> poss;
/* so you might wonder, why do we go through this trouble? The data on which we sort could change during the sort,
which would suck royally and could even lead to crashes. So first we snapshot on what we sort, and then we sort */
ComboAddress g_serverControl{"127.0.0.1:5199"};
+std::shared_ptr<ServerPool> createPoolIfNotExists(pools_t& pools, const string& poolName)
+{
+ std::shared_ptr<ServerPool> pool;
+ pools_t::iterator it = pools.find(poolName);
+ if (it != pools.end()) {
+ pool = it->second;
+ }
+ else {
+ vinfolog("Creating pool %s", poolName);
+ pool = std::make_shared<ServerPool>();
+ pools.insert(std::pair<std::string,std::shared_ptr<ServerPool> >(poolName, pool));
+ }
+ return pool;
+}
-NumberedServerVector getDownstreamCandidates(const servers_t& servers, const std::string& pool)
+void addServerToPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server)
{
- NumberedServerVector ret;
- int count=0;
- for(const auto& s : servers)
- if((pool.empty() && s->pools.empty()) || s->pools.count(pool))
- ret.push_back(make_pair(++count, s));
-
- return ret;
+ std::shared_ptr<ServerPool> pool = createPoolIfNotExists(pools, poolName);
+ unsigned int count = pool->servers.size();
+ vinfolog("Adding server to pool %s", poolName);
+ pool->servers.push_back(make_pair(++count, server));
+}
+
+void removeServerFromPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server)
+{
+ vinfolog("Removing from pool %s", poolName);
+ pools_t::iterator poolIt = pools.find(poolName);
+ if (poolIt == pools.end()) {
+ throw std::out_of_range("No pool named " + poolName);
+ }
+
+ std::shared_ptr<ServerPool> pool = poolIt->second;
+
+ for (NumberedVector<shared_ptr<DownstreamState> >::iterator it = pool->servers.begin(); it != pool->servers.end(); it++) {
+ if (it->second == server) {
+ pool->servers.erase(it);
+ break;
+ }
+ }
+}
+
+std::shared_ptr<ServerPool> getPool(const pools_t& pools, const std::string& poolName)
+{
+ pools_t::const_iterator it = pools.find(poolName);
+
+ if (it == pools.end()) {
+ throw std::out_of_range("No pool named " + poolName);
+ }
+
+ return it->second;
+}
+
+const NumberedServerVector& getDownstreamCandidates(const pools_t& pools, const std::string& poolName)
+{
+ std::shared_ptr<ServerPool> pool = getPool(pools, poolName);
+ return pool->servers;
}
// goal in life - if you send us a reasonably normal packet, we'll get Z for you, otherwise 0
auto localRulactions = g_rulactions.getLocal();
auto localServers = g_dstates.getLocal();
auto localDynBlock = g_dynblockNMG.getLocal();
+ auto localPools = g_pools.getLocal();
struct msghdr msgh;
struct iovec iov;
/* used by HarvestDestinationAddress */
DNSAction::Action action=DNSAction::Action::None;
string ruleresult;
- string pool;
+ string poolname;
int delayMsec=0;
bool done=false;
for(const auto& lr : *localRulactions) {
break;
/* non-terminal actions follow */
case DNSAction::Action::Pool:
- pool=ruleresult;
+ poolname=ruleresult;
break;
case DNSAction::Action::Delay:
delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
continue;
}
- DownstreamState* ss = 0;
- auto candidates=getDownstreamCandidates(*localServers, pool);
+ DownstreamState* ss = nullptr;
+ std::shared_ptr<ServerPool> serverPool = getPool(*localPools, poolname);
auto policy=localPolicy->policy;
{
std::lock_guard<std::mutex> lock(g_luamutex);
- ss = policy(candidates, &dq).get();
+ ss = policy(serverPool->servers, &dq).get();
}
if(!ss) {
continue;
}
+ bool ednsAdded = false;
+ if (ss->useECS) {
+ handleEDNSClientSubnet(query, dq.size, consumed, &dq.len, largerQuery, &(ednsAdded), remote);
+ }
+
+ uint32_t cacheKey = 0;
+ if (serverPool->packetCache && !dq.skipCache) {
+ char cachedResponse[4096];
+ uint16_t cachedResponseSize = sizeof cachedResponse;
+ if (serverPool->packetCache->get((unsigned char*) query, dq.len, *dq.qname, dq.qtype, dq.qclass, consumed, dh->id, cachedResponse, &cachedResponseSize, &cacheKey)) {
+ ComboAddress dest;
+ if(HarvestDestinationAddress(&msgh, &dest))
+ sendfromto(cs->udpFD, cachedResponse, cachedResponseSize, 0, dest, remote);
+ else
+ sendto(cs->udpFD, cachedResponse, cachedResponseSize, 0, (struct sockaddr*)&remote, remote.getSocklen());
+ g_stats.cacheHits++;
+ continue;
+ }
+ g_stats.cacheMisses++;
+ }
+
ss->queries++;
unsigned int idOffset = (ss->idOffset++) % ss->idStates.size();
ids->sentTime.start();
ids->qname = qname;
ids->qtype = dq.qtype;
+ ids->qclass = dq.qclass;
ids->origDest.sin4.sin_family=0;
ids->delayMsec = delayMsec;
ids->origFlags = origFlags;
ids->ednsAdded = false;
+ ids->cacheKey = cacheKey;
+ ids->skipCache = dq.skipCache;
+ ids->packetCache = serverPool->packetCache;
+ ids->ednsAdded = ednsAdded;
#ifdef HAVE_DNSCRYPT
ids->dnsCryptQuery = dnsCryptQuery;
#endif
dh->id = idOffset;
- if (ss->useECS) {
- handleEDNSClientSubnet(query, dq.size, consumed, &dq.len, largerQuery, &(ids->ednsAdded), remote);
- }
-
if (largerQuery.empty()) {
ret = udpClientSendRequestToBackend(ss, ss->fd, query, dq.len);
}
}
std::atomic<uint64_t> g_maxTCPClientThreads{10};
+std::atomic<uint16_t> g_cacheCleaningDelay{60};
void* maintThread()
+{
+ int interval = 1;
+ size_t counter = 0;
+
+ for(;;) {
+ sleep(interval);
+
+ std::lock_guard<std::mutex> lock(g_luamutex);
+ auto f =g_lua.readVariable<boost::optional<std::function<void()> > >("maintenance");
+ if(f)
+ (*f)();
+
+ counter++;
+ if (counter >= g_cacheCleaningDelay) {
+ const auto localPools = g_pools.getCopy();
+ for (const auto& entry : localPools) {
+ if (entry.second->packetCache) {
+ entry.second->packetCache->purge();
+ }
+ }
+ counter = 0;
+ }
+
+ // ponder pruning g_dynblocks of expired entries here
+ }
+ return 0;
+}
+
+void* healthChecksThread()
{
int interval = 1;
}
}
}
-
- std::lock_guard<std::mutex> lock(g_luamutex);
- auto f =g_lua.readVariable<boost::optional<std::function<void()> > >("maintenance");
- if(f)
- (*f)();
-
-
- // ponder pruning g_dynblocks of expired entries here
}
return 0;
}
for(auto& t : todo)
t();
-
+ auto localPools = g_pools.getCopy();
if(g_cmdLine.remotes.size()) {
for(const auto& address : g_cmdLine.remotes) {
auto ret=std::make_shared<DownstreamState>(ComboAddress(address, 53));
+ addServerToPool(localPools, "", ret);
ret->tid = move(thread(responderThread, ret));
g_dstates.modify([ret](servers_t& servers) { servers.push_back(ret); });
}
}
+ g_pools.setState(localPools);
if(g_dstates.getCopy().empty()) {
errlog("No downstream servers defined: all packets will get dropped");
carbonthread.detach();
thread stattid(maintThread);
+ stattid.detach();
+ thread healththread(healthChecksThread);
+
if(g_cmdLine.beDaemon || g_cmdLine.beSupervised) {
- stattid.join();
+ healththread.join();
}
else {
- stattid.detach();
+ healththread.detach();
doConsole();
}
_exit(EXIT_SUCCESS);
#include <thread>
#include "sholder.hh"
#include "dnscrypt.hh"
+#include "dnsdist-cache.hh"
void* carbonDumpThread();
uint64_t uptimeOfProcess(const std::string& str);
stat_t downstreamSendErrors{0};
stat_t truncFail{0};
stat_t noPolicy{0};
+ stat_t cacheHits{0};
+ stat_t cacheMisses{0};
stat_t latency0_1{0}, latency1_10{0}, latency10_50{0}, latency50_100{0}, latency100_1000{0}, latencySlow{0};
double latencyAvg100{0}, latencyAvg1000{0}, latencyAvg10000{0}, latencyAvg1000000{0};
{"noncompliant-queries", &nonCompliantQueries},
{"rdqueries", &rdQueries},
{"empty-queries", &emptyQueries},
+ {"cache-hits", &cacheHits},
+ {"cache-misses", &cacheMisses},
{"cpu-user-msec", getCPUTimeUser},
{"cpu-sys-msec", getCPUTimeSystem},
{"fd-usage", getOpenFileDescriptors}, {"dyn-blocked", &dynBlocked},
#ifdef HAVE_DNSCRYPT
std::shared_ptr<DnsCryptQuery> dnsCryptQuery{0};
#endif
+ std::shared_ptr<DNSDistPacketCache> packetCache{nullptr};
+ uint32_t cacheKey; // 8
std::atomic<uint16_t> age; // 4
uint16_t qtype; // 2
+ uint16_t qclass; // 2
uint16_t origID; // 2
uint16_t origFlags; // 2
int delayMsec;
bool ednsAdded{false};
+ bool skipCache{false};
};
struct Rings {
size_t size;
uint16_t len;
const bool tcp;
+ bool skipCache{false};
};
template <class T> using NumberedVector = std::vector<std::pair<unsigned int, T> >;
policyfunc_t policy;
};
+struct ServerPool
+{
+ const std::shared_ptr<DNSDistPacketCache> getCache() const { return packetCache; };
+
+ NumberedVector<shared_ptr<DownstreamState>> servers;
+ std::shared_ptr<DNSDistPacketCache> packetCache{nullptr};
+};
+using pools_t=map<std::string,std::shared_ptr<ServerPool>>;
+void addServerToPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server);
+void removeServerFromPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server);
+
struct CarbonConfig
{
ComboAddress server{"0.0.0.0", 0};
extern GlobalStateHolder<CarbonConfig> g_carbon;
extern GlobalStateHolder<ServerPolicy> g_policy;
extern GlobalStateHolder<servers_t> g_dstates;
+extern GlobalStateHolder<pools_t> g_pools;
extern GlobalStateHolder<vector<pair<std::shared_ptr<DNSRule>, std::shared_ptr<DNSAction> > > > g_rulactions;
extern GlobalStateHolder<NetmaskGroup> g_ACL;
extern uint16_t g_maxOutstanding;
extern std::atomic<bool> g_configurationDone;
extern std::atomic<uint64_t> g_maxTCPClientThreads;
+extern std::atomic<uint16_t> g_cacheCleaningDelay;
extern uint16_t g_ECSSourcePrefixV4;
extern uint16_t g_ECSSourcePrefixV6;
extern bool g_ECSOverride;
void controlThread(int fd, ComboAddress local);
vector<std::function<void(void)>> setupLua(bool client, const std::string& config);
-NumberedServerVector getDownstreamCandidates(const servers_t& servers, const std::string& pool);
+std::shared_ptr<ServerPool> getPool(const pools_t& pools, const std::string& poolName);
+std::shared_ptr<ServerPool> createPoolIfNotExists(pools_t& pools, const string& poolName);
+const NumberedServerVector& getDownstreamCandidates(const pools_t& pools, const std::string& poolName);
std::shared_ptr<DownstreamState> firstAvailable(const NumberedServerVector& servers, const DNSQuestion* dq);
bool putMsgLen32(int fd, uint32_t len);
void* tcpAcceptorThread(void* p);
-void moreLua();
+void moreLua(bool client);
void doClient(ComboAddress server, const std::string& command);
void doConsole();
void controlClientThread(int fd, ComboAddress client);
dns.cc dns.hh \
dnscrypt.cc dnscrypt.hh \
dnsdist.cc dnsdist.hh \
+ dnsdist-cache.cc dnsdist-cache.hh \
dnsdist-carbon.cc \
dnsdist-console.cc \
dnsdist-dnscrypt.cc \
dns.hh \
test-base64_cc.cc \
test-dnsdist_cc.cc \
+ test-dnsdistpacketcache_cc.cc \
test-dnscrypt_cc.cc \
dnsdist.hh \
+ dnsdist-cache.cc dnsdist-cache.hh \
dnsdist-ecs.cc dnsdist-ecs.hh \
dnscrypt.cc dnscrypt.hh \
dnslabeltext.cc \
--- /dev/null
+../dnsdist-cache.cc
\ No newline at end of file
--- /dev/null
+../dnsdist-cache.hh
\ No newline at end of file
--- /dev/null
+../test-dnsdistpacketcache_cc.cc
\ No newline at end of file
{
public:
explicit DNSPacketMangler(std::string& packet)
- : d_packet(packet), d_notyouroffset(12), d_offset(d_notyouroffset)
+ : d_packet((char*) packet.c_str()), d_length(packet.length()), d_notyouroffset(12), d_offset(d_notyouroffset)
+ {}
+ DNSPacketMangler(char* packet, size_t length)
+ : d_packet(packet), d_length(length), d_notyouroffset(12), d_offset(d_notyouroffset)
{}
void skipLabel()
}
uint16_t get16BitInt()
{
- const char* p = d_packet.c_str() + d_offset;
+ const char* p = d_packet + d_offset;
moveOffset(2);
uint16_t ret;
memcpy(&ret, (void*)p, 2);
uint8_t get8BitInt()
{
- const char* p = d_packet.c_str() + d_offset;
+ const char* p = d_packet + d_offset;
moveOffset(1);
return *p;
}
}
void decreaseAndSkip32BitInt(uint32_t decrease)
{
- const char *p = (const char*)d_packet.c_str() + d_offset;
+ const char *p = d_packet + d_offset;
moveOffset(4);
uint32_t tmp;
tmp = ntohl(tmp);
tmp-=decrease;
tmp = htonl(tmp);
- d_packet.replace(d_offset-4, sizeof(tmp), (const char*)&tmp, sizeof(tmp));
+ memcpy(d_packet + d_offset-4, (const char*)&tmp, sizeof(tmp));
}
private:
void moveOffset(uint16_t by)
{
d_notyouroffset += by;
- if(d_notyouroffset > d_packet.length())
+ if(d_notyouroffset > d_length)
throw std::out_of_range("dns packet out of range: "+std::to_string(d_notyouroffset) +" > "
- + std::to_string(d_packet.length()) );
+ + std::to_string(d_length) );
}
- std::string& d_packet;
+ char* d_packet;
+ size_t d_length;
uint32_t d_notyouroffset; // only 'moveOffset' can touch this
const uint32_t& d_offset; // look.. but don't touch
};
// method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
-void ageDNSPacket(std::string& packet, uint32_t seconds)
+void ageDNSPacket(char* packet, size_t length, uint32_t seconds)
{
- if(packet.length() < sizeof(dnsheader))
+ if(length < sizeof(dnsheader))
return;
try
{
dnsheader dh;
- memcpy((void*)&dh, (const dnsheader*)packet.c_str(), sizeof(dh));
+ memcpy((void*)&dh, (const dnsheader*)packet, sizeof(dh));
int numrecords = ntohs(dh.ancount) + ntohs(dh.nscount) + ntohs(dh.arcount);
- DNSPacketMangler dpm(packet);
+ DNSPacketMangler dpm(packet, length);
int n;
for(n=0; n < ntohs(dh.qdcount) ; ++n) {
return;
}
}
+
+void ageDNSPacket(std::string& packet, uint32_t seconds)
+{
+ ageDNSPacket((char*)packet.c_str(), packet.length(), seconds);
+}
string simpleCompress(const string& label, const string& root="");
void simpleExpandTo(const string& label, unsigned int frompos, string& ret);
+void ageDNSPacket(char* packet, size_t length, uint32_t seconds);
void ageDNSPacket(std::string& packet, uint32_t seconds);
template<typename T>
return "set cd=1";
}
};
+
+class SkipCacheAction : public DNSAction
+{
+public:
+ DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
+ {
+ dq->skipCache = true;
+ return Action::None;
+ }
+ string toString() const override
+ {
+ return "skip cache";
+ }
+};
--- /dev/null
+#define BOOST_TEST_DYN_LINK
+#define BOOST_TEST_NO_MAIN
+
+#include <boost/test/unit_test.hpp>
+
+#include "iputils.hh"
+#include "dnsdist-cache.hh"
+#include "dnswriter.hh"
+
+BOOST_AUTO_TEST_SUITE(dnsdistpacketcache_cc)
+
+BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) {
+ const size_t maxEntries = 150000;
+ DNSDistPacketCache PC(maxEntries, 86400, 1);
+ BOOST_CHECK_EQUAL(PC.getSize(), 0);
+
+ size_t counter=0;
+ size_t skipped=0;
+ try {
+ for(counter = 0; counter < 100000; ++counter) {
+ DNSName a=DNSName("hello ")+DNSName(std::to_string(counter));
+ BOOST_CHECK_EQUAL(DNSName(a.toString()), a);
+
+ vector<uint8_t> query;
+ DNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0);
+ pwQ.getHeader()->rd = 1;
+
+ vector<uint8_t> response;
+ DNSPacketWriter pwR(response, a, QType::A, QClass::IN, 0);
+ pwR.getHeader()->rd = 1;
+ pwR.getHeader()->ra = 1;
+ pwR.getHeader()->qr = 1;
+ pwR.getHeader()->id = pwQ.getHeader()->id;
+ pwR.startRecord(a, QType::A, 100, QClass::IN, DNSResourceRecord::ANSWER);
+ pwR.xfr32BitInt(0x01020304);
+ pwR.commit();
+ uint16_t responseLen = response.size();
+
+ char responseBuf[4096];
+ uint16_t responseBufSize = sizeof(responseBuf);
+ uint32_t key = 0;
+ bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, &key);
+ BOOST_CHECK_EQUAL(found, false);
+
+ PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen);
+
+ found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, true);
+ if (found == true) {
+ BOOST_CHECK_EQUAL(responseBufSize, responseLen);
+ int match = memcmp(responseBuf, response.data(), responseLen);
+ BOOST_CHECK_EQUAL(match, 0);
+ }
+ else {
+ skipped++;
+ }
+ }
+
+ BOOST_CHECK_EQUAL(skipped, PC.getInsertCollisions());
+ BOOST_CHECK_EQUAL(PC.getSize(), counter - skipped);
+
+ size_t deleted=0;
+ size_t delcounter=0;
+ for(delcounter=0; delcounter < counter/1000; ++delcounter) {
+ DNSName a=DNSName("hello ")+DNSName(std::to_string(delcounter));
+ vector<uint8_t> query;
+ DNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0);
+ pwQ.getHeader()->rd = 1;
+ char responseBuf[4096];
+ uint16_t responseBufSize = sizeof(responseBuf);
+ uint32_t key = 0;
+ bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, &key);
+ if (found == true) {
+ PC.expunge(a);
+ deleted++;
+ }
+ }
+ BOOST_CHECK_EQUAL(PC.getSize(), counter - skipped - deleted);
+
+ size_t matches=0;
+ vector<DNSResourceRecord> entry;
+ size_t expected=counter-skipped-deleted;
+ for(; delcounter < counter; ++delcounter) {
+ DNSName a(DNSName("hello ")+DNSName(std::to_string(delcounter)));
+ vector<uint8_t> query;
+ DNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0);
+ pwQ.getHeader()->rd = 1;
+ uint16_t len = query.size();
+ uint32_t key = 0;
+ char response[4096];
+ uint16_t responseSize = sizeof(response);
+ if(PC.get(query.data(), len, a, QType::A, QClass::IN, a.wirelength(), pwQ.getHeader()->id, response, &responseSize, &key)) {
+ matches++;
+ }
+ }
+ BOOST_CHECK_EQUAL(matches, expected);
+ }
+ catch(PDNSException& e) {
+ cerr<<"Had error: "<<e.reason<<endl;
+ throw;
+ }
+}
+
+static DNSDistPacketCache PC(500000);
+
+static void *threadMangler(void* a)
+{
+ try {
+ unsigned int offset=(unsigned int)(unsigned long)a;
+ for(unsigned int counter=0; counter < 100000; ++counter) {
+ DNSName a=DNSName("hello ")+DNSName(std::to_string(counter+offset));
+ vector<uint8_t> query;
+ DNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0);
+ pwQ.getHeader()->rd = 1;
+
+ vector<uint8_t> response;
+ DNSPacketWriter pwR(response, a, QType::A, QClass::IN, 0);
+ pwR.getHeader()->rd = 1;
+ pwR.getHeader()->ra = 1;
+ pwR.getHeader()->qr = 1;
+ pwR.getHeader()->id = pwQ.getHeader()->id;
+ pwR.startRecord(a, QType::A, 3600, QClass::IN, DNSResourceRecord::ANSWER);
+ pwR.xfr32BitInt(0x01020304);
+ pwR.commit();
+ uint16_t responseLen = response.size();
+
+ char responseBuf[4096];
+ uint16_t responseBufSize = sizeof(responseBuf);
+ uint32_t key = 0;
+ PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, &key);
+
+ PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen);
+ }
+ }
+ catch(PDNSException& e) {
+ cerr<<"Had error: "<<e.reason<<endl;
+ throw;
+ }
+ return 0;
+}
+
+AtomicCounter g_missing;
+
+static void *threadReader(void* a)
+{
+ try
+ {
+ unsigned int offset=(unsigned int)(unsigned long)a;
+ vector<DNSResourceRecord> entry;
+ for(unsigned int counter=0; counter < 100000; ++counter) {
+ DNSName a=DNSName("hello ")+DNSName(std::to_string(counter+offset));
+ vector<uint8_t> query;
+ DNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0);
+ pwQ.getHeader()->rd = 1;
+
+ char responseBuf[4096];
+ uint16_t responseBufSize = sizeof(responseBuf);
+ uint32_t key = 0;
+ bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, &key);
+ if (!found) {
+ g_missing++;
+ }
+ }
+ }
+ catch(PDNSException& e) {
+ cerr<<"Had error in threadReader: "<<e.reason<<endl;
+ throw;
+ }
+ return 0;
+}
+
+BOOST_AUTO_TEST_CASE(test_PacketCacheThreaded) {
+ try {
+ pthread_t tid[4];
+ for(int i=0; i < 4; ++i)
+ pthread_create(&tid[i], 0, threadMangler, (void*)(i*1000000UL));
+ void* res;
+ for(int i=0; i < 4 ; ++i)
+ pthread_join(tid[i], &res);
+
+ BOOST_CHECK_EQUAL(PC.getSize() + PC.getDeferredInserts() + PC.getInsertCollisions(), 400000);
+ BOOST_CHECK_SMALL(1.0*PC.getInsertCollisions(), 10000.0);
+
+ for(int i=0; i < 4; ++i)
+ pthread_create(&tid[i], 0, threadReader, (void*)(i*1000000UL));
+ for(int i=0; i < 4 ; ++i)
+ pthread_join(tid[i], &res);
+
+ BOOST_CHECK((PC.getDeferredInserts() + PC.getDeferredLookups() + PC.getInsertCollisions()) >= g_missing);
+ }
+ catch(PDNSException& e) {
+ cerr<<"Had error: "<<e.reason<<endl;
+ throw;
+ }
+
+}
+
+BOOST_AUTO_TEST_SUITE_END()
(_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
receivedResponse.id = expectedResponse.id
self.assertEquals(receivedResponse, expectedResponse)
+
+class TestAdvancedCaching(DNSDistTest):
+
+ _config_template = """
+ pc = newPacketCache(5, 86400, 1)
+ getPool(""):setCache(pc)
+ addAction(makeRule("nocache.tests.powerdns.com."), SkipCacheAction())
+ newServer{address="127.0.0.1:%s"}
+ """
+ def testCached(self):
+ """
+ Advanced: Served from cache
+
+ dnsdist is configured to cache entries, we are sending several
+ identical requests and checking that the backend only receive
+ the first one.
+ """
+ numberOfQueries = 10
+ name = 'cached.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'AAAA', 'IN')
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.AAAA,
+ '::1')
+ response.answer.append(rrset)
+
+ # first query to fill the cache
+ (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(receivedResponse, response)
+
+ for idx in range(numberOfQueries):
+ (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+ receivedResponse.id = response.id
+ self.assertEquals(receivedResponse, response)
+
+ (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+ receivedResponse.id = response.id
+ self.assertEquals(receivedResponse, response)
+
+ total = 0
+ for key in TestAdvancedCaching._responsesCounter:
+ total += TestAdvancedCaching._responsesCounter[key]
+
+ self.assertEquals(total, 1)
+
+ def testSkipCache(self):
+ """
+ Advanced: SkipCacheAction
+
+ dnsdist is configured to not cache entries for nocache.tests.powerdns.com.
+ we are sending several requests and checking that the backend get them all.
+ """
+ name = 'nocache.tests.powerdns.com.'
+ numberOfQueries = 10
+ query = dns.message.make_query(name, 'AAAA', 'IN')
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.AAAA,
+ '::1')
+ response.answer.append(rrset)
+
+ for idx in range(numberOfQueries):
+ (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(receivedResponse, response)
+
+ (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(receivedResponse, response)
+
+ for key in TestAdvancedCaching._responsesCounter:
+ value = TestAdvancedCaching._responsesCounter[key]
+ self.assertEquals(value, numberOfQueries)
+
+ def testCacheExpiration(self):
+ """
+ Advanced: Cache expiration
+
+ dnsdist is configured to cache entries, we are sending one request
+ (cache miss) with a very short TTL, checking that the next requests
+ are cached. Then we wait for the TTL to expire, check that the
+ next request is a miss but the following one a hit.
+ """
+ ttl = 2
+ misses = 0
+ name = 'cacheexpiration.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'AAAA', 'IN')
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ ttl,
+ dns.rdataclass.IN,
+ dns.rdatatype.AAAA,
+ '::1')
+ response.answer.append(rrset)
+
+ # first query to fill the cache
+ (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(receivedResponse, response)
+ misses += 1
+
+ # next queries should hit the cache
+ (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+ receivedResponse.id = response.id
+ self.assertEquals(receivedResponse, response)
+
+ (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+ receivedResponse.id = response.id
+ self.assertEquals(receivedResponse, response)
+
+ # now we wait a bit for the cache entry to expire
+ time.sleep(ttl + 1)
+
+ # next query should be a miss, fill the cache again
+ (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(receivedResponse, response)
+ misses += 1
+
+ # following queries should hit the cache again
+ (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+ receivedResponse.id = response.id
+ self.assertEquals(receivedResponse, response)
+
+ (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+ receivedResponse.id = response.id
+ self.assertEquals(receivedResponse, response)
+
+ total = 0
+ for key in TestAdvancedCaching._responsesCounter:
+ total += TestAdvancedCaching._responsesCounter[key]
+
+ self.assertEquals(total, misses)
+
+ def testCacheDecreaseTTL(self):
+ """
+ Advanced: Cache decreases TTL
+
+ dnsdist is configured to cache entries, we are sending one request
+ (cache miss) and verify that the cache hits have a decreasing TTL.
+ """
+ ttl = 600
+ misses = 0
+ name = 'cachedecreasettl.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'AAAA', 'IN')
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ ttl,
+ dns.rdataclass.IN,
+ dns.rdatatype.AAAA,
+ '::1')
+ response.answer.append(rrset)
+
+ # first query to fill the cache
+ (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(receivedResponse, response)
+ misses += 1
+
+ # next queries should hit the cache
+ (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+ receivedResponse.id = response.id
+ self.assertEquals(receivedResponse, response)
+ for an in receivedResponse.answer:
+ self.assertTrue(an.ttl <= ttl)
+
+ (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+ receivedResponse.id = response.id
+ self.assertEquals(receivedResponse, response)
+ for an in receivedResponse.answer:
+ self.assertTrue(an.ttl <= ttl)
+
+ # now we wait a bit for the TTL to decrease
+ time.sleep(1)
+
+ # next queries should hit the cache
+ (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+ receivedResponse.id = response.id
+ self.assertEquals(receivedResponse, response)
+ for an in receivedResponse.answer:
+ self.assertTrue(an.ttl < ttl)
+
+ (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+ receivedResponse.id = response.id
+ self.assertEquals(receivedResponse, response)
+ for an in receivedResponse.answer:
+ self.assertTrue(an.ttl < ttl)
+
+ total = 0
+ for key in TestAdvancedCaching._responsesCounter:
+ total += TestAdvancedCaching._responsesCounter[key]
+
+ self.assertEquals(total, misses)
+
+ def testCacheDifferentCase(self):
+ """
+ Advanced: Cache matches different case
+
+ dnsdist is configured to cache entries, we are sending one request
+ (cache miss) and verify that the same one with a different case
+ matches.
+ """
+ ttl = 600
+ name = 'cachedifferentcase.tests.powerdns.com.'
+ differentCaseName = 'CacheDifferentCASE.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'AAAA', 'IN')
+ differentCaseQuery = dns.message.make_query(differentCaseName, 'AAAA', 'IN')
+ response = dns.message.make_response(query)
+ differentCaseResponse = dns.message.make_response(differentCaseQuery)
+ rrset = dns.rrset.from_text(name,
+ ttl,
+ dns.rdataclass.IN,
+ dns.rdatatype.AAAA,
+ '::1')
+ response.answer.append(rrset)
+ differentCaseResponse.answer.append(rrset)
+
+ # first query to fill the cache
+ (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(receivedResponse, response)
+
+ # different case query should still hit the cache
+ (_, receivedResponse) = self.sendUDPQuery(differentCaseQuery, response=None, useQueue=False)
+ receivedResponse.id = differentCaseResponse.id
+ self.assertEquals(receivedResponse, differentCaseResponse)
+
+ (_, receivedResponse) = self.sendTCPQuery(differentCaseQuery, response=None, useQueue=False)
+ receivedResponse.id = differentCaseResponse.id
+ self.assertEquals(receivedResponse, differentCaseResponse)
+
+class TestAdvancedCachingWithExistingEDNS(DNSDistTest):
+
+ _config_template = """
+ pc = newPacketCache(5, 86400, 1)
+ getPool(""):setCache(pc)
+ newServer{address="127.0.0.1:%s"}
+ """
+ def testCacheWithEDNS(self):
+ """
+ Advanced: Cache should not match different EDNS value
+
+ dnsdist is configured to cache entries, we are sending one request
+ (cache miss) and verify that the same one with a different EDNS UDP
+ Payload size is not served from the cache.
+ """
+ misses = 0
+ name = 'cachedifferentedns.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=512)
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+
+ (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(response, receivedResponse)
+ misses += 1
+
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+
+ (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(response, receivedResponse)
+ misses += 1
+
+ total = 0
+ for key in TestAdvancedCaching._responsesCounter:
+ total += TestAdvancedCaching._responsesCounter[key]
+
+ self.assertEquals(total, misses)