]> granicus.if.org Git - pdns/commitdiff
dnsdist: Add tag-based routing
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 1 Dec 2017 18:21:20 +0000 (19:21 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 1 Dec 2017 18:21:20 +0000 (19:21 +0100)
pdns/dnsdist-console.cc
pdns/dnsdist-lua2.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/docs/rules-actions.rst
pdns/dnsrulactions.hh
regression-tests.dnsdist/test_Tags.py [new file with mode: 0644]

index d70a646747e438dc30e57e5530fdb93d108a40bd..5f7fa35c38f3203f10dc799382efd2da1c9be585 100644 (file)
@@ -423,6 +423,9 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "SNMPTrapAction", true, "[reason]", "send an SNMP trap, adding the optional `reason` string as the query description"},
   { "SNMPTrapResponseAction", true, "[reason]", "send an SNMP trap, adding the optional `reason` string as the response description"},
   { "SpoofAction", true, "{ip, ...} ", "forge a response with the specified IPv4 (for an A query) or IPv6 (for an AAAA). If you specify multiple addresses, all that match the query type (A, AAAA or ANY) will get spoofed in" },
+  { "TagAction", true, "name, value", "set the tag named 'name' to the given value" },
+  { "TagResponseAction", true, "name, value", "set the tag named 'name' to the given value" },
+  { "TagRule", true, "name [, value]", "matches if the tag named 'name' is present, with the given 'value' matching if any" },
   { "TCAction", true, "", "create answer to query with TC and RD bits set, to move to TCP" },
   { "testCrypto", true, "", "test of the crypto all works" },
   { "TimedIPSetRule", true, "", "Create a rule which matches a set of IP addresses which expire"}, 
index 7e043ef395f9934701413125f41a51d62fbda25e..23908b0f0da7df42dd45bdec9280becdd9e09f28 100644 (file)
@@ -1481,4 +1481,16 @@ void moreLua(bool client)
         g_outputBuffer="recvmmsg support is not available!\n";
 #endif
       });
+
+    g_lua.writeFunction("TagAction", [](std::string tag, std::string value) {
+        return std::shared_ptr<DNSAction>(new TagAction(tag, value));
+      });
+
+    g_lua.writeFunction("TagResponseAction", [](std::string tag, std::string value) {
+        return std::shared_ptr<DNSResponseAction>(new TagResponseAction(tag, value));
+      });
+
+    g_lua.writeFunction("TagRule", [](std::string tag, boost::optional<std::string> value) {
+        return std::shared_ptr<DNSRule>(new TagRule(tag, value));
+      });
 }
index e74ac4a398763d99b4fa14e85a6ca2993eaf7b3b..e8f8857e01d05799d14f66f6e8248154cdf62c63 100644 (file)
@@ -390,6 +390,8 @@ void* tcpClientThread(int pipefd)
 #ifdef HAVE_PROTOBUF
             dr.uniqueId = dq.uniqueId;
 #endif
+            dr.qTag = dq.qTag;
+
             if (!processResponse(holders.cacheHitRespRulactions, dr, &delayMsec)) {
               goto drop;
             }
@@ -542,6 +544,8 @@ void* tcpClientThread(int pipefd)
 #ifdef HAVE_PROTOBUF
         dr.uniqueId = dq.uniqueId;
 #endif
+        dr.qTag = dq.qTag;
+
         if (!processResponse(localRespRulactions, dr, &delayMsec)) {
           break;
         }
index 4e646dca2fcd662733347e5fc9caf19f3c888f0a..ef859d6d9426c6deaf48be5d3129ddef4f3d83a0 100644 (file)
@@ -448,6 +448,8 @@ try {
 #ifdef HAVE_PROTOBUF
       dr.uniqueId = ids->uniqueId;
 #endif
+      dr.qTag = ids->qTag;
+
       if (!processResponse(localRespRulactions, dr, &ids->delayMsec)) {
         continue;
       }
@@ -1272,6 +1274,8 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
 #ifdef HAVE_PROTOBUF
         dr.uniqueId = dq.uniqueId;
 #endif
+        dr.qTag = dq.qTag;
+
         if (!processResponse(holders.cacheHitRespRulactions, dr, &delayMsec)) {
           return;
         }
@@ -1361,6 +1365,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     ids->packetCache = packetCache;
     ids->ednsAdded = ednsAdded;
     ids->ecsAdded = ecsAdded;
+    ids->qTag = dq.qTag;
 
     /* If we couldn't harvest the real dest addr, still
        write down the listening addr since it will be useful
index c9ff9bb5a6fd54801eea0a4e7a03151e1af018d9..7b9b1fd1679bfb680d766e251536563aeda8f9f8 100644 (file)
@@ -65,20 +65,20 @@ public:
   {
   }
 
-  void add(const std::string strLabel, const std::string strValue)
+  void add(const std::string& strLabel, const std::string& strValue)
   {
-    tagData.insert( {strLabel, strValue});
+    tagData.insert({strLabel, strValue});
     return;
   }
 
   std::string getMatch(const std::string& strLabel)  const
   {
-    std::unordered_map<std::string, std::string>::const_iterator got =tagData.find (strLabel);
-    if(got == tagData.end()) {
+    const auto got = tagData.find(strLabel);
+    if (got == tagData.cend()) {
       return "";
-    } else {
-      return got->second;
     }
+
+    return got->second;
   }
 
   std::string getEntry(size_t iEntry) const
@@ -117,7 +117,7 @@ public:
     return strRet;
   }
 
-  std::unordered_map<std::string, std::string>tagData;
+  std::unordered_map<std::string, std::string> tagData;
 
 private:
   static constexpr char const *strSep = "\t";
@@ -410,6 +410,7 @@ struct IDState
   boost::optional<boost::uuids::uuid> uniqueId;
 #endif
   std::shared_ptr<DNSDistPacketCache> packetCache{nullptr};
+  std::shared_ptr<QTag> qTag{nullptr};
   const ClientState* cs{nullptr};
   uint32_t cacheKey;                                          // 8
   std::atomic<uint16_t> age;                                  // 4
index 583916de311b3b77383f037fe6c4c986b348e449..64f7b7fc8ab19902cc03eef90e733a94b0dd8c0a 100644 (file)
@@ -528,6 +528,13 @@ These ``DNSRule``\ s be one of the following items:
   :param SuffixMatchNode smb: The SuffixMatchNode to match on
   :param bool quiet: Do not return the list of matched domains. Default is false.
 
+.. function:: TagRule(name [, value])
+
+  Matches question or answer with a tag named ``name`` set. If ``value`` is specified, the existing tag value should match too.
+
+  :param bool name: The name of the tag that has to be set
+  :param bool value: If set, the value the tag has to be set to. Default is unset
+
 .. function:: TCPRule([tcp])
 
   Matches question received over TCP if ``tcp`` is true, over UDP otherwise.
@@ -742,6 +749,20 @@ The following actions exist.
 
   :param string cname: The name to respond with
 
+.. function:: TagAction(name, value)
+
+  Associate a tag named ``name`` with a value of ``value`` to this query, that will be passed on to the response.
+
+  :param string name: The name of the tag to set
+  :param string cname: The value of the tag
+
+.. function:: TagResponseAction(name, value)
+
+  Associate a tag named ``name`` with a value of ``value`` to this response.
+
+  :param string name: The name of the tag to set
+  :param string cname: The value of the tag
+
 .. function:: TCAction()
 
   Create answer to query with TC and RD bits set, to force the client to TCP.
index 8f824efa80409aac27c98f599ba2933fc8abc1ed..254103e9c851f6ea18d08396e8be347ffb9bb29e 100644 (file)
@@ -768,9 +768,44 @@ public:
   }
   bool matches(const DNSQuestion* dq) const override;
   string toString() const override;
+private:
   double d_proba;
 };
 
+class TagRule : public DNSRule
+{
+public:
+  TagRule(std::string tag, boost::optional<std::string> value) : d_value(value), d_tag(tag)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    if (dq->qTag == nullptr) {
+      return false;
+    }
+
+    const auto got = dq->qTag->tagData.find(d_tag);
+    if (got == dq->qTag->tagData.cend()) {
+      return false;
+    }
+
+    if (!d_value) {
+      return true;
+    }
+
+    return got->second == *d_value;
+  }
+
+  string toString() const override
+  {
+    return "tag '" + d_tag + "' is set" + (d_value ? (" to '" + *d_value + "'") : "");
+  }
+
+private:
+  boost::optional<std::string> d_value;
+  std::string d_tag;
+};
+
 
 class DropAction : public DNSAction
 {
@@ -1402,3 +1437,53 @@ public:
 private:
   std::string d_reason;
 };
+
+class TagAction : public DNSAction
+{
+public:
+  TagAction(const std::string tag, std::string value): d_tag(tag), d_value(value)
+  {
+  }
+  DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
+  {
+    if (dq->qTag == nullptr) {
+      dq->qTag = std::make_shared<QTag>();
+    }
+
+    dq->qTag->add(d_tag, d_value);
+
+    return Action::None;
+  }
+  string toString() const override
+  {
+    return "set tag '" + d_tag + "' to value '" + d_value + "'";
+  }
+private:
+  std::string d_tag;
+  std::string d_value;
+};
+
+class TagResponseAction : public DNSResponseAction
+{
+public:
+  TagResponseAction(const std::string tag, std::string value): d_tag(tag), d_value(value)
+  {
+  }
+  DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override
+  {
+    if (dr->qTag == nullptr) {
+      dr->qTag = std::make_shared<QTag>();
+    }
+
+    dr->qTag->add(d_tag, d_value);
+
+    return Action::None;
+  }
+  string toString() const override
+  {
+    return "set tag '" + d_tag + "' to value '" + d_value + "'";
+  }
+private:
+  std::string d_tag;
+  std::string d_value;
+};
diff --git a/regression-tests.dnsdist/test_Tags.py b/regression-tests.dnsdist/test_Tags.py
new file mode 100644 (file)
index 0000000..6da32b9
--- /dev/null
@@ -0,0 +1,216 @@
+#!/usr/bin/env python
+import unittest
+import dns
+import clientsubnetoption
+from dnsdisttests import DNSDistTest
+
+class TestBasics(DNSDistTest):
+
+    _config_template = """
+    newServer{address="127.0.0.1:%s"}
+
+    function lol(dq)
+      return DNSAction.None, ""
+    end
+    addAction(AllRule(), LuaAction(lol))
+
+    addAction("tag-me-dns-1.tags.tests.powerdns.com.", TagAction("dns", "value1"))
+    addAction("tag-me-dns-2.tags.tests.powerdns.com.", TagAction("dns", "value2"))
+    addAction("tag-me-response-1.tags.tests.powerdns.com.", TagAction("response", "value1"))
+    addAction("tag-me-response-2.tags.tests.powerdns.com.", TagAction("response", "value2"))
+
+    addAction(TagRule("not-dns"), SpoofAction("1.2.3.4"))
+    addAction(TagRule("dns", "value1"), SpoofAction("1.2.3.50"))
+    addAction(TagRule("dns"), SpoofAction("1.2.3.100"))
+
+    function responseHandlerSetTC(dr)
+      dr.dh:setTC(true)
+      return DNSResponseAction.HeaderModify, ""
+    end
+
+    function responseHandlerUnsetQR(dr)
+      dr.dh:setQR(false)
+      return DNSResponseAction.HeaderModify, ""
+    end
+
+    addResponseAction(TagRule("not-dns"), DropResponseAction())
+    addResponseAction(TagRule("response", "value1"), LuaResponseAction(responseHandlerSetTC))
+    addResponseAction(TagRule("response", "no-match-value"), DropResponseAction())
+
+    addResponseAction("tag-me-response-3.tags.tests.powerdns.com.", TagResponseAction("response-tag", "value"))
+    addResponseAction(TagRule("response-tag"), LuaResponseAction(responseHandlerUnsetQR))
+    """
+
+    def testQuestionNoTag(self):
+        """
+        Tag: No match
+        """
+        name = 'no-match.tags.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+    def testQuestionMatchTagAndValue(self):
+        """
+        Tag: Name and value match
+        """
+        name = 'tag-me-dns-1.tags.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '1.2.3.50')
+        expectedResponse.answer.append(rrset)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.assertTrue(receivedResponse)
+        self.assertEquals(expectedResponse, receivedResponse)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.assertTrue(receivedResponse)
+        self.assertEquals(expectedResponse, receivedResponse)
+
+    def testQuestionMatchTagOnly(self):
+        """
+        Tag: Name matches
+        """
+        name = 'tag-me-dns-2.tags.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '1.2.3.100')
+        expectedResponse.answer.append(rrset)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.assertTrue(receivedResponse)
+        self.assertEquals(expectedResponse, receivedResponse)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.assertTrue(receivedResponse)
+        self.assertEquals(expectedResponse, receivedResponse)
+
+    def testResponseNoMatch(self):
+        """
+        Tag: Tag set on query does not match anything
+        """
+        name = 'tag-me-response-2.tags.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+    def testResponseMatchTagAndValue(self):
+        """
+        Tag: Tag and value set on query matches on response
+        """
+        name = 'tag-me-response-1.tags.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '1.2.3.100')
+        response.answer.append(rrset)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.answer.append(rrset)
+        # we will set TC if the tag matches
+        expectedResponse.flags |= dns.flags.TC
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        print(expectedResponse)
+        print(receivedResponse)
+        self.assertEquals(expectedResponse, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(expectedResponse, receivedResponse)
+
+    def testResponseMatchResponseTagMatches(self):
+        """
+        Tag: Tag set on response matches
+        """
+        name = 'tag-me-response-3.tags.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '1.2.3.100')
+        response.answer.append(rrset)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.answer.append(rrset)
+        # we will set QR=0 if the tag matches
+        expectedResponse.flags &= ~dns.flags.QR
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(expectedResponse, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(expectedResponse, receivedResponse)