From 7caebbb287b58d848bc991ac4841fc98b789a123 Mon Sep 17 00:00:00 2001 From: Remi Gacogne Date: Wed, 28 Dec 2016 16:21:43 +0100 Subject: [PATCH] dnsdist: Add `RDRule()` to match queries with the `RD` flag set --- pdns/README-dnsdist.md | 4 +++ pdns/dnsdist-lua2.cc | 7 +++- pdns/dnsrulactions.hh | 16 +++++++++ regression-tests.dnsdist/test_Advanced.py | 41 +++++++++++++++++++++++ 4 files changed, 67 insertions(+), 1 deletion(-) diff --git a/pdns/README-dnsdist.md b/pdns/README-dnsdist.md index 3d62cafb4..3d6e10d2d 100644 --- a/pdns/README-dnsdist.md +++ b/pdns/README-dnsdist.md @@ -347,6 +347,7 @@ Rules have selectors and actions. Current selectors are: * RE2Rule on query name (optional) * Response code * Packet requests DNSSEC processing + * Packet requests recursion * Query received over UDP or TCP * Opcode (OpcodeRule) * Number of entries in a given section (RecordsCountRule) @@ -405,6 +406,7 @@ A DNS rule can be: * an AllRule * an AndRule + * a DNSSECRule * a MaxQPSIPRule * a MaxQPSRule * a NetmaskGroupRule @@ -416,6 +418,7 @@ A DNS rule can be: * a QNameWireLengthRule * a QTypeRule * a RCodeRule + * a RDRule * a RegexRule * a RE2Rule * a RecordsCountRule @@ -1327,6 +1330,7 @@ instantiate a server with additional parameters * `QNameWireLengthRule(min, max)`: matches if the qname's length on the wire is less than `min` or more than `max` bytes * `QTypeRule(qtype)`: matches queries with the specified qtype * `RCodeRule(rcode)`: matches queries or responses the specified rcode + * `RDRule()`: matches queries with the `RD` flag set * `RegexRule(regex)`: matches the query name against the supplied regex * `RecordsCountRule(section, minCount, maxCount)`: matches if there is at least `minCount` and at most `maxCount` records in the `section` section * `RecordsTypeCountRule(section, type, minCount, maxCount)`: matches if there is at least `minCount` and at most `maxCount` records of type `type` in the `section` section diff --git a/pdns/dnsdist-lua2.cc b/pdns/dnsdist-lua2.cc index 9c6626606..419e6f903 100644 --- a/pdns/dnsdist-lua2.cc +++ b/pdns/dnsdist-lua2.cc @@ -1154,5 +1154,10 @@ void moreLua(bool client) return; } g_rings.setCapacity(capacity); - }); + }); + + g_lua.writeFunction("RDRule", []() { + return std::shared_ptr(new RDRule()); + }); + } diff --git a/pdns/dnsrulactions.hh b/pdns/dnsrulactions.hh index 0244cd6ea..9af98903c 100644 --- a/pdns/dnsrulactions.hh +++ b/pdns/dnsrulactions.hh @@ -579,6 +579,22 @@ private: int d_rcode; }; +class RDRule : public DNSRule +{ +public: + RDRule() + { + } + bool matches(const DNSQuestion* dq) const override + { + return dq->dh->rd == 1; + } + string toString() const override + { + return "rd==1"; + } +}; + class DropAction : public DNSAction { diff --git a/regression-tests.dnsdist/test_Advanced.py b/regression-tests.dnsdist/test_Advanced.py index eb35ed0a3..a32e96df6 100644 --- a/regression-tests.dnsdist/test_Advanced.py +++ b/regression-tests.dnsdist/test_Advanced.py @@ -1330,3 +1330,44 @@ advanced.tests.powerdns.com. tests.powerdns.com. powerdns.com. com.""") + +class TestAdvancedRD(DNSDistTest): + + _config_template = """ + addAction(RDRule(), RCodeAction(dnsdist.REFUSED)) + newServer{address="127.0.0.1:%s"} + """ + + def testAdvancedRDRefused(self): + """ + Advanced: RD query is refused + """ + name = 'rd.advanced.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + expectedResponse = dns.message.make_response(query) + expectedResponse.set_rcode(dns.rcode.REFUSED) + + (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) + self.assertEquals(receivedResponse, expectedResponse) + + (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False) + self.assertEquals(receivedResponse, expectedResponse) + + def testAdvancedNoRDAllowed(self): + """ + Advanced: No-RD query is allowed + """ + name = 'no-rd.advanced.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + query.flags &= ~dns.flags.RD + response = dns.message.make_response(query) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + receivedQuery.id = query.id + self.assertEquals(receivedQuery, query) + self.assertEquals(receivedResponse, response) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + receivedQuery.id = query.id + self.assertEquals(receivedQuery, query) + self.assertEquals(receivedResponse, response) -- 2.40.0