]> granicus.if.org Git - pdns/commitdiff
dnsdist: Handle trailing data correctly when adding OPT or ECS info
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 21 Sep 2018 18:06:06 +0000 (20:06 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 7 Nov 2018 19:10:02 +0000 (20:10 +0100)
17 files changed:
pdns/dnsdist-console.cc
pdns/dnsdist-ecs.cc
pdns/dnsdist-ecs.hh
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist-xpf.cc [new file with mode: 0644]
pdns/dnsdist-xpf.hh [new file with mode: 0644]
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/dnsdist-xpf.cc [new symlink]
pdns/dnsdistdist/dnsdist-xpf.hh [new symlink]
pdns/dnsdistdist/docs/advanced/ecs.rst
pdns/dnsparser.cc
pdns/test-dnsdist_cc.cc
pdns/test-dnsparser_cc.cc

index a697912bcbb28178decb319a3590b16621d2f99b..10f1e91918409e441642974c3757e3e8ab7809f0 100644 (file)
@@ -430,6 +430,7 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "setPayloadSizeOnSelfGeneratedAnswers", true, "payloadSize", "set the UDP payload size advertised via EDNS on self-generated responses" },
   { "setPoolServerPolicy", true, "policy, pool", "set the server selection policy for this pool to that policy" },
   { "setPoolServerPolicy", true, "name, func, pool", "set the server selection policy for this pool to one named 'name' and provided by 'function'" },
+  { "setPreserveTrailingData", true, "bool", "set whether trailing data should be preserved while adding ECS or XPF records to incoming queries" },
   { "setQueryCount", true, "bool", "set whether queries should be counted" },
   { "setQueryCountFilter", true, "func", "filter queries that would be counted, where `func` is a function with parameter `dq` which decides whether a query should and how it should be counted" },
   { "setRingBuffersLockRetries", true, "n", "set the number of attempts to get a non-blocking lock to a ringbuffer shard before blocking" },
index 71b852f25f4274b42d32147ad91494197dd3af9e..d1cd71a8673a44512a160ed3ddd9e887edc0d600 100644 (file)
@@ -353,7 +353,7 @@ static bool addECSToExistingOPT(char* const packet, size_t const packetSize, uin
   return true;
 }
 
-static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, bool* const ednsAdded)
+static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, bool* const ednsAdded, bool preserveTrailingData)
 {
   /* we need to add a EDNS0 RR with one EDNS0 ECS option, fixing the AR count */
   string EDNSRR;
@@ -365,17 +365,27 @@ static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t
     return false;
   }
 
+  uint32_t realPacketLen = getDNSPacketLength(packet, *len);
+  if (realPacketLen < *len && preserveTrailingData) {
+    size_t toMove = *len - realPacketLen;
+    memmove(packet + realPacketLen + EDNSRR.size(), packet + realPacketLen, toMove);
+    *len += EDNSRR.size();
+  }
+  else {
+    *len = realPacketLen + EDNSRR.size();
+  }
+
   uint16_t arcount = ntohs(dh->arcount);
   arcount++;
   dh->arcount = htons(arcount);
   *ednsAdded = true;
 
-  memcpy(packet + *len, EDNSRR.c_str(), EDNSRR.size());
-  *len += EDNSRR.size();
+  memcpy(packet + realPacketLen, EDNSRR.c_str(), EDNSRR.size());
+
   return true;
 }
 
-bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool* const ednsAdded, bool* const ecsAdded, bool overrideExisting, const string& newECSOption)
+bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool* const ednsAdded, bool* const ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData)
 {
   assert(packet != nullptr);
   assert(len != nullptr);
@@ -388,7 +398,7 @@ bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const u
   int res = getEDNSOptionsStart(packet, consumed, *len, &optRDPosition, &remaining);
 
   if (res != 0) {
-    return addEDNSWithECS(packet, packetSize, len, newECSOption, ednsAdded);
+    return addEDNSWithECS(packet, packetSize, len, newECSOption, ednsAdded, preserveTrailingData);
   }
 
   unsigned char* optRDLen = reinterpret_cast<unsigned char*>(packet) + optRDPosition;
@@ -412,14 +422,14 @@ bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const u
   return true;
 }
 
-bool handleEDNSClientSubnet(DNSQuestion& dq, bool* ednsAdded, bool* ecsAdded)
+bool handleEDNSClientSubnet(DNSQuestion& dq, bool* ednsAdded, bool* ecsAdded, bool preserveTrailingData)
 {
   assert(dq.remote != nullptr);
   string newECSOption;
-  generateECSOption(dq.ecsSet ? dq.ecs.getNetwork() : *dq.remote, newECSOption, dq.ecsSet ? dq.ecs.getBits() :  dq.ecsPrefixLength);
+  generateECSOption(dq.ecsSet ? dq.ecs.getNetwork() : *dq.remote, newECSOption, dq.ecsSet ? dq.ecs.getBits() : dq.ecsPrefixLength);
   char* packet = reinterpret_cast<char*>(dq.dh);
 
-  return handleEDNSClientSubnet(packet, dq.size, dq.consumed, &dq.len, ednsAdded, ecsAdded, dq.ecsOverride, newECSOption);
+  return handleEDNSClientSubnet(packet, dq.size, dq.consumed, &dq.len, ednsAdded, ecsAdded, dq.ecsOverride, newECSOption, preserveTrailingData);
 }
 
 static int removeEDNSOptionFromOptions(unsigned char* optionsStart, const uint16_t optionsLen, const uint16_t optionCodeToRemove, uint16_t* newOptionsLen)
index 5b664407bcdfa99e3c96c81424cb04a44f7ee519..40bf974bd055ac870c51c5c584604557964324ae 100644 (file)
@@ -35,8 +35,8 @@ bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const s
 bool addEDNS(dnsheader* dh, uint16_t& len, const size_t size, bool dnssecOK, uint16_t payloadSize);
 bool addEDNSToQueryTurnedResponse(DNSQuestion& dq);
 
-bool handleEDNSClientSubnet(DNSQuestion& dq, bool* ednsAdded, bool* ecsAdded);
-bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool* const ednsAdded, bool* const ecsAdded, bool overrideExisting, const string& newECSOption);
+bool handleEDNSClientSubnet(DNSQuestion& dq, bool* ednsAdded, bool* ecsAdded, bool preserveTrailingData);
+bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool* const ednsAdded, bool* const ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData);
 
 bool parseEDNSOptions(DNSQuestion& dq);
 
index 6e92336347914eead3465429c11103c5abc9332e..5cb2dc3e8e4f5fca2662b1afc559635ff68b70ec 100644 (file)
@@ -176,7 +176,7 @@ DNSAction::Action TeeAction::operator()(DNSQuestion* dq, string* ruleresult) con
       string newECSOption;
       generateECSOption(dq->ecsSet ? dq->ecs.getNetwork() : *dq->remote, newECSOption, dq->ecsSet ? dq->ecs.getBits() :  dq->ecsPrefixLength);
 
-      if (!handleEDNSClientSubnet(const_cast<char*>(query.c_str()), query.capacity(), dq->qname->wirelength(), &len, &ednsAdded, &ecsAdded, dq->ecsOverride, newECSOption)) {
+      if (!handleEDNSClientSubnet(const_cast<char*>(query.c_str()), query.capacity(), dq->qname->wirelength(), &len, &ednsAdded, &ecsAdded, dq->ecsOverride, newECSOption, g_preserveTrailingData)) {
         return DNSAction::Action::None;
       }
 
index 2bc27b1e680fb4bffd95d0f00d8d15e50ec381f5..1f468ebd2173a7f3751367e0faac38632eafbc1d 100644 (file)
@@ -913,6 +913,8 @@ void setupLuaConfig(bool client)
 
   g_lua.writeFunction("setECSOverride", [](bool override) { g_ECSOverride=override; });
 
+  g_lua.writeFunction("setPreserveTrailingData", [](bool preserve) { g_preserveTrailingData = preserve; });
+
   g_lua.writeFunction("showDynBlocks", []() {
       setLuaNoSideEffect();
       auto slow = g_dynblockNMG.getCopy();
index 656f8aad442e872c7f4862be59e91fabc6c5efac..98b6e78371cfdb94866110b0a244b42436b674db 100644 (file)
@@ -22,6 +22,7 @@
 #include "dnsdist.hh"
 #include "dnsdist-ecs.hh"
 #include "dnsdist-rings.hh"
+#include "dnsdist-xpf.hh"
 
 #include "dnsparser.hh"
 #include "ednsoptions.hh"
@@ -433,7 +434,7 @@ void* tcpClientThread(int pipefd)
         }
 
         if (dq.useECS && ((ds && ds->useECS) || (!ds && serverPool->getECS()))) {
-          if (!handleEDNSClientSubnet(dq, &(ednsAdded), &(ecsAdded))) {
+          if (!handleEDNSClientSubnet(dq, &(ednsAdded), &(ecsAdded), g_preserveTrailingData)) {
             vinfolog("Dropping query from %s because we couldn't insert the ECS value", ci.remote.toStringWithPort());
             goto drop;
           }
@@ -501,7 +502,7 @@ void* tcpClientThread(int pipefd)
         }
 
         if (dq.addXPF && ds->xpfRRCode != 0) {
-          addXPF(dq, ds->xpfRRCode);
+          addXPF(dq, ds->xpfRRCode, g_preserveTrailingData);
         }
 
        int dsock = -1;
diff --git a/pdns/dnsdist-xpf.cc b/pdns/dnsdist-xpf.cc
new file mode 100644 (file)
index 0000000..c828aad
--- /dev/null
@@ -0,0 +1,67 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "dnsdist-xpf.hh"
+
+#include "dnsparser.hh"
+#include "xpf.hh"
+
+bool addXPF(DNSQuestion& dq, uint16_t optionCode, bool preserveTrailingData)
+{
+  std::string payload = generateXPFPayload(dq.tcp, *dq.remote, *dq.local);
+  uint8_t root = '\0';
+  dnsrecordheader drh;
+  drh.d_type = htons(optionCode);
+  drh.d_class = htons(QClass::IN);
+  drh.d_ttl = 0;
+  drh.d_clen = htons(payload.size());
+  size_t recordHeaderLen = sizeof(root) + sizeof(drh);
+
+  size_t available = dq.size - dq.len;
+
+  if ((payload.size() + recordHeaderLen) > available) {
+    return false;
+  }
+
+  size_t xpfSize = sizeof(root) + sizeof(drh) + payload.size();
+  uint32_t realPacketLen = getDNSPacketLength(reinterpret_cast<const char*>(dq.dh), dq.len);
+  if (realPacketLen < dq.len && preserveTrailingData) {
+    size_t toMove = dq.len - realPacketLen;
+    memmove(reinterpret_cast<char*>(dq.dh) + realPacketLen + xpfSize, reinterpret_cast<const char*>(dq.dh) + realPacketLen, toMove);
+    dq.len += xpfSize;
+  }
+  else {
+    dq.len = realPacketLen + xpfSize;
+  }
+
+  size_t pos = realPacketLen;
+  memcpy(reinterpret_cast<char*>(dq.dh) + pos, &root, sizeof(root));
+  pos += sizeof(root);
+  memcpy(reinterpret_cast<char*>(dq.dh) + pos, &drh, sizeof(drh));
+  pos += sizeof(drh);
+  memcpy(reinterpret_cast<char*>(dq.dh) + pos, payload.data(), payload.size());
+  pos += payload.size();
+
+  dq.dh->arcount = htons(ntohs(dq.dh->arcount) + 1);
+
+  return true;
+}
diff --git a/pdns/dnsdist-xpf.hh b/pdns/dnsdist-xpf.hh
new file mode 100644 (file)
index 0000000..5a1b411
--- /dev/null
@@ -0,0 +1,27 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include "dnsdist.hh"
+
+bool addXPF(DNSQuestion& dq, uint16_t optionCode, bool preserveTrailingData);
+
index 48805687061864b1f5621315c1e1cdcfa77d9479..ecd402b82669325e7786d0778ed4b7189194f529 100644 (file)
@@ -48,6 +48,7 @@
 #include "dnsdist-lua.hh"
 #include "dnsdist-rings.hh"
 #include "dnsdist-secpoll.hh"
+#include "dnsdist-xpf.hh"
 
 #include "base64.hh"
 #include "delaypipe.hh"
@@ -62,7 +63,6 @@
 #include "sodcrypto.hh"
 #include "sstuff.hh"
 #include "threadname.hh"
-#include "xpf.hh"
 
 thread_local boost::uuids::random_generator t_uuidGenerator;
 
@@ -143,7 +143,8 @@ int g_udpTimeout{2};
 
 bool g_servFailOnNoPolicy{false};
 bool g_truncateTC{false};
-bool g_fixupCase{0};
+bool g_fixupCase{false};
+bool g_preserveTrailingData{false};
 
 static void truncateTC(char* packet, uint16_t* len, size_t responseSize, unsigned int consumed)
 try
@@ -1197,38 +1198,6 @@ static ssize_t udpClientSendRequestToBackend(DownstreamState* ss, const int sd,
   return result;
 }
 
-bool addXPF(DNSQuestion& dq, uint16_t optionCode)
-{
-  std::string payload = generateXPFPayload(dq.tcp, *dq.remote, *dq.local);
-  uint8_t root = '\0';
-  dnsrecordheader drh;
-  drh.d_type = htons(optionCode);
-  drh.d_class = htons(QClass::IN);
-  drh.d_ttl = 0;
-  drh.d_clen = htons(payload.size());
-  size_t recordHeaderLen = sizeof(root) + sizeof(drh);
-
-  size_t available = dq.size - dq.len;
-
-  if ((payload.size() + recordHeaderLen) > available) {
-    return false;
-  }
-
-  size_t pos = dq.len;
-  memcpy(reinterpret_cast<char*>(dq.dh) + pos, &root, sizeof(root));
-  pos += sizeof(root);
-  memcpy(reinterpret_cast<char*>(dq.dh) + pos, &drh, sizeof(drh));
-  pos += sizeof(drh);
-  memcpy(reinterpret_cast<char*>(dq.dh) + pos, payload.data(), payload.size());
-  pos += payload.size();
-
-  dq.len = pos;
-
-  dq.dh->arcount = htons(ntohs(dq.dh->arcount) + 1);
-
-  return true;
-}
-
 static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest)
 {
   if (msgh->msg_flags & MSG_TRUNC) {
@@ -1422,7 +1391,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     bool ednsAdded = false;
     bool ecsAdded = false;
     if (dq.useECS && ((ss && ss->useECS) || (!ss && serverPool->getECS()))) {
-      if (!handleEDNSClientSubnet(dq, &(ednsAdded), &(ecsAdded))) {
+      if (!handleEDNSClientSubnet(dq, &(ednsAdded), &(ecsAdded), g_preserveTrailingData)) {
         vinfolog("Dropping query from %s because we couldn't insert the ECS value", remote.toStringWithPort());
         return;
       }
@@ -1514,7 +1483,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     }
 
     if (dq.addXPF && ss->xpfRRCode != 0) {
-      addXPF(dq, ss->xpfRRCode);
+      addXPF(dq, ss->xpfRRCode, g_preserveTrailingData);
     }
 
     ss->queries++;
index 566a2c245e637c157e8e26280454a18471b2f084..8cb7c062cae24467d0ef8ce43e4a69a8c4c07553 100644 (file)
@@ -948,6 +948,7 @@ extern uint32_t g_hashperturb;
 extern bool g_useTCPSinglePipe;
 extern std::atomic<uint16_t> g_downstreamTCPCleanupInterval;
 extern size_t g_udpVectorSize;
+extern bool g_preserveTrailingData;
 
 #ifdef HAVE_EBPF
 extern shared_ptr<BPFFilter> g_defaultBPFFilter;
index f2d79be915783a01f9258f0cb9080b75322b5132..92c8632ce89c6653679cd3982c483d69b64688fb 100644 (file)
@@ -111,6 +111,7 @@ dnsdist_SOURCES = \
        dnsdist-snmp.cc dnsdist-snmp.hh \
        dnsdist-tcp.cc \
        dnsdist-web.cc \
+       dnsdist-xpf.cc dnsdist-xpf.hh \
        dnslabeltext.cc \
        dnsname.cc dnsname.hh \
        dnsparser.hh dnsparser.cc \
@@ -231,6 +232,7 @@ testrunner_SOURCES = \
        dnsdist.hh \
        dnsdist-cache.cc dnsdist-cache.hh \
        dnsdist-ecs.cc dnsdist-ecs.hh \
+       dnsdist-xpf.cc dnsdist-xpf.hh \
        dnscrypt.cc dnscrypt.hh \
        dnslabeltext.cc \
        dnsname.cc dnsname.hh \
diff --git a/pdns/dnsdistdist/dnsdist-xpf.cc b/pdns/dnsdistdist/dnsdist-xpf.cc
new file mode 120000 (symlink)
index 0000000..66fd88d
--- /dev/null
@@ -0,0 +1 @@
+../dnsdist-xpf.cc
\ No newline at end of file
diff --git a/pdns/dnsdistdist/dnsdist-xpf.hh b/pdns/dnsdistdist/dnsdist-xpf.hh
new file mode 120000 (symlink)
index 0000000..c2b75e2
--- /dev/null
@@ -0,0 +1 @@
+../dnsdist-xpf.hh
\ No newline at end of file
index b8ce4b2d01c8b8f9f750b73d66d970698d6f2954..643f467a02c281a196ed6ec268e9ae070c58ae40 100644 (file)
@@ -14,3 +14,5 @@ In addition to the global settings, rules and Lua bindings can alter this behavi
  * calling :func:`ECSPrefixLengthAction(v4, v6)` or setting ``dq.ecsPrefixLength`` will override the global :func:`setECSSourcePrefixV4()` and :func:`setECSSourcePrefixV6()` values.
 
 In effect this means that for the EDNS Client Subnet option to be added to the request, ``useClientSubnet`` should be set to ``true`` for the backend used (default to ``false``) and ECS should not have been disabled by calling :func:`DisableECSAction` or setting ``dq.useECS`` to ``false`` (default to true).
+
+Note that any trailing data present in the incoming query is removed by default when an OPT (or XPF) record has to be inserted. This behaviour can be modified using :func:`setPreserveTrailingData()`.
index e8888111365eafc1e3fecde1ea0906f430a73c9b..4103df84f565ac195e4771596f935f73ceb1bc77 100644 (file)
@@ -811,7 +811,7 @@ uint32_t getDNSPacketLength(const char* packet, size_t length)
   }
   try
   {
-    const dnsheader* dh = (const dnsheader*) packet;
+    const dnsheader* dh = reinterpret_cast<const dnsheader*>(packet);
     DNSPacketMangler dpm(const_cast<char*>(packet), length);
 
     const uint16_t qdcount = ntohs(dh->qdcount);
index 9996b6edf8c8914913fd6eca47de0dd1d89a18e2..c579f40c11bf0523c4670eb1e155ef0180956e0c 100644 (file)
 #define BOOST_TEST_NO_MAIN
 
 #include <boost/test/unit_test.hpp>
+#include <unistd.h>
 
 #include "dnsdist.hh"
 #include "dnsdist-ecs.hh"
+#include "dnsdist-xpf.hh"
+
 #include "dolog.hh"
 #include "dnsname.hh"
 #include "dnsparser.hh"
 #include "ednsoptions.hh"
 #include "ednscookies.hh"
 #include "ednssubnet.hh"
-#include <unistd.h>
 
 BOOST_AUTO_TEST_SUITE(test_dnsdist_cc)
 
-bool g_syslog{true};
-bool g_verbose{true};
-
 static const uint16_t ECSSourcePrefixV4 = 24;
 static const uint16_t ECSSourcePrefixV6 = 56;
 
-static void validateQuery(const char * packet, size_t packetSize, bool hasEdns=true)
+static void validateQuery(const char * packet, size_t packetSize, bool hasEdns=true, bool hasXPF=false)
 {
   MOADNSParser mdp(true, packet, packetSize);
 
@@ -52,7 +51,8 @@ static void validateQuery(const char * packet, size_t packetSize, bool hasEdns=t
   BOOST_CHECK_EQUAL(mdp.d_header.qdcount, 1);
   BOOST_CHECK_EQUAL(mdp.d_header.ancount, 0);
   BOOST_CHECK_EQUAL(mdp.d_header.nscount, 0);
-  BOOST_CHECK_EQUAL(mdp.d_header.arcount, (hasEdns ? 1 : 0));
+  uint16_t expectedARCount = 0 + (hasEdns ? 1 : 0) + (hasXPF ? 1 : 0);
+  BOOST_CHECK_EQUAL(mdp.d_header.arcount, expectedARCount);
 }
 
 static void validateECS(const char* packet, size_t packetSize, const ComboAddress& expected)
@@ -89,6 +89,124 @@ static void validateResponse(const char * packet, size_t packetSize, bool hasEdn
   BOOST_CHECK_EQUAL(mdp.d_header.arcount, (hasEdns ? 1 : 0) + additionalCount);
 }
 
+BOOST_AUTO_TEST_CASE(test_addXPF)
+{
+  static const uint16_t xpfOptionCode = 65422;
+
+  struct timespec queryTime;
+  gettime(&queryTime);  // does not have to be accurate ("realTime") in tests
+  ComboAddress remote;
+  DNSName name("www.powerdns.com.");
+
+  vector<uint8_t> query;
+  DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
+  pw.getHeader()->rd = 1;
+  const uint16_t len = query.size();
+  vector<uint8_t> queryWithXPF;
+
+  {
+    char packet[1500];
+    memcpy(packet, query.data(), query.size());
+
+    /* large enough packet */
+    unsigned int consumed = 0;
+    uint16_t qtype;
+    DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, nullptr, &consumed);
+    BOOST_CHECK_EQUAL(qname, name);
+    BOOST_CHECK(qtype == QType::A);
+
+    auto dh = reinterpret_cast<dnsheader*>(packet);
+    DNSQuestion dq(&qname, qtype, QClass::IN, qname.wirelength(), &remote, &remote, dh, sizeof(packet), query.size(), false, &queryTime);
+
+    BOOST_CHECK(addXPF(dq, xpfOptionCode, false));
+    BOOST_CHECK(static_cast<size_t>(dq.len) > query.size());
+    validateQuery(packet, dq.len, false, true);
+    queryWithXPF.resize(dq.len);
+    memcpy(queryWithXPF.data(), packet, dq.len);
+  }
+
+  {
+    char packet[1500];
+    memcpy(packet, query.data(), query.size());
+
+    /* not large enough packet */
+    unsigned int consumed = 0;
+    uint16_t qtype;
+    DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, nullptr, &consumed);
+    BOOST_CHECK_EQUAL(qname, name);
+    BOOST_CHECK(qtype == QType::A);
+
+    auto dh = reinterpret_cast<dnsheader*>(packet);
+    DNSQuestion dq(&qname, qtype, QClass::IN, qname.wirelength(), &remote, &remote, dh, sizeof(packet), query.size(), false, &queryTime);
+    dq.size = dq.len;
+
+    BOOST_CHECK(!addXPF(dq, xpfOptionCode, false));
+    BOOST_CHECK_EQUAL(static_cast<size_t>(dq.len), query.size());
+    validateQuery(packet, dq.len, false, false);
+  }
+
+  {
+    char packet[1500];
+    memcpy(packet, query.data(), query.size());
+
+    /* packet with trailing data (overriding it) */
+    unsigned int consumed = 0;
+    uint16_t qtype;
+    DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, nullptr, &consumed);
+    BOOST_CHECK_EQUAL(qname, name);
+    BOOST_CHECK(qtype == QType::A);
+
+    auto dh = reinterpret_cast<dnsheader*>(packet);
+    DNSQuestion dq(&qname, qtype, QClass::IN, qname.wirelength(), &remote, &remote, dh, sizeof(packet), query.size(), false, &queryTime);
+
+    /* add trailing data */
+    const size_t trailingDataSize = 10;
+    /* Making sure we have enough room to allow for fake trailing data */
+    BOOST_REQUIRE(sizeof(packet) > dq.len && (sizeof(packet) - dq.len) > trailingDataSize);
+    for (size_t idx = 0; idx < trailingDataSize; idx++) {
+      packet[dq.len + idx] = 'A';
+    }
+    dq.len += trailingDataSize;
+
+    BOOST_CHECK(addXPF(dq, xpfOptionCode, false));
+    BOOST_CHECK_EQUAL(static_cast<size_t>(dq.len), queryWithXPF.size());
+    BOOST_CHECK_EQUAL(memcmp(queryWithXPF.data(), packet, queryWithXPF.size()), 0);
+    validateQuery(packet, dq.len, false, true);
+  }
+
+  {
+    char packet[1500];
+    memcpy(packet, query.data(), query.size());
+
+    /* packet with trailing data (preserving trailing data) */
+    unsigned int consumed = 0;
+    uint16_t qtype;
+    DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, nullptr, &consumed);
+    BOOST_CHECK_EQUAL(qname, name);
+    BOOST_CHECK(qtype == QType::A);
+
+    auto dh = reinterpret_cast<dnsheader*>(packet);
+    DNSQuestion dq(&qname, qtype, QClass::IN, qname.wirelength(), &remote, &remote, dh, sizeof(packet), query.size(), false, &queryTime);
+
+    /* add trailing data */
+    const size_t trailingDataSize = 10;
+    /* Making sure we have enough room to allow for fake trailing data */
+    BOOST_REQUIRE(sizeof(packet) > dq.len && (sizeof(packet) - dq.len) > trailingDataSize);
+    for (size_t idx = 0; idx < trailingDataSize; idx++) {
+      packet[dq.len + idx] = 'A';
+    }
+    dq.len += trailingDataSize;
+
+    BOOST_CHECK(addXPF(dq, xpfOptionCode, true));
+    BOOST_CHECK(static_cast<size_t>(dq.len) > queryWithXPF.size());
+    BOOST_CHECK_EQUAL(memcmp(queryWithXPF.data(), packet, queryWithXPF.size()), 0);
+    for (size_t idx = 0; idx < trailingDataSize; idx++) {
+      BOOST_CHECK_EQUAL(packet[queryWithXPF.size() + idx], 'A');
+    }
+    validateQuery(packet, dq.len, false, true);
+  }
+}
+
 BOOST_AUTO_TEST_CASE(addECSWithoutEDNS)
 {
   bool ednsAdded = false;
@@ -109,31 +227,84 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS)
 
   unsigned int consumed = 0;
   uint16_t qtype;
-  DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, nullptr, &consumed);
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption));
-  BOOST_CHECK((size_t) len > query.size());
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false));
+  BOOST_CHECK(static_cast<size_t>(len) > query.size());
   BOOST_CHECK_EQUAL(ednsAdded, true);
   BOOST_CHECK_EQUAL(ecsAdded, false);
   validateQuery(packet, len);
   validateECS(packet, len, remote);
+  vector<uint8_t> queryWithEDNS;
+  queryWithEDNS.resize(len);
+  memcpy(queryWithEDNS.data(), packet, len);
 
   /* not large enough packet */
   ednsAdded = false;
   ecsAdded = false;
   consumed = 0;
   len = query.size();
-  qname = DNSName(reinterpret_cast<char*>(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  qname = DNSName(reinterpret_cast<char*>(query.data()), len, sizeof(dnsheader), false, &qtype, nullptr, &consumed);
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption));
-  BOOST_CHECK_EQUAL((size_t) len, query.size());
+  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false));
+  BOOST_CHECK_EQUAL(static_cast<size_t>(len), query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
   validateQuery(reinterpret_cast<char*>(query.data()), len, false);
+
+  /* packet with trailing data (overriding it) */
+  memcpy(packet, query.data(), query.size());
+  ednsAdded = false;
+  ecsAdded = false;
+  consumed = 0;
+  len = query.size();
+  qname = DNSName(packet, len, sizeof(dnsheader), false, &qtype, nullptr, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::A);
+  /* add trailing data */
+  const size_t trailingDataSize = 10;
+  /* Making sure we have enough room to allow for fake trailing data */
+  BOOST_REQUIRE(sizeof(packet) > len && (sizeof(packet) - len) > trailingDataSize);
+  for (size_t idx = 0; idx < trailingDataSize; idx++) {
+    packet[len + idx] = 'A';
+  }
+  len += trailingDataSize;
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false));
+  BOOST_REQUIRE_EQUAL(static_cast<size_t>(len), queryWithEDNS.size());
+  BOOST_CHECK_EQUAL(memcmp(queryWithEDNS.data(), packet, queryWithEDNS.size()), 0);
+  BOOST_CHECK_EQUAL(ednsAdded, true);
+  BOOST_CHECK_EQUAL(ecsAdded, false);
+  validateQuery(packet, len);
+
+  /* packet with trailing data (preserving trailing data) */
+  memcpy(packet, query.data(), query.size());
+  ednsAdded = false;
+  ecsAdded = false;
+  consumed = 0;
+  len = query.size();
+  qname = DNSName(packet, len, sizeof(dnsheader), false, &qtype, nullptr, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::A);
+  /* add trailing data */
+  /* Making sure we have enough room to allow for fake trailing data */
+  BOOST_REQUIRE(sizeof(packet) > len && (sizeof(packet) - len) > trailingDataSize);
+  for (size_t idx = 0; idx < trailingDataSize; idx++) {
+    packet[len + idx] = 'A';
+  }
+  len += trailingDataSize;
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, true));
+  BOOST_REQUIRE_EQUAL(static_cast<size_t>(len), queryWithEDNS.size() + trailingDataSize);
+  BOOST_CHECK_EQUAL(memcmp(queryWithEDNS.data(), packet, queryWithEDNS.size()), 0);
+  for (size_t idx = 0; idx < trailingDataSize; idx++) {
+    BOOST_CHECK_EQUAL(packet[queryWithEDNS.size() + idx], 'A');
+  }
+  BOOST_CHECK_EQUAL(ednsAdded, true);
+  BOOST_CHECK_EQUAL(ecsAdded, false);
+  validateQuery(packet, len);
 }
 
 BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed)
@@ -164,7 +335,7 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed)
   BOOST_CHECK(!parseEDNSOptions(dq));
 
   /* And now we add our own ECS */
-  BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded));
+  BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded, false));
   BOOST_CHECK_GT(static_cast<size_t>(dq.len), query.size());
   BOOST_CHECK_EQUAL(ednsAdded, true);
   BOOST_CHECK_EQUAL(ecsAdded, false);
@@ -181,7 +352,7 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed)
   BOOST_CHECK(qclass == QClass::IN);
   DNSQuestion dq2(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast<dnsheader*>(query.data()), query.size(), query.size(), false, nullptr);
 
-  BOOST_CHECK(!handleEDNSClientSubnet(dq2, &ednsAdded, &ecsAdded));
+  BOOST_CHECK(!handleEDNSClientSubnet(dq2, &ednsAdded, &ecsAdded, false));
   BOOST_CHECK_EQUAL(static_cast<size_t>(dq2.len), query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
@@ -213,7 +384,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption));
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false));
   BOOST_CHECK((size_t) len > query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, true);
@@ -229,7 +400,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption));
+  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false));
   BOOST_CHECK_EQUAL((size_t) len, query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
@@ -265,7 +436,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECSAlreadyParsed) {
   BOOST_CHECK(parseEDNSOptions(dq));
 
   /* And now we add our own ECS */
-  BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded));
+  BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded, false));
   BOOST_CHECK_GT(static_cast<size_t>(dq.len), query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, true);
@@ -282,7 +453,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECSAlreadyParsed) {
   BOOST_CHECK(qclass == QClass::IN);
   DNSQuestion dq2(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast<dnsheader*>(query.data()), query.size(), query.size(), false, nullptr);
 
-  BOOST_CHECK(!handleEDNSClientSubnet(dq2, &ednsAdded, &ecsAdded));
+  BOOST_CHECK(!handleEDNSClientSubnet(dq2, &ednsAdded, &ecsAdded, false));
   BOOST_CHECK_EQUAL(static_cast<size_t>(dq2.len), query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
@@ -320,7 +491,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSize) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption));
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption, false));
   BOOST_CHECK_EQUAL((size_t) len, query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
@@ -365,7 +536,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSizeAlreadyParsed) {
   BOOST_CHECK(parseEDNSOptions(dq));
 
   /* And now we add our own ECS */
-  BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded));
+  BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded, false));
   BOOST_CHECK_EQUAL(static_cast<size_t>(dq.len), query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
@@ -404,7 +575,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSmaller) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption));
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption, false));
   BOOST_CHECK((size_t) len < query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
@@ -443,7 +614,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithLarger) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption));
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption, false));
   BOOST_CHECK((size_t) len > query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
@@ -459,7 +630,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithLarger) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption));
+  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption, false));
   BOOST_CHECK_EQUAL((size_t) len, query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
index e9f248cb6b318222a647337e94623884725ce20d..4a2f6ddfd403231abbce77aaf6f6f2e71f6d5368 100644 (file)
@@ -417,6 +417,32 @@ BOOST_AUTO_TEST_CASE(test_getDNSPacketLength) {
     BOOST_CHECK_EQUAL(result, realSize);
   }
 
+  {
+    /* truncated packet, should return the full size */
+    vector<uint8_t> packet;
+    DNSPacketWriter pwR(packet, name, QType::A, QClass::IN, 0);
+    pwR.getHeader()->qr = 1;
+    pwR.commit();
+
+    pwR.startRecord(name, QType::A, 255, QClass::IN, DNSResourceRecord::ANSWER);
+    pwR.xfrIP(v4.sin4.sin_addr.s_addr);
+    pwR.commit();
+
+    pwR.startRecord(name, QType::SOA, 257, QClass::IN, DNSResourceRecord::AUTHORITY);
+    pwR.commit();
+
+    pwR.startRecord(name, QType::A, 256, QClass::IN, DNSResourceRecord::ADDITIONAL);
+    pwR.xfrIP(v4.sin4.sin_addr.s_addr);
+    pwR.commit();
+
+    pwR.addOpt(4096, 0, 0);
+    pwR.commit();
+
+    size_t fakeSize = packet.size()-1;
+    auto result = getDNSPacketLength(reinterpret_cast<char*>(packet.data()), fakeSize);
+    BOOST_CHECK_EQUAL(result, fakeSize);
+  }
+
 }
 
 BOOST_AUTO_TEST_CASE(test_getRecordsOfTypeCount) {