From e0fd37ec1c560f232ce78ab0f0a531df56fdf33b Mon Sep 17 00:00:00 2001
From: Remi Gacogne <remi.gacogne@powerdns.com>
Date: Mon, 20 Aug 2018 15:21:10 +0200
Subject: [PATCH] dnsdist: Handle EDNS on truncateTC answers

---
 pdns/dnsdist-ecs.cc                     | 16 ++++----
 pdns/dnsdist-ecs.hh                     |  2 +-
 pdns/dnsdist-lua-actions.cc             |  2 +-
 pdns/dnsdist.cc                         | 17 +++++++-
 pdns/dnsparser.cc                       | 16 ++++++--
 pdns/dnsparser.hh                       |  2 +-
 pdns/test-dnsdist_cc.cc                 | 53 +++++++++++++++++++------
 regression-tests.dnsdist/test_Basics.py | 47 ++++++++++++++++++++++
 8 files changed, 125 insertions(+), 30 deletions(-)

diff --git a/pdns/dnsdist-ecs.cc b/pdns/dnsdist-ecs.cc
index cf8b73e8a..b94b0eeee 100644
--- a/pdns/dnsdist-ecs.cc
+++ b/pdns/dnsdist-ecs.cc
@@ -559,23 +559,23 @@ int rewriteResponseWithoutEDNSOption(const std::string& initialPacket, const uin
   return 0;
 }
 
-bool addEDNS(DNSQuestion& dq, bool dnssecOK)
+bool addEDNS(dnsheader* dh, uint16_t& len, const size_t size, bool dnssecOK, uint16_t payloadSize)
 {
-  if (dq.dh->arcount != 0) {
+  if (dh->arcount != 0) {
     return false;
   }
 
   std::string optRecord;
-  generateOptRR(std::string(), optRecord, g_PayloadSizeSelfGenAnswers, dnssecOK);
+  generateOptRR(std::string(), optRecord, payloadSize, dnssecOK);
 
-  if (optRecord.size() >= dq.size || (dq.size - optRecord.size()) < dq.len) {
+  if (optRecord.size() >= size || (size - optRecord.size()) < len) {
     return false;
   }
 
-  char * optPtr = reinterpret_cast<char*>(dq.dh) + dq.len;
+  char * optPtr = reinterpret_cast<char*>(dh) + len;
   memcpy(optPtr, optRecord.data(), optRecord.size());
-  dq.len += optRecord.size();
-  dq.dh->arcount = htons(1);
+  len += optRecord.size();
+  dh->arcount = htons(1);
 
   return true;
 }
@@ -611,7 +611,7 @@ bool addEDNSToQueryTurnedResponse(DNSQuestion& dq)
 
   if (g_addEDNSToSelfGeneratedResponses) {
     /* now we need to add a new OPT record */
-    return addEDNS(dq, dnssecOK);
+    return addEDNS(dq.dh, dq.len, dq.size, dnssecOK, g_PayloadSizeSelfGenAnswers);
   }
 
   /* otherwise we are just fine */
diff --git a/pdns/dnsdist-ecs.hh b/pdns/dnsdist-ecs.hh
index e0b92a42c..c9193f61d 100644
--- a/pdns/dnsdist-ecs.hh
+++ b/pdns/dnsdist-ecs.hh
@@ -32,7 +32,7 @@ int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optio
 int rewriteResponseWithoutEDNSOption(const std::string& initialPacket, const uint16_t optionCodeToSkip, vector<uint8_t>& newContent);
 int getEDNSOptionsStart(char* packet, const size_t offset, const size_t len, char ** optRDLen, size_t * remaining);
 bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind);
-bool addEDNS(DNSQuestion& dq, bool dnssecOK);
+bool addEDNS(dnsheader* dh, uint16_t& len, const size_t size, bool dnssecOK, uint16_t payloadSize);
 bool addEDNSToQueryTurnedResponse(DNSQuestion& dq);
 
 int getEDNSZ(const DNSQuestion& dq);
diff --git a/pdns/dnsdist-lua-actions.cc b/pdns/dnsdist-lua-actions.cc
index 7724fda54..a9f64a40f 100644
--- a/pdns/dnsdist-lua-actions.cc
+++ b/pdns/dnsdist-lua-actions.cc
@@ -458,7 +458,7 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, string* ruleresult) c
   dq->dh->ancount = htons(dq->dh->ancount);
 
   if (hadEDNS && g_addEDNSToSelfGeneratedResponses) {
-    addEDNS(*dq, dnssecOK);
+    addEDNS(dq->dh, dq->len, dq->size, dnssecOK, g_PayloadSizeSelfGenAnswers);
   }
 
   return Action::HeaderModify;
diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc
index 8abb56e41..39018f851 100644
--- a/pdns/dnsdist.cc
+++ b/pdns/dnsdist.cc
@@ -52,6 +52,7 @@
 #include "delaypipe.hh"
 #include "dolog.hh"
 #include "dnsname.hh"
+#include "dnsparser.hh"
 #include "dnswriter.hh"
 #include "ednsoptions.hh"
 #include "gettime.hh"
@@ -142,12 +143,24 @@ bool g_servFailOnNoPolicy{false};
 bool g_truncateTC{false};
 bool g_fixupCase{0};
 
-static void truncateTC(char* packet, uint16_t* len, unsigned int consumed)
+static void truncateTC(char* packet, uint16_t* len, size_t responseSize, unsigned int consumed)
 try
 {
+  bool hadEDNS = false;
+  uint16_t payloadSize = 0;
+  uint16_t z = 0;
+
+  if (g_addEDNSToSelfGeneratedResponses) {
+    hadEDNS = getEDNSUDPPayloadSizeAndZ(packet, *len, &payloadSize, &z);
+  }
+
   *len=(uint16_t) (sizeof(dnsheader)+consumed+DNS_TYPE_SIZE+DNS_CLASS_SIZE);
   struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet);
   dh->ancount = dh->arcount = dh->nscount = 0;
+
+  if (hadEDNS) {
+    addEDNS(dh, *len, responseSize, z & EDNS_HEADER_FLAG_DO, payloadSize);
+  }
 }
 catch(...)
 {
@@ -475,7 +488,7 @@ try {
         }
 
         if(dh->tc && g_truncateTC) {
-          truncateTC(response, &responseLen, consumed);
+          truncateTC(response, &responseLen, responseSize, consumed);
         }
 
         dh->id = ids->origID;
diff --git a/pdns/dnsparser.cc b/pdns/dnsparser.cc
index 2fc13ac38..e26bdf708 100644
--- a/pdns/dnsparser.cc
+++ b/pdns/dnsparser.cc
@@ -934,12 +934,15 @@ uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t sectio
   return result;
 }
 
-uint16_t getEDNSUDPPayloadSize(const char* packet, size_t length)
+bool getEDNSUDPPayloadSizeAndZ(const char* packet, size_t length, uint16_t* payloadSize, uint16_t* z)
 {
   if (length < sizeof(dnsheader)) {
-    return 0;
+    return false;
   }
 
+  *payloadSize = 0;
+  *z = 0;
+
   try
   {
     const dnsheader* dh = (const dnsheader*) packet;
@@ -958,7 +961,11 @@ uint16_t getEDNSUDPPayloadSize(const char* packet, size_t length)
       const uint16_t dnsclass = dpm.get16BitInt();
 
       if(dnstype == QType::OPT) {
-        return dnsclass;
+        /* skip extended rcode and version */
+        dpm.skipBytes(2);
+        *z = dpm.get16BitInt();
+        *payloadSize = dnsclass;
+        return true;
       }
 
       /* TTL */
@@ -969,5 +976,6 @@ uint16_t getEDNSUDPPayloadSize(const char* packet, size_t length)
   catch(...)
   {
   }
-  return 0;
+
+  return false;
 }
diff --git a/pdns/dnsparser.hh b/pdns/dnsparser.hh
index b0182ebb5..3cc710f43 100644
--- a/pdns/dnsparser.hh
+++ b/pdns/dnsparser.hh
@@ -401,7 +401,7 @@ void editDNSPacketTTL(char* packet, size_t length, std::function<uint32_t(uint8_
 uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA=nullptr);
 uint32_t getDNSPacketLength(const char* packet, size_t length);
 uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type);
-uint16_t getEDNSUDPPayloadSize(const char* packet, size_t length);;
+bool getEDNSUDPPayloadSizeAndZ(const char* packet, size_t length, uint16_t* payloadSize, uint16_t* z);
 
 template<typename T>
 std::shared_ptr<T> getRR(const DNSRecord& dr)
diff --git a/pdns/test-dnsdist_cc.cc b/pdns/test-dnsdist_cc.cc
index 9bd45c2e5..0f8901c56 100644
--- a/pdns/test-dnsdist_cc.cc
+++ b/pdns/test-dnsdist_cc.cc
@@ -780,6 +780,8 @@ static int getZ(const DNSName& qname, const uint16_t qtype, const uint16_t qclas
 
 BOOST_AUTO_TEST_CASE(test_getEDNSZ) {
 
+  uint16_t z;
+  uint16_t udpPayloadSize;
   DNSName qname("www.powerdns.com.");
   uint16_t qtype = QType::A;
   uint16_t qclass = QClass::IN;
@@ -801,7 +803,9 @@ BOOST_AUTO_TEST_CASE(test_getEDNSZ) {
     pw.commit();
 
     BOOST_CHECK_EQUAL(getZ(qname, qtype, qclass, query), 0);
-    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(query.data()), query.size()), 0);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(query.data()), query.size(), &udpPayloadSize, &z), false);
+    BOOST_CHECK_EQUAL(z, 0);
+    BOOST_CHECK_EQUAL(udpPayloadSize, 0);
   }
 
   {
@@ -813,7 +817,9 @@ BOOST_AUTO_TEST_CASE(test_getEDNSZ) {
 
     query.resize(query.size() - (/* RDLEN */ sizeof(uint16_t) + /* last byte of TTL / Z */ 1));
     BOOST_CHECK_EQUAL(getZ(qname, qtype, qclass, query), 0);
-    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(query.data()), query.size()), 512);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(query.data()), query.size(), &udpPayloadSize, &z), false);
+    BOOST_CHECK_EQUAL(z, 0);
+    BOOST_CHECK_EQUAL(udpPayloadSize, 0);
   }
 
   {
@@ -824,7 +830,9 @@ BOOST_AUTO_TEST_CASE(test_getEDNSZ) {
     pw.commit();
 
     BOOST_CHECK_EQUAL(getZ(qname, qtype, qclass, query), 0);
-    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(query.data()), query.size()), 512);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(query.data()), query.size(), &udpPayloadSize, &z), true);
+    BOOST_CHECK_EQUAL(z, 0);
+    BOOST_CHECK_EQUAL(udpPayloadSize, 512);
   }
 
   {
@@ -835,7 +843,9 @@ BOOST_AUTO_TEST_CASE(test_getEDNSZ) {
     pw.commit();
 
     BOOST_CHECK_EQUAL(getZ(qname, qtype, qclass, query), EDNS_HEADER_FLAG_DO);
-    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(query.data()), query.size()), 512);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(query.data()), query.size(), &udpPayloadSize, &z), true);
+    BOOST_CHECK_EQUAL(z, EDNS_HEADER_FLAG_DO);
+    BOOST_CHECK_EQUAL(udpPayloadSize, 512);
   }
 
     {
@@ -846,7 +856,9 @@ BOOST_AUTO_TEST_CASE(test_getEDNSZ) {
     pw.commit();
 
     BOOST_CHECK_EQUAL(getZ(qname, qtype, qclass, query), 0);
-    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(query.data()), query.size()), 512);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(query.data()), query.size(), &udpPayloadSize, &z), true);
+    BOOST_CHECK_EQUAL(z, 0);
+    BOOST_CHECK_EQUAL(udpPayloadSize, 512);
   }
 
   {
@@ -857,13 +869,17 @@ BOOST_AUTO_TEST_CASE(test_getEDNSZ) {
     pw.commit();
 
     BOOST_CHECK_EQUAL(getZ(qname, qtype, qclass, query), EDNS_HEADER_FLAG_DO);
-    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(query.data()), query.size()), 512);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(query.data()), query.size(), &udpPayloadSize, &z), true);
+    BOOST_CHECK_EQUAL(z, EDNS_HEADER_FLAG_DO);
+    BOOST_CHECK_EQUAL(udpPayloadSize, 512);
   }
 
 }
 
 BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) {
 
+  uint16_t z;
+  uint16_t udpPayloadSize;
   DNSName qname("www.powerdns.com.");
   uint16_t qtype = QType::A;
   uint16_t qclass = QClass::IN;
@@ -892,7 +908,9 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) {
 
     auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query);
     BOOST_CHECK_EQUAL(getEDNSZ(dq), 0);
-    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(dq.dh), dq.len), 0);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(dq.dh), dq.len, &udpPayloadSize, &z), false);
+    BOOST_CHECK_EQUAL(z, 0);
+    BOOST_CHECK_EQUAL(udpPayloadSize, 0);
   }
 
   {
@@ -905,8 +923,9 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) {
     query.resize(query.size() - (/* RDLEN */ sizeof(uint16_t) + /* last byte of TTL / Z */ 1));
     auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query);
     BOOST_CHECK_EQUAL(getEDNSZ(dq), 0);
-    /* 512, because we don't touch a broken OPT record */
-    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(dq.dh), dq.len), 512);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(dq.dh), dq.len, &udpPayloadSize, &z), false);
+    BOOST_CHECK_EQUAL(z, 0);
+    BOOST_CHECK_EQUAL(udpPayloadSize, 0);
   }
 
   {
@@ -918,7 +937,9 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) {
 
     auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query);
     BOOST_CHECK_EQUAL(getEDNSZ(dq), 0);
-    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(dq.dh), dq.len), g_PayloadSizeSelfGenAnswers);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(dq.dh), dq.len, &udpPayloadSize, &z), true);
+    BOOST_CHECK_EQUAL(z, 0);
+    BOOST_CHECK_EQUAL(udpPayloadSize, g_PayloadSizeSelfGenAnswers);
   }
 
   {
@@ -930,7 +951,9 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) {
 
     auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query);
     BOOST_CHECK_EQUAL(getEDNSZ(dq), EDNS_HEADER_FLAG_DO);
-    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(dq.dh), dq.len), g_PayloadSizeSelfGenAnswers);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(dq.dh), dq.len, &udpPayloadSize, &z), true);
+    BOOST_CHECK_EQUAL(z, EDNS_HEADER_FLAG_DO);
+    BOOST_CHECK_EQUAL(udpPayloadSize, g_PayloadSizeSelfGenAnswers);
   }
 
   {
@@ -942,7 +965,9 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) {
 
     auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query);
     BOOST_CHECK_EQUAL(getEDNSZ(dq), 0);
-    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(dq.dh), dq.len), g_PayloadSizeSelfGenAnswers);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(dq.dh), dq.len, &udpPayloadSize, &z), true);
+    BOOST_CHECK_EQUAL(z, 0);
+    BOOST_CHECK_EQUAL(udpPayloadSize, g_PayloadSizeSelfGenAnswers);
   }
 
   {
@@ -954,7 +979,9 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) {
 
     auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query);
     BOOST_CHECK_EQUAL(getEDNSZ(dq), EDNS_HEADER_FLAG_DO);
-    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(dq.dh), dq.len), g_PayloadSizeSelfGenAnswers);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(dq.dh), dq.len, &udpPayloadSize, &z), true);
+    BOOST_CHECK_EQUAL(z, EDNS_HEADER_FLAG_DO);
+    BOOST_CHECK_EQUAL(udpPayloadSize, g_PayloadSizeSelfGenAnswers);
   }
 }
 
diff --git a/regression-tests.dnsdist/test_Basics.py b/regression-tests.dnsdist/test_Basics.py
index e9c122e70..ea4b7b08b 100644
--- a/regression-tests.dnsdist/test_Basics.py
+++ b/regression-tests.dnsdist/test_Basics.py
@@ -142,6 +142,48 @@ class TestBasics(DNSDistTest):
 
         response.answer.append(rrset)
         response.flags |= dns.flags.TC
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.flags |= dns.flags.TC
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(expectedResponse.flags, receivedResponse.flags)
+        self.assertEquals(expectedResponse.question, receivedResponse.question)
+        self.assertFalse(response.answer == receivedResponse.answer)
+        self.assertEquals(len(receivedResponse.answer), 0)
+        self.assertEquals(len(receivedResponse.authority), 0)
+        self.assertEquals(len(receivedResponse.additional), 0)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+    def testTruncateTCEDNS(self):
+        """
+        Basics: Truncate TC with EDNS
+
+        dnsdist is configured to truncate TC (default),
+        we make the backend send responses
+        with TC set and additional content,
+        and check that the received response has been fixed.
+        Note that the query and initial response had EDNS,
+        so the final response should have it too.
+        """
+        name = 'atruncatetc.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=True)
+        response = dns.message.make_response(query)
+        # force a different responder payload than the one in the query,
+        # so we check that we don't just mirror it
+        response.payload = 4242
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+
+        response.answer.append(rrset)
+        response.flags |= dns.flags.TC
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.payload = 4242
+        expectedResponse.flags |= dns.flags.TC
 
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
         receivedQuery.id = query.id
@@ -152,6 +194,11 @@ class TestBasics(DNSDistTest):
         self.assertEquals(len(receivedResponse.answer), 0)
         self.assertEquals(len(receivedResponse.authority), 0)
         self.assertEquals(len(receivedResponse.additional), 0)
+        print(expectedResponse)
+        print(receivedResponse)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 4242)
 
     def testRegexReturnsRefused(self):
         """
-- 
2.40.0