]> granicus.if.org Git - pdns/commitdiff
dnsdist: Prevent the cache ptr from being altered under our feet
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 25 Feb 2016 14:46:22 +0000 (15:46 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 25 Feb 2016 14:46:22 +0000 (15:46 +0100)
Make sure we hold the Lua mutex before getting the packet cache
shared_ptr, so that we don't have a thread reading it at the
exact same time it is altered by another.
We could have used atomic_load/atomic_store but libstdc++ uses
a pool of mutex for that anyway.
This might fix #3396.

pdns/dnsdist-tcp.cc
pdns/dnsdist.cc

index e2014152f5337c1c750289f4374af7ca5fafb0ea..5b18cedc0780ba12fddfc0dc14b86f890664d35c 100644 (file)
@@ -323,9 +323,11 @@ void* tcpClientThread(int pipefd)
          break;
 
         std::shared_ptr<ServerPool> serverPool = getPool(*localPools, poolname);
+        std::shared_ptr<DNSDistPacketCache> packetCache = nullptr;
        {
          std::lock_guard<std::mutex> lock(g_luamutex);
          ds = localPolicy->policy(serverPool->servers, &dq);
+         packetCache = serverPool->packetCache;
        }
        if(!ds) {
          g_stats.noPolicy++;
@@ -345,10 +347,10 @@ void* tcpClientThread(int pipefd)
         }
 
         uint32_t cacheKey = 0;
-        if (serverPool->packetCache && !dq.skipCache) {
+        if (packetCache && !dq.skipCache) {
           char cachedResponse[4096];
           uint16_t cachedResponseSize = sizeof cachedResponse;
-          if (serverPool->packetCache->get((unsigned char*) query, dq.len, *dq.qname, dq.qtype, dq.qclass, consumed, dq.dh->id, cachedResponse, &cachedResponseSize, true, &cacheKey)) {
+          if (packetCache->get((unsigned char*) query, dq.len, *dq.qname, dq.qtype, dq.qclass, consumed, dq.dh->id, cachedResponse, &cachedResponseSize, true, &cacheKey)) {
             if (putNonBlockingMsgLen(ci.fd, cachedResponseSize, g_tcpSendTimeout))
               writen2WithTimeout(ci.fd, cachedResponse, cachedResponseSize, g_tcpSendTimeout);
             g_stats.cacheHits++;
@@ -475,8 +477,8 @@ void* tcpClientThread(int pipefd)
          memcpy(response + sizeof(dnsheader), realname.c_str(), realname.length());
        }
 
-       if (serverPool->packetCache && !dq.skipCache) {
-         serverPool->packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true);
+       if (packetCache && !dq.skipCache) {
+         packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true);
        }
 
 #ifdef HAVE_DNSCRYPT
index 9bd5d2de1fdc076000bfbea89c4c0a2853bbed15..d83f6b5a5916c2b38cb8b006284b919239f544e0 100644 (file)
@@ -805,10 +805,12 @@ try
 
       DownstreamState* ss = nullptr;
       std::shared_ptr<ServerPool> serverPool = getPool(*localPools, poolname);
+      std::shared_ptr<DNSDistPacketCache> packetCache = nullptr;
       auto policy=localPolicy->policy;
       {
        std::lock_guard<std::mutex> lock(g_luamutex);
        ss = policy(serverPool->servers, &dq).get();
+       packetCache = serverPool->packetCache;
       }
 
       if(!ss) {
@@ -822,10 +824,10 @@ try
       }
 
       uint32_t cacheKey = 0;
-      if (serverPool->packetCache && !dq.skipCache) {
+      if (packetCache && !dq.skipCache) {
         char cachedResponse[4096];
         uint16_t cachedResponseSize = sizeof cachedResponse;
-        if (serverPool->packetCache->get((unsigned char*) query, dq.len, *dq.qname, dq.qtype, dq.qclass, consumed, dh->id, cachedResponse, &cachedResponseSize, false, &cacheKey)) {
+        if (packetCache->get((unsigned char*) query, dq.len, *dq.qname, dq.qtype, dq.qclass, consumed, dh->id, cachedResponse, &cachedResponseSize, false, &cacheKey)) {
           ComboAddress dest;
           if(HarvestDestinationAddress(&msgh, &dest))
             sendfromto(cs->udpFD, cachedResponse, cachedResponseSize, 0, dest, remote);
@@ -863,7 +865,7 @@ try
       ids->ednsAdded = false;
       ids->cacheKey = cacheKey;
       ids->skipCache = dq.skipCache;
-      ids->packetCache = serverPool->packetCache;
+      ids->packetCache = packetCache;
       ids->ednsAdded = ednsAdded;
 #ifdef HAVE_DNSCRYPT
       ids->dnsCryptQuery = dnsCryptQuery;
@@ -1017,9 +1019,14 @@ void* maintThread()
     counter++;
     if (counter >= g_cacheCleaningDelay) {
       const auto localPools = g_pools.getCopy();
+      std::shared_ptr<DNSDistPacketCache> packetCache = nullptr;
       for (const auto& entry : localPools) {
-        if (entry.second->packetCache) {
-          entry.second->packetCache->purge();
+        {
+          std::lock_guard<std::mutex> lock(g_luamutex);
+          packetCache = entry.second->packetCache;
+        }
+        if (packetCache) {
+          packetCache->purge();
         }
       }
       counter = 0;