From: Remi Gacogne Date: Fri, 1 Dec 2017 18:21:20 +0000 (+0100) Subject: dnsdist: Add tag-based routing X-Git-Tag: dnsdist-1.3.0~190^2~2 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=a76b0d63ce7e2fb51925e73d0c082197b3d885e9;p=pdns dnsdist: Add tag-based routing --- diff --git a/pdns/dnsdist-console.cc b/pdns/dnsdist-console.cc index d70a64674..5f7fa35c3 100644 --- a/pdns/dnsdist-console.cc +++ b/pdns/dnsdist-console.cc @@ -423,6 +423,9 @@ const std::vector g_consoleKeywords{ { "SNMPTrapAction", true, "[reason]", "send an SNMP trap, adding the optional `reason` string as the query description"}, { "SNMPTrapResponseAction", true, "[reason]", "send an SNMP trap, adding the optional `reason` string as the response description"}, { "SpoofAction", true, "{ip, ...} ", "forge a response with the specified IPv4 (for an A query) or IPv6 (for an AAAA). If you specify multiple addresses, all that match the query type (A, AAAA or ANY) will get spoofed in" }, + { "TagAction", true, "name, value", "set the tag named 'name' to the given value" }, + { "TagResponseAction", true, "name, value", "set the tag named 'name' to the given value" }, + { "TagRule", true, "name [, value]", "matches if the tag named 'name' is present, with the given 'value' matching if any" }, { "TCAction", true, "", "create answer to query with TC and RD bits set, to move to TCP" }, { "testCrypto", true, "", "test of the crypto all works" }, { "TimedIPSetRule", true, "", "Create a rule which matches a set of IP addresses which expire"}, diff --git a/pdns/dnsdist-lua2.cc b/pdns/dnsdist-lua2.cc index 7e043ef39..23908b0f0 100644 --- a/pdns/dnsdist-lua2.cc +++ b/pdns/dnsdist-lua2.cc @@ -1481,4 +1481,16 @@ void moreLua(bool client) g_outputBuffer="recvmmsg support is not available!\n"; #endif }); + + g_lua.writeFunction("TagAction", [](std::string tag, std::string value) { + return std::shared_ptr(new TagAction(tag, value)); + }); + + g_lua.writeFunction("TagResponseAction", [](std::string tag, std::string value) { + return std::shared_ptr(new TagResponseAction(tag, value)); + }); + + g_lua.writeFunction("TagRule", [](std::string tag, boost::optional value) { + return std::shared_ptr(new TagRule(tag, value)); + }); } diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index e74ac4a39..e8f8857e0 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -390,6 +390,8 @@ void* tcpClientThread(int pipefd) #ifdef HAVE_PROTOBUF dr.uniqueId = dq.uniqueId; #endif + dr.qTag = dq.qTag; + if (!processResponse(holders.cacheHitRespRulactions, dr, &delayMsec)) { goto drop; } @@ -542,6 +544,8 @@ void* tcpClientThread(int pipefd) #ifdef HAVE_PROTOBUF dr.uniqueId = dq.uniqueId; #endif + dr.qTag = dq.qTag; + if (!processResponse(localRespRulactions, dr, &delayMsec)) { break; } diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 4e646dca2..ef859d6d9 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -448,6 +448,8 @@ try { #ifdef HAVE_PROTOBUF dr.uniqueId = ids->uniqueId; #endif + dr.qTag = ids->qTag; + if (!processResponse(localRespRulactions, dr, &ids->delayMsec)) { continue; } @@ -1272,6 +1274,8 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct #ifdef HAVE_PROTOBUF dr.uniqueId = dq.uniqueId; #endif + dr.qTag = dq.qTag; + if (!processResponse(holders.cacheHitRespRulactions, dr, &delayMsec)) { return; } @@ -1361,6 +1365,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct ids->packetCache = packetCache; ids->ednsAdded = ednsAdded; ids->ecsAdded = ecsAdded; + ids->qTag = dq.qTag; /* If we couldn't harvest the real dest addr, still write down the listening addr since it will be useful diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index c9ff9bb5a..7b9b1fd16 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -65,20 +65,20 @@ public: { } - void add(const std::string strLabel, const std::string strValue) + void add(const std::string& strLabel, const std::string& strValue) { - tagData.insert( {strLabel, strValue}); + tagData.insert({strLabel, strValue}); return; } std::string getMatch(const std::string& strLabel) const { - std::unordered_map::const_iterator got =tagData.find (strLabel); - if(got == tagData.end()) { + const auto got = tagData.find(strLabel); + if (got == tagData.cend()) { return ""; - } else { - return got->second; } + + return got->second; } std::string getEntry(size_t iEntry) const @@ -117,7 +117,7 @@ public: return strRet; } - std::unordered_maptagData; + std::unordered_map tagData; private: static constexpr char const *strSep = "\t"; @@ -410,6 +410,7 @@ struct IDState boost::optional uniqueId; #endif std::shared_ptr packetCache{nullptr}; + std::shared_ptr qTag{nullptr}; const ClientState* cs{nullptr}; uint32_t cacheKey; // 8 std::atomic age; // 4 diff --git a/pdns/dnsdistdist/docs/rules-actions.rst b/pdns/dnsdistdist/docs/rules-actions.rst index 583916de3..64f7b7fc8 100644 --- a/pdns/dnsdistdist/docs/rules-actions.rst +++ b/pdns/dnsdistdist/docs/rules-actions.rst @@ -528,6 +528,13 @@ These ``DNSRule``\ s be one of the following items: :param SuffixMatchNode smb: The SuffixMatchNode to match on :param bool quiet: Do not return the list of matched domains. Default is false. +.. function:: TagRule(name [, value]) + + Matches question or answer with a tag named ``name`` set. If ``value`` is specified, the existing tag value should match too. + + :param bool name: The name of the tag that has to be set + :param bool value: If set, the value the tag has to be set to. Default is unset + .. function:: TCPRule([tcp]) Matches question received over TCP if ``tcp`` is true, over UDP otherwise. @@ -742,6 +749,20 @@ The following actions exist. :param string cname: The name to respond with +.. function:: TagAction(name, value) + + Associate a tag named ``name`` with a value of ``value`` to this query, that will be passed on to the response. + + :param string name: The name of the tag to set + :param string cname: The value of the tag + +.. function:: TagResponseAction(name, value) + + Associate a tag named ``name`` with a value of ``value`` to this response. + + :param string name: The name of the tag to set + :param string cname: The value of the tag + .. function:: TCAction() Create answer to query with TC and RD bits set, to force the client to TCP. diff --git a/pdns/dnsrulactions.hh b/pdns/dnsrulactions.hh index 8f824efa8..254103e9c 100644 --- a/pdns/dnsrulactions.hh +++ b/pdns/dnsrulactions.hh @@ -768,9 +768,44 @@ public: } bool matches(const DNSQuestion* dq) const override; string toString() const override; +private: double d_proba; }; +class TagRule : public DNSRule +{ +public: + TagRule(std::string tag, boost::optional value) : d_value(value), d_tag(tag) + { + } + bool matches(const DNSQuestion* dq) const override + { + if (dq->qTag == nullptr) { + return false; + } + + const auto got = dq->qTag->tagData.find(d_tag); + if (got == dq->qTag->tagData.cend()) { + return false; + } + + if (!d_value) { + return true; + } + + return got->second == *d_value; + } + + string toString() const override + { + return "tag '" + d_tag + "' is set" + (d_value ? (" to '" + *d_value + "'") : ""); + } + +private: + boost::optional d_value; + std::string d_tag; +}; + class DropAction : public DNSAction { @@ -1402,3 +1437,53 @@ public: private: std::string d_reason; }; + +class TagAction : public DNSAction +{ +public: + TagAction(const std::string tag, std::string value): d_tag(tag), d_value(value) + { + } + DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override + { + if (dq->qTag == nullptr) { + dq->qTag = std::make_shared(); + } + + dq->qTag->add(d_tag, d_value); + + return Action::None; + } + string toString() const override + { + return "set tag '" + d_tag + "' to value '" + d_value + "'"; + } +private: + std::string d_tag; + std::string d_value; +}; + +class TagResponseAction : public DNSResponseAction +{ +public: + TagResponseAction(const std::string tag, std::string value): d_tag(tag), d_value(value) + { + } + DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override + { + if (dr->qTag == nullptr) { + dr->qTag = std::make_shared(); + } + + dr->qTag->add(d_tag, d_value); + + return Action::None; + } + string toString() const override + { + return "set tag '" + d_tag + "' to value '" + d_value + "'"; + } +private: + std::string d_tag; + std::string d_value; +}; diff --git a/regression-tests.dnsdist/test_Tags.py b/regression-tests.dnsdist/test_Tags.py new file mode 100644 index 000000000..6da32b9a0 --- /dev/null +++ b/regression-tests.dnsdist/test_Tags.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python +import unittest +import dns +import clientsubnetoption +from dnsdisttests import DNSDistTest + +class TestBasics(DNSDistTest): + + _config_template = """ + newServer{address="127.0.0.1:%s"} + + function lol(dq) + return DNSAction.None, "" + end + addAction(AllRule(), LuaAction(lol)) + + addAction("tag-me-dns-1.tags.tests.powerdns.com.", TagAction("dns", "value1")) + addAction("tag-me-dns-2.tags.tests.powerdns.com.", TagAction("dns", "value2")) + addAction("tag-me-response-1.tags.tests.powerdns.com.", TagAction("response", "value1")) + addAction("tag-me-response-2.tags.tests.powerdns.com.", TagAction("response", "value2")) + + addAction(TagRule("not-dns"), SpoofAction("1.2.3.4")) + addAction(TagRule("dns", "value1"), SpoofAction("1.2.3.50")) + addAction(TagRule("dns"), SpoofAction("1.2.3.100")) + + function responseHandlerSetTC(dr) + dr.dh:setTC(true) + return DNSResponseAction.HeaderModify, "" + end + + function responseHandlerUnsetQR(dr) + dr.dh:setQR(false) + return DNSResponseAction.HeaderModify, "" + end + + addResponseAction(TagRule("not-dns"), DropResponseAction()) + addResponseAction(TagRule("response", "value1"), LuaResponseAction(responseHandlerSetTC)) + addResponseAction(TagRule("response", "no-match-value"), DropResponseAction()) + + addResponseAction("tag-me-response-3.tags.tests.powerdns.com.", TagResponseAction("response-tag", "value")) + addResponseAction(TagRule("response-tag"), LuaResponseAction(responseHandlerUnsetQR)) + """ + + def testQuestionNoTag(self): + """ + Tag: No match + """ + name = 'no-match.tags.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + + def testQuestionMatchTagAndValue(self): + """ + Tag: Name and value match + """ + name = 'tag-me-dns-1.tags.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + # dnsdist set RA = RD for spoofed responses + query.flags &= ~dns.flags.RD + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '1.2.3.50') + expectedResponse.answer.append(rrset) + + (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) + self.assertTrue(receivedResponse) + self.assertEquals(expectedResponse, receivedResponse) + + (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False) + self.assertTrue(receivedResponse) + self.assertEquals(expectedResponse, receivedResponse) + + def testQuestionMatchTagOnly(self): + """ + Tag: Name matches + """ + name = 'tag-me-dns-2.tags.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + # dnsdist set RA = RD for spoofed responses + query.flags &= ~dns.flags.RD + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '1.2.3.100') + expectedResponse.answer.append(rrset) + + (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) + self.assertTrue(receivedResponse) + self.assertEquals(expectedResponse, receivedResponse) + + (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False) + self.assertTrue(receivedResponse) + self.assertEquals(expectedResponse, receivedResponse) + + def testResponseNoMatch(self): + """ + Tag: Tag set on query does not match anything + """ + name = 'tag-me-response-2.tags.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '192.0.2.1') + response.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + + def testResponseMatchTagAndValue(self): + """ + Tag: Tag and value set on query matches on response + """ + name = 'tag-me-response-1.tags.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + # dnsdist set RA = RD for spoofed responses + query.flags &= ~dns.flags.RD + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '1.2.3.100') + response.answer.append(rrset) + expectedResponse = dns.message.make_response(query) + expectedResponse.answer.append(rrset) + # we will set TC if the tag matches + expectedResponse.flags |= dns.flags.TC + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + print(expectedResponse) + print(receivedResponse) + self.assertEquals(expectedResponse, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(expectedResponse, receivedResponse) + + def testResponseMatchResponseTagMatches(self): + """ + Tag: Tag set on response matches + """ + name = 'tag-me-response-3.tags.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + # dnsdist set RA = RD for spoofed responses + query.flags &= ~dns.flags.RD + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '1.2.3.100') + response.answer.append(rrset) + expectedResponse = dns.message.make_response(query) + expectedResponse.answer.append(rrset) + # we will set QR=0 if the tag matches + expectedResponse.flags &= ~dns.flags.QR + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(expectedResponse, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(expectedResponse, receivedResponse)