]> granicus.if.org Git - pdns/commitdiff
dnsdist: Don't parse DNS names when caching responses
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 1 Mar 2016 15:47:04 +0000 (16:47 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 1 Mar 2016 15:47:04 +0000 (16:47 +0100)
Fix a crash reported by @rygl:

terminate called after throwing an instance of 'std::out_of_range'
what():  dnsname issue: Found a forward reference during label decompression

pdns/dnsdist-cache.cc
pdns/dnsparser.cc
pdns/dnsparser.hh

index 29a1c587e58506d166a6f140ed91d858b4755fa4..2d951bc56b61e6f522ea68ef4002a9d94ed91cd6 100644 (file)
@@ -190,51 +190,7 @@ bool DNSDistPacketCache::isFull()
 
 uint32_t DNSDistPacketCache::getMinTTL(const char* packet, uint16_t length)
 {
-  const struct dnsheader* dh = (const struct dnsheader*) packet;
-  uint32_t result = std::numeric_limits<uint32_t>::max();
-  vector<uint8_t> content(length - sizeof(dnsheader));
-  copy(packet + sizeof(dnsheader), packet + length, content.begin());
-  PacketReader pr(content);
-  size_t idx = 0;
-  DNSName rrname;
-  uint16_t qdcount = ntohs(dh->qdcount);
-  uint16_t ancount = ntohs(dh->ancount);
-  uint16_t nscount = ntohs(dh->nscount);
-  uint16_t arcount = ntohs(dh->arcount);
-  uint16_t rrtype;
-  uint16_t rrclass;
-  struct dnsrecordheader ah;
-
-  /* consume qd */
-  for(idx = 0; idx < qdcount; idx++) {
-    rrname = pr.getName();
-    rrtype = pr.get16BitInt();
-    rrclass = pr.get16BitInt();
-    (void) rrtype;
-    (void) rrclass;
-  }
-
-  /* consume AN and NS */
-  for (idx = 0; idx < ancount + nscount; idx++) {
-    rrname = pr.getName();
-    pr.getDnsrecordheader(ah);
-    pr.d_pos += ah.d_clen;
-    if (result > ah.d_ttl)
-      result = ah.d_ttl;
-  }
-
-  /* consume AR, watch for OPT */
-  for (idx = 0; idx < arcount; idx++) {
-    rrname = pr.getName();
-    pr.getDnsrecordheader(ah);
-    pr.d_pos += ah.d_clen;
-    if (ah.d_type == QType::OPT) {
-      continue;
-    }
-    if (result > ah.d_ttl)
-      result = ah.d_ttl;
-  }
-  return result;
+  return getDNSPacketMinTTL(packet, length);
 }
 
 uint32_t DNSDistPacketCache::getKey(const DNSName& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen, bool tcp)
index 67f64a32e06ecc36bfced26ecd7ee053c29075ef..f5330aadb6479ca53b4be289cdd18b128a2fe9aa 100644 (file)
@@ -586,6 +586,14 @@ public:
   {
       moveOffset(bytes);
   }
+  uint32_t get32BitInt()
+  {
+    const char* p = d_packet + d_offset;
+    moveOffset(4);
+    uint32_t ret;
+    memcpy(&ret, (void*)p, 4);
+    return ntohl(ret);
+  }
   uint16_t get16BitInt()
   {
     const char* p = d_packet + d_offset;
@@ -676,3 +684,42 @@ void ageDNSPacket(std::string& packet, uint32_t seconds)
 {
   ageDNSPacket((char*)packet.c_str(), packet.length(), seconds);
 }
+
+uint32_t getDNSPacketMinTTL(const char* packet, size_t length)
+{
+  uint32_t result = std::numeric_limits<uint32_t>::max();
+  if(length < sizeof(dnsheader)) {
+    return result;
+  }
+  try
+  {
+    const dnsheader* dh = (const dnsheader*) packet;
+    DNSPacketMangler dpm(const_cast<char*>(packet), length);
+
+    const uint16_t qdcount = ntohs(dh->qdcount);
+    for(size_t n = 0; n < qdcount; ++n) {
+      dpm.skipLabel();
+      dpm.skipBytes(4); // qtype, qclass
+    }
+    const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
+    for(size_t n = 0; n < numrecords; ++n) {
+      dpm.skipLabel();
+
+      const uint16_t dnstype = dpm.get16BitInt();
+      /* uint16_t dnsclass = */ dpm.get16BitInt();
+
+      if(dnstype == QType::OPT)
+        break;
+
+      const uint32_t ttl = dpm.get32BitInt();
+      if (result > ttl)
+        result = ttl;
+
+      dpm.skipRData();
+    }
+  }
+  catch(...)
+  {
+  }
+  return result;
+}
index 94df048ec7b8797b984fe99a0f9f5864a08da1d0..35767c3fa73b9af86e040d2b62f70ea03cf264ef 100644 (file)
@@ -346,6 +346,7 @@ string simpleCompress(const string& label, const string& root="");
 void simpleExpandTo(const string& label, unsigned int frompos, string& ret);
 void ageDNSPacket(char* packet, size_t length, uint32_t seconds);
 void ageDNSPacket(std::string& packet, uint32_t seconds);
+uint32_t getDNSPacketMinTTL(const char* packet, size_t length);
 
 template<typename T>
 std::shared_ptr<T> getRR(const DNSRecord& dr)