]> granicus.if.org Git - pdns/commitdiff
dnsdist: Add DNSQuestion:getEDNSOptions() to access incoming EDNS options
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 7 Dec 2017 15:36:58 +0000 (16:36 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 2 Oct 2018 08:59:03 +0000 (10:59 +0200)
14 files changed:
pdns/dnsdist-cache.cc
pdns/dnsdist-ecs.cc
pdns/dnsdist-ecs.hh
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdist-lua-bindings.cc
pdns/dnsdist-lua-vars.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/docs/reference/dq.rst
pdns/test-dnsdist_cc.cc
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_EDNSOptions.py [new file with mode: 0644]

index bf092f7309f10b9d38d2b7f2738d5df9bf58560a..f9162e16b729fb71b117fb89fa1ecc0d7aef5c2a 100644 (file)
@@ -54,16 +54,16 @@ DNSDistPacketCache::~DNSDistPacketCache()
 
 bool DNSDistPacketCache::getClientSubnet(const char* packet, unsigned int consumed, uint16_t len, boost::optional<Netmask>& subnet)
 {
-  char * optRDLen = NULL;
+  uint16_t optRDPosition;
   size_t remaining = 0;
 
-  int res = getEDNSOptionsStart(const_cast<char*>(packet), consumed, len, &optRDLen, &remaining);
+  int res = getEDNSOptionsStart(const_cast<char*>(packet), consumed, len, &optRDPosition, &remaining);
 
   if (res == 0) {
-    char * ecsOptionStart = NULL;
+    char * ecsOptionStart = nullptr;
     size_t ecsOptionSize = 0;
 
-    res = getEDNSOption(optRDLen, remaining, EDNSOptionCode::ECS, &ecsOptionStart, &ecsOptionSize);
+    res = getEDNSOption(const_cast<char*>(packet) + optRDPosition, remaining, EDNSOptionCode::ECS, &ecsOptionStart, &ecsOptionSize);
 
     if (res == 0 && ecsOptionSize > (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)) {
 
index b94b0eeeed1de7c99ae81937f1453f2b9e97b55d..71b852f25f4274b42d32147ad91494197dd3af9e 100644 (file)
@@ -66,7 +66,7 @@ int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector<uint8_t>
   rrname = pr.getName();
   rrtype = pr.get16BitInt();
   rrclass = pr.get16BitInt();
-  
+
   DNSPacketWriter pw(newContent, rrname, rrtype, rrclass, dh->opcode);
   pw.getHeader()->id=dh->id;
   pw.getHeader()->qr=dh->qr;
@@ -77,7 +77,7 @@ int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector<uint8_t>
   pw.getHeader()->ad=dh->ad;
   pw.getHeader()->cd=dh->cd;
   pw.getHeader()->rcode=dh->rcode;
-  
+
   /* consume remaining qd if any */
   if (qdcount > 1) {
     for(idx = 1; idx < qdcount; idx++) {
@@ -191,13 +191,13 @@ int locateEDNSOptRR(const std::string& packet, uint16_t * optStart, size_t * opt
 }
 
 /* extract the start of the OPT RR in a QUERY packet if any */
-int getEDNSOptionsStart(char* packet, const size_t offset, const size_t len, char ** optRDLen, size_t * remaining)
+int getEDNSOptionsStart(const char* packet, const size_t offset, const size_t len, uint16_t* optRDPosition, size_t * remaining)
 {
-  assert(packet != NULL);
-  assert(optRDLen != NULL);
-  assert(remaining != NULL);
+  assert(packet != nullptr);
+  assert(optRDPosition != nullptr);
+  assert(remaining != nullptr);
   const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet);
-  
+
   if (offset >= len) {
     return ENOENT;
   }
@@ -229,13 +229,13 @@ int getEDNSOptionsStart(char* packet, const size_t offset, const size_t len, cha
     return ENOENT;
 
   pos += DNS_TTL_SIZE;
-  *optRDLen = packet + pos;
+  *optRDPosition = pos;
   *remaining = len - pos;
 
   return 0;
 }
 
-static void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPrefixLength)
+void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPrefixLength)
 {
   Netmask sourceNetmask(source, ECSPrefixLength);
   EDNSSubnetOpts ecsOpts;
@@ -264,22 +264,20 @@ void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayload
   res.append(optRData.c_str(), optRData.length());
 }
 
-static bool replaceEDNSClientSubnetOption(char * const packet, const size_t packetSize, uint16_t * const len, const ComboAddress& remote, char * const oldEcsOptionStart, size_t const oldEcsOptionSize, unsigned char * const optRDLen, uint16_t ECSPrefixLength)
+static bool replaceEDNSClientSubnetOption(char * const packet, const size_t packetSize, uint16_t * const len, char * const oldEcsOptionStart, size_t const oldEcsOptionSize, unsigned char * const optRDLen, const string& newECSOption)
 {
   assert(packet != NULL);
   assert(len != NULL);
   assert(oldEcsOptionStart != NULL);
   assert(optRDLen != NULL);
-  string ECSOption;
-  generateECSOption(remote, ECSOption, ECSPrefixLength);
 
-  if (ECSOption.size() == oldEcsOptionSize) {
+  if (newECSOption.size() == oldEcsOptionSize) {
     /* same size as the existing option */
-    memcpy(oldEcsOptionStart, ECSOption.c_str(), oldEcsOptionSize);
+    memcpy(oldEcsOptionStart, newECSOption.c_str(), oldEcsOptionSize);
   }
   else {
     /* different size than the existing option */
-    const unsigned int newPacketLen = *len + (ECSOption.length() - oldEcsOptionSize);
+    const unsigned int newPacketLen = *len + (newECSOption.length() - oldEcsOptionSize);
     const size_t beforeOptionLen = oldEcsOptionStart - packet;
     const size_t dataBehindSize = *len - beforeOptionLen - oldEcsOptionSize;
 
@@ -290,91 +288,140 @@ static bool replaceEDNSClientSubnetOption(char * const packet, const size_t pack
 
     /* fix the size of ECS Option RDLen */
     uint16_t newRDLen = (optRDLen[0] * 256) + optRDLen[1];
-    newRDLen += (ECSOption.size() - oldEcsOptionSize);
+    newRDLen += (newECSOption.size() - oldEcsOptionSize);
     optRDLen[0] = newRDLen / 256;
     optRDLen[1] = newRDLen % 256;
 
     if (dataBehindSize > 0) {
       memmove(oldEcsOptionStart, oldEcsOptionStart + oldEcsOptionSize, dataBehindSize);
     }
-    memcpy(oldEcsOptionStart + dataBehindSize, ECSOption.c_str(), ECSOption.size());
+    memcpy(oldEcsOptionStart + dataBehindSize, newECSOption.c_str(), newECSOption.size());
     *len = newPacketLen;
   }
 
   return true;
 }
 
-bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool* const ednsAdded, bool* const ecsAdded, const ComboAddress& remote, bool overrideExisting, uint16_t ecsPrefixLength)
+/* This function looks for an OPT RR, return true if a valid one was found (even if there was no options)
+   and false otherwise. */
+bool parseEDNSOptions(DNSQuestion& dq)
 {
-  assert(packet != NULL);
-  assert(len != NULL);
-  assert(consumed <= (size_t) *len);
-  assert(ednsAdded != NULL);
-  assert(ecsAdded != NULL);
-  unsigned char * optRDLen = NULL;
-  size_t remaining = 0;
+  assert(dq.dh != nullptr);
+  assert(dq.consumed <= dq.len);
+  assert(dq.len <= dq.size);
+
+  if (dq.ednsOptions != nullptr) {
+    return true;
+  }
 
-  int res = getEDNSOptionsStart(packet, consumed, *len, (char**) &optRDLen, &remaining);
+  dq.ednsOptions = std::make_shared<std::map<uint16_t, EDNSOptionView> >();
+  const char* packet = reinterpret_cast<const char*>(dq.dh);
+
+  size_t remaining = 0;
+  uint16_t optRDPosition;
+  int res = getEDNSOptionsStart(packet, dq.consumed, dq.len, &optRDPosition, &remaining);
 
   if (res == 0) {
-    char * ecsOptionStart = NULL;
-    size_t ecsOptionSize = 0;
-    
-    res = getEDNSOption((char*)optRDLen, remaining, EDNSOptionCode::ECS, &ecsOptionStart, &ecsOptionSize);
-    
-    if (res == 0) {
-      /* there is already an ECS value */
-      if (overrideExisting) {
-        return replaceEDNSClientSubnetOption(packet, packetSize, len, remote, ecsOptionStart, ecsOptionSize, optRDLen, ecsPrefixLength);
-      }
-    } else {
-      /* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */
-      /* getEDNSOptionsStart has already checked that there is exactly one AR,
-         no NS and no AN */
-      string ECSOption;
-      generateECSOption(remote, ECSOption, ecsPrefixLength);
-      const size_t ECSOptionSize = ECSOption.size();
-      
-      /* check if the existing buffer is large enough */
-      if (packetSize - *len <= ECSOptionSize) {
-        return false;
-      }
+    res = getEDNSOptions(packet + optRDPosition, remaining, *dq.ednsOptions);
+    return (res == 0);
+  }
 
-      uint16_t newRDLen = (optRDLen[0] * 256) + optRDLen[1];
-      newRDLen += ECSOptionSize;
-      optRDLen[0] = newRDLen / 256;
-      optRDLen[1] = newRDLen % 256;
+  return false;
+}
 
-      memcpy(packet + *len, ECSOption.c_str(), ECSOptionSize);
-      *len += ECSOptionSize;
-      *ecsAdded = true;
-    }
+static bool addECSToExistingOPT(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, unsigned char* optRDLen, bool* const ecsAdded)
+{
+  /* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */
+  /* getEDNSOptionsStart has already checked that there is exactly one AR,
+     no NS and no AN */
+
+  /* check if the existing buffer is large enough */
+  const size_t newECSOptionSize = newECSOption.size();
+  if (packetSize - *len <= newECSOptionSize) {
+    return false;
   }
-  else {
-    /* we need to add a EDNS0 RR with one EDNS0 ECS option, fixing the AR count */
-    string EDNSRR;
-    struct dnsheader* dh = (struct dnsheader*) packet;
-    string optRData;
-    generateECSOption(remote, optRData, ecsPrefixLength);
-    generateOptRR(optRData, EDNSRR, g_EdnsUDPPayloadSize, false);
-
-    /* does it fit in the existing buffer? */
-    if (packetSize - *len <= EDNSRR.size()) {
-      return false;
-    }
 
-    uint16_t arcount = ntohs(dh->arcount);
-    arcount++;
-    dh->arcount = htons(arcount);
-    *ednsAdded = true;
+  uint16_t newRDLen = (optRDLen[0] * 256) + optRDLen[1];
+  newRDLen += newECSOptionSize;
+  optRDLen[0] = newRDLen / 256;
+  optRDLen[1] = newRDLen % 256;
+
+  memcpy(packet + *len, newECSOption.c_str(), newECSOptionSize);
+  *len += newECSOptionSize;
+  *ecsAdded = true;
+
+  return true;
+}
+
+static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, bool* const ednsAdded)
+{
+  /* we need to add a EDNS0 RR with one EDNS0 ECS option, fixing the AR count */
+  string EDNSRR;
+  struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet);
+  generateOptRR(newECSOption, EDNSRR, g_EdnsUDPPayloadSize, false);
+
+  /* does it fit in the existing buffer? */
+  if (packetSize - *len <= EDNSRR.size()) {
+    return false;
+  }
+
+  uint16_t arcount = ntohs(dh->arcount);
+  arcount++;
+  dh->arcount = htons(arcount);
+  *ednsAdded = true;
+
+  memcpy(packet + *len, EDNSRR.c_str(), EDNSRR.size());
+  *len += EDNSRR.size();
+  return true;
+}
+
+bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool* const ednsAdded, bool* const ecsAdded, bool overrideExisting, const string& newECSOption)
+{
+  assert(packet != nullptr);
+  assert(len != nullptr);
+  assert(consumed <= (size_t) *len);
+  assert(ednsAdded != nullptr);
+  assert(ecsAdded != nullptr);
+  uint16_t optRDPosition = 0;
+  size_t remaining = 0;
+
+  int res = getEDNSOptionsStart(packet, consumed, *len, &optRDPosition, &remaining);
 
-    memcpy(packet + *len, EDNSRR.c_str(), EDNSRR.size());
-    *len += EDNSRR.size();
+  if (res != 0) {
+    return addEDNSWithECS(packet, packetSize, len, newECSOption, ednsAdded);
+  }
+
+  unsigned char* optRDLen = reinterpret_cast<unsigned char*>(packet) + optRDPosition;
+  char * ecsOptionStart = nullptr;
+  size_t ecsOptionSize = 0;
+
+  res = getEDNSOption(reinterpret_cast<char*>(optRDLen), remaining, EDNSOptionCode::ECS, &ecsOptionStart, &ecsOptionSize);
+
+  if (res == 0) {
+    /* there is already an ECS value */
+    if (!overrideExisting) {
+      return true;
+    }
+
+    return replaceEDNSClientSubnetOption(packet, packetSize, len, ecsOptionStart, ecsOptionSize, optRDLen, newECSOption);
+  } else {
+    /* we have an EDNS OPT RR but no existing ECS option */
+    return addECSToExistingOPT(packet, packetSize, len, newECSOption, optRDLen, ecsAdded);
   }
 
   return true;
 }
 
+bool handleEDNSClientSubnet(DNSQuestion& dq, bool* ednsAdded, bool* ecsAdded)
+{
+  assert(dq.remote != nullptr);
+  string newECSOption;
+  generateECSOption(dq.ecsSet ? dq.ecs.getNetwork() : *dq.remote, newECSOption, dq.ecsSet ? dq.ecs.getBits() :  dq.ecsPrefixLength);
+  char* packet = reinterpret_cast<char*>(dq.dh);
+
+  return handleEDNSClientSubnet(packet, dq.size, dq.consumed, &dq.len, ednsAdded, ecsAdded, dq.ecsOverride, newECSOption);
+}
+
 static int removeEDNSOptionFromOptions(unsigned char* optionsStart, const uint16_t optionsLen, const uint16_t optionCodeToRemove, uint16_t* newOptionsLen)
 {
   unsigned char* p = optionsStart;
@@ -582,11 +629,11 @@ bool addEDNS(dnsheader* dh, uint16_t& len, const size_t size, bool dnssecOK, uin
 
 bool addEDNSToQueryTurnedResponse(DNSQuestion& dq)
 {
-  char* optRDLen = nullptr;
+  uint16_t optRDPosition;
   /* remaining is at least the size of the rdlen + the options if any + the following records if any */
   size_t remaining = 0;
 
-  int res = getEDNSOptionsStart(reinterpret_cast<char*>(dq.dh), dq.consumed, dq.len, &optRDLen, &remaining);
+  int res = getEDNSOptionsStart(reinterpret_cast<char*>(dq.dh), dq.consumed, dq.len, &optRDPosition, &remaining);
 
   if (res != 0) {
     /* if the initial query did not have EDNS0, we are done */
@@ -599,6 +646,7 @@ bool addEDNSToQueryTurnedResponse(DNSQuestion& dq)
     return false;
   }
 
+  char* optRDLen = reinterpret_cast<char*>(dq.dh) + optRDPosition;
   char * optPtr = (optRDLen - (/* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2));
 
   const uint8_t* zPtr = (const uint8_t*) optPtr + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE;
@@ -663,10 +711,10 @@ catch(...)
 
 bool queryHasEDNS(const DNSQuestion& dq)
 {
-  char * optRDLen = nullptr;
+  uint16_t optRDPosition;
   size_t ecsRemaining = 0;
 
-  int res = getEDNSOptionsStart(reinterpret_cast<char*>(dq.dh), dq.consumed, dq.len, &optRDLen, &ecsRemaining);
+  int res = getEDNSOptionsStart(reinterpret_cast<char*>(dq.dh), dq.consumed, dq.len, &optRDPosition, &ecsRemaining);
   if (res == 0) {
     return true;
   }
index c9193f61d65c9a62617d4dbb6c4c57894d82f790..5b664407bcdfa99e3c96c81424cb04a44f7ee519 100644 (file)
@@ -26,15 +26,20 @@ extern uint16_t g_PayloadSizeSelfGenAnswers;
 
 int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector<uint8_t>& newContent);
 int locateEDNSOptRR(const std::string& packet, uint16_t * optStart, size_t * optLen, bool * last);
-bool handleEDNSClientSubnet(char * packet, size_t packetSize, unsigned int consumed, uint16_t * len, bool* ednsAdded, bool* ecsAdded, const ComboAddress& remote, bool overrideExisting, uint16_t ecsPrefixLength);
 void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayloadSize, bool dnssecOK);
+void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPrefixLength);
 int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove);
 int rewriteResponseWithoutEDNSOption(const std::string& initialPacket, const uint16_t optionCodeToSkip, vector<uint8_t>& newContent);
-int getEDNSOptionsStart(char* packet, const size_t offset, const size_t len, char ** optRDLen, size_t * remaining);
+int getEDNSOptionsStart(const char* packet, const size_t offset, const size_t len, uint16_t* optRDPosition, size_t * remaining);
 bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind);
 bool addEDNS(dnsheader* dh, uint16_t& len, const size_t size, bool dnssecOK, uint16_t payloadSize);
 bool addEDNSToQueryTurnedResponse(DNSQuestion& dq);
 
+bool handleEDNSClientSubnet(DNSQuestion& dq, bool* ednsAdded, bool* ecsAdded);
+bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool* const ednsAdded, bool* const ecsAdded, bool overrideExisting, const string& newECSOption);
+
+bool parseEDNSOptions(DNSQuestion& dq);
+
 int getEDNSZ(const DNSQuestion& dq);
 bool queryHasEDNS(const DNSQuestion& dq);
 
index 377add214c5aaee840c3380bf7853427baf500e8..69c8f069d5bbeecddaa4252475ac7023179e2049 100644 (file)
@@ -155,7 +155,6 @@ TeeAction::~TeeAction()
   d_worker.join();
 }
 
-
 DNSAction::Action TeeAction::operator()(DNSQuestion* dq, string* ruleresult) const
 {
   if(dq->tcp) {
@@ -173,7 +172,10 @@ DNSAction::Action TeeAction::operator()(DNSQuestion* dq, string* ruleresult) con
       query.reserve(dq->size);
       query.assign((char*) dq->dh, len);
 
-      if (!handleEDNSClientSubnet(const_cast<char*>(query.c_str()), query.capacity(), dq->qname->wirelength(), &len, &ednsAdded, &ecsAdded, dq->ecsSet ? dq->ecs.getNetwork() : *dq->remote, dq->ecsOverride, dq->ecsSet ? dq->ecs.getBits() :  dq->ecsPrefixLength)) {
+      string newECSOption;
+      generateECSOption(dq->ecsSet ? dq->ecs.getNetwork() : *dq->remote, newECSOption, dq->ecsSet ? dq->ecs.getBits() :  dq->ecsPrefixLength);
+
+      if (!handleEDNSClientSubnet(const_cast<char*>(query.c_str()), query.capacity(), dq->qname->wirelength(), &len, &ednsAdded, &ecsAdded, dq->ecsOverride, newECSOption)) {
         return DNSAction::Action::None;
       }
 
index abb3a4fc2fc7885cbdd00f2469931e7e248fca31..52ccf2711c79b46fa3d22a4adee7863465c1c521 100644 (file)
@@ -55,6 +55,15 @@ void setupLuaBindingsDNSQuestion()
   g_lua.registerFunction<bool(DNSQuestion::*)()>("getDO", [](const DNSQuestion& dq) {
       return getEDNSZ(dq) & EDNS_HEADER_FLAG_DO;
     });
+
+  g_lua.registerFunction<std::map<uint16_t, EDNSOptionView>(DNSQuestion::*)()>("getEDNSOptions", [](DNSQuestion& dq) {
+      if (dq.ednsOptions == nullptr) {
+        parseEDNSOptions(dq);
+      }
+
+      return *dq.ednsOptions;
+    });
+
   g_lua.registerFunction<void(DNSQuestion::*)(std::string)>("sendTrap", [](const DNSQuestion& dq, boost::optional<std::string> reason) {
 #ifdef HAVE_NET_SNMP
       if (g_snmpAgent && g_snmpTrapsEnabled) {
index 97c3a0ef43331318bb7d4deb2cd493d4c342d907..bc2b9f1dc1483a0eb838fcff5f6e48020ecbfcaf 100644 (file)
@@ -571,4 +571,16 @@ void setupLuaBindings(bool client)
       }
     });
 #endif /* HAVE_EBPF */
+
+  /* EDNSOptionView */
+  g_lua.registerFunction<size_t(EDNSOptionView::*)()>("count", [](const EDNSOptionView& option) {
+      return option.values.size();
+    });
+  g_lua.registerFunction<std::vector<std::pair<int, string>>(EDNSOptionView::*)()>("getValues", [] (const EDNSOptionView& option) {
+    std::vector<std::pair<int, string> > values;
+    for (const auto& value : option.values) {
+      values.push_back(std::make_pair(values.size(), std::string(value.content, value.size)));
+    }
+    return values;
+  });
 }
index 7fe5afdd3917415be116d05d90e6b7cc6345b318..4ce42b5c7dd1d174aafdeb49ca5adea1be14ac17 100644 (file)
@@ -118,4 +118,17 @@ void setupLuaVars()
         { "VERSION2", DNSCryptExchangeVersion::VERSION2 },
     });
 #endif
+
+  g_lua.writeVariable("EDNSOptionCode", std::unordered_map<string, uint8_t>{
+      { "NSID", EDNSOptionCode::NSID },
+      { "DAU", EDNSOptionCode::DAU },
+      { "DHU", EDNSOptionCode::DHU },
+      { "N3U", EDNSOptionCode::N3U },
+      { "ECS", EDNSOptionCode::ECS },
+      { "EXPIRE", EDNSOptionCode::EXPIRE },
+      { "COOKIE", EDNSOptionCode::COOKIE },
+      { "TCPKEEPALIVE", EDNSOptionCode::TCPKEEPALIVE },
+      { "PADDING", EDNSOptionCode::PADDING },
+      { "CHAIN", EDNSOptionCode::CHAIN }
+    });
 }
index 5ca8ba071777e332f2440c8f579b2a70bc09730f..e5487fec3b62b41d62929be583a0bdda32003c91 100644 (file)
@@ -404,12 +404,10 @@ void* tcpClientThread(int pipefd)
         }
 
         if (dq.useECS && ((ds && ds->useECS) || (!ds && serverPool->getECS()))) {
-          uint16_t newLen = dq.len;
-          if (!handleEDNSClientSubnet(query, dq.size, consumed, &newLen, &ednsAdded, &ecsAdded, dq.ecsSet ? dq.ecs.getNetwork() : ci.remote, dq.ecsOverride, dq.ecsSet ? dq.ecs.getBits() : dq.ecsPrefixLength)) {
+          if (!handleEDNSClientSubnet(dq, &(ednsAdded), &(ecsAdded))) {
             vinfolog("Dropping query from %s because we couldn't insert the ECS value", ci.remote.toStringWithPort());
             goto drop;
           }
-          dq.len = newLen;
         }
 
         uint32_t cacheKey = 0;
index 4173879f6de4f12c405d3e63b9f73c1659338c16..9433947b0b6dfa7a0ca4e587ca24b8f551efc934 100644 (file)
@@ -1403,7 +1403,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     bool ednsAdded = false;
     bool ecsAdded = false;
     if (dq.useECS && ((ss && ss->useECS) || (!ss && serverPool->getECS()))) {
-      if (!handleEDNSClientSubnet(query, dq.size, consumed, &dq.len, &(ednsAdded), &(ecsAdded), dq.ecsSet ? dq.ecs.getNetwork() : remote, dq.ecsOverride, dq.ecsSet ? dq.ecs.getBits() : dq.ecsPrefixLength)) {
+      if (!handleEDNSClientSubnet(dq, &(ednsAdded), &(ecsAdded))) {
         vinfolog("Dropping query from %s because we couldn't insert the ECS value", remote.toStringWithPort());
         return;
       }
index c66e1f217b5a005958be8871e3f83e2679e69841..b4f800e8a03400b6f38953dce6bd07bfa795ae60 100644 (file)
 #pragma once
 #include "config.h"
 #include "ext/luawrapper/include/LuaContext.hpp"
-#include <time.h>
-#include "misc.hh"
-#include "mplexer.hh"
-#include "iputils.hh"
-#include "dnsname.hh"
+
 #include <atomic>
-#include <boost/variant.hpp>
 #include <mutex>
+#include <string>
 #include <thread>
+#include <time.h>
 #include <unistd.h>
-#include "sholder.hh"
+#include <unordered_map>
+
+#include <boost/circular_buffer.hpp>
+#include <boost/variant.hpp>
+
+#include "bpf-filter.hh"
 #include "dnscrypt.hh"
 #include "dnsdist-cache.hh"
-#include "gettime.hh"
 #include "dnsdist-dynbpf.hh"
-#include "bpf-filter.hh"
-#include <string>
-#include <unordered_map>
+#include "dnsname.hh"
+#include "ednsoptions.hh"
+#include "gettime.hh"
+#include "iputils.hh"
+#include "misc.hh"
+#include "mplexer.hh"
+#include "sholder.hh"
 #include "tcpiohandler.hh"
 
 #include <boost/uuid/uuid.hpp>
@@ -72,6 +77,7 @@ struct DNSQuestion
   const ComboAddress* local;
   const ComboAddress* remote;
   std::shared_ptr<QTag> qTag{nullptr};
+  std::shared_ptr<std::map<uint16_t, EDNSOptionView> > ednsOptions;
   struct dnsheader* dh;
   size_t size;
   unsigned int consumed{0};
index 6900c6799506831378374ebd6f9c156794c10fd4..75cbb556c5c65453f0226b8dc25c083a5aced130 100644 (file)
@@ -84,6 +84,14 @@ This state can be modified from the various hooks.
 
     :returns: true if the DO bit was set, false otherwise
 
+  .. method:: DNSQuestion:getEDNSOptions() -> table
+
+    .. versionadded:: 1.3.1
+
+    Return the list of EDNS Options, if any.
+
+    :returns: A table of EDNSOptionView objects, indexed on the ECS Option code
+
   .. method:: DNSQuestion:getTag(key) -> string
 
     .. versionadded:: 1.2.0
@@ -194,3 +202,22 @@ DNSHeader (``dh``) object
     Set checking disabled flag.
 
     :param bool cd: State of the CD flag
+
+.. _EDNSOptionView:
+
+EDNSOptionView object
+=====================
+
+.. class:: EDNSOptionView
+
+  .. versionadded:: 1.3.1
+
+  An object that represents the values of a single EDNS option received in a query.
+
+  .. attribute:: EDNSOptionView.count -> int
+
+    The number of values for this EDNS option.
+
+  .. method:: EDNSOptionView:getValues()
+
+    Return a table of NULL-safe strings values for this EDNS option.
index b88d52b0addcfb58ed379d15a3782281abf51f5a..9996b6edf8c8914913fd6eca47de0dd1d89a18e2 100644 (file)
@@ -55,6 +55,27 @@ static void validateQuery(const char * packet, size_t packetSize, bool hasEdns=t
   BOOST_CHECK_EQUAL(mdp.d_header.arcount, (hasEdns ? 1 : 0));
 }
 
+static void validateECS(const char* packet, size_t packetSize, const ComboAddress& expected)
+{
+  ComboAddress rem("::1");
+  unsigned int consumed = 0;
+  uint16_t qtype;
+  uint16_t qclass;
+  DNSName qname(packet, packetSize, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
+  DNSQuestion dq(&qname, qtype, qclass, consumed, nullptr, &rem, const_cast<dnsheader*>(reinterpret_cast<const dnsheader*>(packet)), packetSize, packetSize, false, nullptr);
+  BOOST_CHECK(parseEDNSOptions(dq));
+  BOOST_REQUIRE(dq.ednsOptions != nullptr);
+  BOOST_CHECK_EQUAL(dq.ednsOptions->size(), 1);
+  const auto& ecsOption = dq.ednsOptions->find(EDNSOptionCode::ECS);
+  BOOST_REQUIRE(ecsOption != dq.ednsOptions->cend());
+
+  string expectedOption;
+  generateECSOption(expected, expectedOption, expected.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+  /* we need to skip the option code and length, which are not included */
+  BOOST_REQUIRE_EQUAL(ecsOption->second.values.size(), 1);
+  BOOST_CHECK_EQUAL(expectedOption.substr(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), std::string(ecsOption->second.values.at(0).content, ecsOption->second.values.at(0).size));
+}
+
 static void validateResponse(const char * packet, size_t packetSize, bool hasEdns, uint8_t additionalCount=0)
 {
   MOADNSParser mdp(false, packet, packetSize);
@@ -72,8 +93,10 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS)
 {
   bool ednsAdded = false;
   bool ecsAdded = false;
-  ComboAddress remote;
+  ComboAddress remote("192.0.2.1");
   DNSName name("www.powerdns.com.");
+  string newECSOption;
+  generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
 
   vector<uint8_t> query;
   DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
@@ -90,11 +113,12 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS)
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6));
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption));
   BOOST_CHECK((size_t) len > query.size());
   BOOST_CHECK_EQUAL(ednsAdded, true);
   BOOST_CHECK_EQUAL(ecsAdded, false);
   validateQuery(packet, len);
+  validateECS(packet, len, remote);
 
   /* not large enough packet */
   ednsAdded = false;
@@ -105,18 +129,72 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS)
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6));
+  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption));
   BOOST_CHECK_EQUAL((size_t) len, query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
   validateQuery(reinterpret_cast<char*>(query.data()), len, false);
 }
 
+BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed)
+{
+  bool ednsAdded = false;
+  bool ecsAdded = false;
+  ComboAddress remote("192.0.2.1");
+  DNSName name("www.powerdns.com.");
+
+  vector<uint8_t> query;
+  DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
+  pw.getHeader()->rd = 1;
+
+  /* large enough packet */
+  char packet[1500];
+  memcpy(packet, query.data(), query.size());
+
+  unsigned int consumed = 0;
+  uint16_t qtype;
+  uint16_t qclass;
+  DNSName qname(packet, query.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::A);
+  BOOST_CHECK(qclass == QClass::IN);
+
+  DNSQuestion dq(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast<dnsheader*>(packet), sizeof(packet), query.size(), false, nullptr);
+  /* Parse the options before handling ECS, simulating a Lua rule asking for EDNS Options */
+  BOOST_CHECK(!parseEDNSOptions(dq));
+
+  /* And now we add our own ECS */
+  BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded));
+  BOOST_CHECK_GT(static_cast<size_t>(dq.len), query.size());
+  BOOST_CHECK_EQUAL(ednsAdded, true);
+  BOOST_CHECK_EQUAL(ecsAdded, false);
+  validateQuery(packet, dq.len);
+  validateECS(packet, dq.len, remote);
+
+  /* not large enough packet */
+  ednsAdded = false;
+  ecsAdded = false;
+  consumed = 0;
+  qname = DNSName(reinterpret_cast<char*>(query.data()), query.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::A);
+  BOOST_CHECK(qclass == QClass::IN);
+  DNSQuestion dq2(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast<dnsheader*>(query.data()), query.size(), query.size(), false, nullptr);
+
+  BOOST_CHECK(!handleEDNSClientSubnet(dq2, &ednsAdded, &ecsAdded));
+  BOOST_CHECK_EQUAL(static_cast<size_t>(dq2.len), query.size());
+  BOOST_CHECK_EQUAL(ednsAdded, false);
+  BOOST_CHECK_EQUAL(ecsAdded, false);
+  validateQuery(reinterpret_cast<char*>(query.data()), dq2.len, false);
+}
+
 BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) {
   bool ednsAdded = false;
   bool ecsAdded = false;
   ComboAddress remote;
   DNSName name("www.powerdns.com.");
+  string newECSOption;
+  generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
 
   vector<uint8_t> query;
   DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
@@ -135,11 +213,12 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6));
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption));
   BOOST_CHECK((size_t) len > query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, true);
   validateQuery(packet, len);
+  validateECS(packet, len, remote);
 
   /* not large enough packet */
   consumed = 0;
@@ -150,19 +229,74 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6));
+  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption));
   BOOST_CHECK_EQUAL((size_t) len, query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
   validateQuery(reinterpret_cast<char*>(query.data()), len);
 }
 
+BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECSAlreadyParsed) {
+  bool ednsAdded = false;
+  bool ecsAdded = false;
+  ComboAddress remote("2001:DB8::1");
+  DNSName name("www.powerdns.com.");
+
+  vector<uint8_t> query;
+  DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
+  pw.getHeader()->rd = 1;
+  pw.addOpt(512, 0, 0);
+  pw.commit();
+
+  /* large enough packet */
+  char packet[1500];
+  memcpy(packet, query.data(), query.size());
+
+  unsigned int consumed = 0;
+  uint16_t qtype;
+  uint16_t qclass;
+  DNSName qname(packet, query.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::A);
+  BOOST_CHECK(qclass == QClass::IN);
+
+  DNSQuestion dq(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast<dnsheader*>(packet), sizeof(packet), query.size(), false, nullptr);
+  /* Parse the options before handling ECS, simulating a Lua rule asking for EDNS Options */
+  BOOST_CHECK(parseEDNSOptions(dq));
+
+  /* And now we add our own ECS */
+  BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded));
+  BOOST_CHECK_GT(static_cast<size_t>(dq.len), query.size());
+  BOOST_CHECK_EQUAL(ednsAdded, false);
+  BOOST_CHECK_EQUAL(ecsAdded, true);
+  validateQuery(packet, dq.len);
+  validateECS(packet, dq.len, remote);
+
+  /* not large enough packet */
+  consumed = 0;
+  ednsAdded = false;
+  ecsAdded = false;
+  qname = DNSName(reinterpret_cast<char*>(query.data()), query.size(), sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::A);
+  BOOST_CHECK(qclass == QClass::IN);
+  DNSQuestion dq2(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast<dnsheader*>(query.data()), query.size(), query.size(), false, nullptr);
+
+  BOOST_CHECK(!handleEDNSClientSubnet(dq2, &ednsAdded, &ecsAdded));
+  BOOST_CHECK_EQUAL(static_cast<size_t>(dq2.len), query.size());
+  BOOST_CHECK_EQUAL(ednsAdded, false);
+  BOOST_CHECK_EQUAL(ecsAdded, false);
+  validateQuery(reinterpret_cast<char*>(query.data()), dq2.len);
+}
+
 BOOST_AUTO_TEST_CASE(replaceECSWithSameSize) {
   bool ednsAdded = false;
   bool ecsAdded = false;
   ComboAddress remote("192.168.1.25");
   DNSName name("www.powerdns.com.");
   ComboAddress origRemote("127.0.0.1");
+  string newECSOption;
+  generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
 
   vector<uint8_t> query;
   DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
@@ -186,11 +320,57 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSize) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6));
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption));
   BOOST_CHECK_EQUAL((size_t) len, query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
   validateQuery(packet, len);
+  validateECS(packet, len, remote);
+}
+
+BOOST_AUTO_TEST_CASE(replaceECSWithSameSizeAlreadyParsed) {
+  bool ednsAdded = false;
+  bool ecsAdded = false;
+  ComboAddress remote("192.168.1.25");
+  DNSName name("www.powerdns.com.");
+  ComboAddress origRemote("127.0.0.1");
+
+  vector<uint8_t> query;
+  DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
+  pw.getHeader()->rd = 1;
+  EDNSSubnetOpts ecsOpts;
+  ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4);
+  string origECSOption = makeEDNSSubnetOptsString(ecsOpts);
+  DNSPacketWriter::optvect_t opts;
+  opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOption));
+  pw.addOpt(512, 0, 0, opts);
+  pw.commit();
+
+  /* large enough packet */
+  char packet[1500];
+  memcpy(packet, query.data(), query.size());
+
+  unsigned int consumed = 0;
+  uint16_t qtype;
+  uint16_t qclass;
+  DNSName qname(packet, query.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::A);
+  BOOST_CHECK(qclass == QClass::IN);
+
+  DNSQuestion dq(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast<dnsheader*>(packet), sizeof(packet), query.size(), false, nullptr);
+  dq.ecsOverride = true;
+
+  /* Parse the options before handling ECS, simulating a Lua rule asking for EDNS Options */
+  BOOST_CHECK(parseEDNSOptions(dq));
+
+  /* And now we add our own ECS */
+  BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded));
+  BOOST_CHECK_EQUAL(static_cast<size_t>(dq.len), query.size());
+  BOOST_CHECK_EQUAL(ednsAdded, false);
+  BOOST_CHECK_EQUAL(ecsAdded, false);
+  validateQuery(packet, dq.len);
+  validateECS(packet, dq.len, remote);
 }
 
 BOOST_AUTO_TEST_CASE(replaceECSWithSmaller) {
@@ -199,6 +379,8 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSmaller) {
   ComboAddress remote("192.168.1.25");
   DNSName name("www.powerdns.com.");
   ComboAddress origRemote("127.0.0.1");
+  string newECSOption;
+  generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
 
   vector<uint8_t> query;
   DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
@@ -222,11 +404,12 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSmaller) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6));
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption));
   BOOST_CHECK((size_t) len < query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
   validateQuery(packet, len);
+  validateECS(packet, len, remote);
 }
 
 BOOST_AUTO_TEST_CASE(replaceECSWithLarger) {
@@ -235,6 +418,8 @@ BOOST_AUTO_TEST_CASE(replaceECSWithLarger) {
   ComboAddress remote("192.168.1.25");
   DNSName name("www.powerdns.com.");
   ComboAddress origRemote("127.0.0.1");
+  string newECSOption;
+  generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
 
   vector<uint8_t> query;
   DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
@@ -258,11 +443,12 @@ BOOST_AUTO_TEST_CASE(replaceECSWithLarger) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6));
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption));
   BOOST_CHECK((size_t) len > query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
   validateQuery(packet, len);
+  validateECS(packet, len, remote);
 
   /* not large enough packet */
   ednsAdded = false;
@@ -273,7 +459,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithLarger) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6));
+  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption));
   BOOST_CHECK_EQUAL((size_t) len, query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
index f202b87bf0e66b774550ec723fcb163c7d2ada64..c02afb8437bbd1da93dfddc95d274e0d6a14740d 100644 (file)
@@ -491,18 +491,25 @@ class DNSDistTest(unittest.TestCase):
             for option in received.options:
                 self.assertEquals(option.otype, 10)
 
-    def checkMessageEDNSWithECS(self, expected, received):
+    def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0):
         self.assertEquals(expected, received)
         self.assertEquals(received.edns, 0)
-        self.assertEquals(len(received.options), 1)
-        self.assertEquals(received.options[0].otype, clientsubnetoption.ASSIGNED_OPTION_CODE)
+        self.assertEquals(len(received.options), 1 + additionalOptions)
+        hasECS = False
+        for option in received.options:
+            if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE:
+                hasECS = True
+            else:
+                self.assertNotEquals(additionalOptions, 0)
+
         self.compareOptions(expected.options, received.options)
+        self.assertTrue(hasECS)
 
-    def checkQueryEDNSWithECS(self, expected, received):
-        self.checkMessageEDNSWithECS(expected, received)
+    def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0):
+        self.checkMessageEDNSWithECS(expected, received, additionalOptions)
 
-    def checkResponseEDNSWithECS(self, expected, received):
-        self.checkMessageEDNSWithECS(expected, received)
+    def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0):
+        self.checkMessageEDNSWithECS(expected, received, additionalOptions)
 
     def checkQueryEDNSWithoutECS(self, expected, received):
         self.checkMessageEDNSWithoutECS(expected, received)
diff --git a/regression-tests.dnsdist/test_EDNSOptions.py b/regression-tests.dnsdist/test_EDNSOptions.py
new file mode 100644 (file)
index 0000000..f5838bf
--- /dev/null
@@ -0,0 +1,446 @@
+#!/usr/bin/env python
+import dns
+import clientsubnetoption
+import cookiesoption
+from dnsdisttests import DNSDistTest
+
+class EDNSOptionsBase(DNSDistTest):
+    _ednsTestFunction = """
+    function testEDNSOptions(dq)
+      local options = dq:getEDNSOptions()
+      local qname = dq.qname:toString()
+
+      if string.match(qname, 'noedns') then
+        if next(options) ~= nil then
+          return DNSAction.Spoof, "192.0.2.255"
+        end
+      end
+
+      if string.match(qname, 'multiplecookies') then
+        if options[EDNSOptionCode.COOKIE] == nil then
+          return DNSAction.Spoof, "192.0.2.1"
+        end
+        if options[EDNSOptionCode.COOKIE]:count() ~= 2 then
+          return DNSAction.Spoof, "192.0.2.2"
+        end
+        if options[EDNSOptionCode.COOKIE]:getValues()[0]:len() ~= 16 then
+          return DNSAction.Spoof, "192.0.2.3"
+        end
+        if options[EDNSOptionCode.COOKIE]:getValues()[1]:len() ~= 16 then
+          return DNSAction.Spoof, "192.0.2.4"
+        end
+      elseif string.match(qname, 'cookie') then
+        if options[EDNSOptionCode.COOKIE] == nil then
+          return DNSAction.Spoof, "192.0.2.1"
+        end
+        if options[EDNSOptionCode.COOKIE]:count() ~= 1 or options[EDNSOptionCode.COOKIE]:getValues()[0]:len() ~= 16 then
+          return DNSAction.Spoof, "192.0.2.2"
+        end
+      end
+
+      if string.match(qname, 'ecs4') then
+        if options[EDNSOptionCode.ECS] == nil then
+          return DNSAction.Spoof, "192.0.2.51"
+        end
+        if options[EDNSOptionCode.ECS]:count() ~= 1 or options[EDNSOptionCode.ECS]:getValues()[0]:len() ~= 8 then
+          return DNSAction.Spoof, "192.0.2.52"
+        end
+      end
+
+      if string.match(qname, 'ecs6') then
+        if options[EDNSOptionCode.ECS] == nil then
+          return DNSAction.Spoof, "192.0.2.101"
+        end
+        if options[EDNSOptionCode.ECS]:count() ~= 1 or options[EDNSOptionCode.ECS]:getValues()[0]:len() ~= 20 then
+          return DNSAction.Spoof, "192.0.2.102"
+        end
+      end
+
+      return DNSAction.None, ""
+
+    end
+    """
+
+class TestEDNSOptions(EDNSOptionsBase):
+
+    _config_template = """
+    %s
+
+    addLuaAction(AllRule(), testEDNSOptions)
+
+    newServer{address="127.0.0.1:%s"}
+    """
+    _config_params = ['_ednsTestFunction', '_testServerPort']
+
+    def testWithoutEDNS(self):
+        """
+        EDNS Options: No EDNS
+        """
+        name = 'noedns.ednsoptions.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.255')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+
+    def testCookie(self):
+        """
+        EDNS Options: Cookie
+        """
+        name = 'cookie.ednsoptions.tests.powerdns.com.'
+        eco = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco])
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+
+    def testECS4(self):
+        """
+        EDNS Options: ECS4
+        """
+        name = 'ecs4.ednsoptions.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('1.2.3.4', 32)
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+
+    def testECS6(self):
+        """
+        EDNS Options: ECS6
+        """
+        name = 'ecs6.ednsoptions.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+
+    def testECS6Cookie(self):
+        """
+        EDNS Options: Cookie + ECS6
+        """
+        name = 'cookie-ecs6.ednsoptions.tests.powerdns.com.'
+        eco = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso,eco])
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+
+    def testMultiCookiesECS6(self):
+        """
+        EDNS Options: Two Cookies + ECS6
+        """
+        name = 'multiplecookies-ecs6.ednsoptions.tests.powerdns.com.'
+        eco1 = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
+        eco2 = cookiesoption.CookiesOption('deadc0de', 'deadc0de')
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco1, ecso, eco2])
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+
+class TestEDNSOptionsAddingECS(EDNSOptionsBase):
+
+    _config_template = """
+    %s
+
+    addLuaAction(AllRule(), testEDNSOptions)
+
+    newServer{address="127.0.0.1:%s", useClientSubnet=true}
+    """
+    _config_params = ['_ednsTestFunction', '_testServerPort']
+
+    def testWithoutEDNS(self):
+        """
+        EDNS Options: No EDNS (adding ECS)
+        """
+        name = 'noedns.ednsoptions-ecs.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseNoEDNS(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseNoEDNS(response, receivedResponse)
+
+    def testCookie(self):
+        """
+        EDNS Options: Cookie (adding ECS)
+        """
+        name = 'cookie.ednsoptions-ecs.tests.powerdns.com.'
+        eco = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco])
+        ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[eco,ecso], payload=512)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery, 1)
+        self.checkResponseEDNSWithoutECS(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery, 1)
+        self.checkResponseEDNSWithoutECS(response, receivedResponse)
+
+    def testECS4(self):
+        """
+        EDNS Options: ECS4 (adding ECS)
+        """
+        name = 'ecs4.ednsoptions-ecs.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('1.2.3.4', 32)
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
+        ecsoResponse = clientsubnetoption.ClientSubnetOption('1.2.3.4', 24, scope=24)
+        response = dns.message.make_response(query)
+        response.use_edns(edns=True, payload=4096, options=[ecsoResponse])
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.checkQueryEDNSWithECS(query, receivedQuery)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.checkQueryEDNSWithECS(query, receivedQuery)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
+
+    def testECS6(self):
+        """
+        EDNS Options: ECS6 (adding ECS)
+        """
+        name = 'ecs6.ednsoptions-ecs.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
+        ecsoResponse = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128, scope=56)
+        response = dns.message.make_response(query)
+        response.use_edns(edns=True, payload=4096, options=[ecsoResponse])
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.checkQueryEDNSWithECS(query, receivedQuery)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.checkQueryEDNSWithECS(query, receivedQuery)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
+
+    def testECS6Cookie(self):
+        """
+        EDNS Options: Cookie + ECS6 (adding ECS)
+        """
+        name = 'cookie-ecs6.ednsoptions-ecs.tests.powerdns.com.'
+        eco = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso,eco])
+        ecsoResponse = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128, scope=56)
+        response = dns.message.make_response(query)
+        response.use_edns(edns=True, payload=4096, options=[ecsoResponse])
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.checkQueryEDNSWithECS(query, receivedQuery, 1)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.checkQueryEDNSWithECS(query, receivedQuery, 1)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
+
+    def testMultiCookiesECS6(self):
+        """
+        EDNS Options: Two Cookies + ECS6
+        """
+        name = 'multiplecookies-ecs6.ednsoptions.tests.powerdns.com.'
+        eco1 = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
+        eco2 = cookiesoption.CookiesOption('deadc0de', 'deadc0de')
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco1, ecso, eco2])
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)