From: Remi Gacogne Date: Mon, 15 Jul 2019 11:02:41 +0000 (+0200) Subject: dnsdist: Add a regression test for invalid DNS answer from the backend X-Git-Tag: dnsdist-1.4.0-rc1~47^2 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=a620f197a9f3aad9aa9dd6c09f65a90319138189;p=pdns dnsdist: Add a regression test for invalid DNS answer from the backend --- diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 98ebd4d2a..792ceefcc 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -176,10 +176,11 @@ class DNSDistTest(unittest.TestCase): return response @classmethod - def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False): + def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None): # trailingDataResponse=True means "ignore trailing data". # Other values are either False (meaning "raise an exception") # or are interpreted as a response RCODE for queries with trailing data. + # callback is invoked for every -even healthcheck ones- query and should return a raw response ignoreTrailing = trailingDataResponse is True sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -197,20 +198,24 @@ class DNSDistTest(unittest.TestCase): request = dns.message.from_wire(data, ignore_trailing=True) forceRcode = trailingDataResponse - response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) - if not response: - continue + if callback: + wire = callback(request) + else: + response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) + if response: + wire = response.to_wire() sock.settimeout(2.0) - sock.sendto(response.to_wire(), addr) + sock.sendto(wire, addr) sock.settimeout(None) sock.close() @classmethod - def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False): + def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None): # trailingDataResponse=True means "ignore trailing data". # Other values are either False (meaning "raise an exception") # or are interpreted as a response RCODE for queries with trailing data. + # callback is invoked for every -even healthcheck ones- query and should return a raw response ignoreTrailing = trailingDataResponse is True sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -243,12 +248,17 @@ class DNSDistTest(unittest.TestCase): request = dns.message.from_wire(data, ignore_trailing=True) forceRcode = trailingDataResponse - response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) - if not response: + if callback: + wire = callback(request) + else: + response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) + if response: + wire = response.to_wire(max_size=65535) + + if not wire: conn.close() continue - wire = response.to_wire(max_size=65535) conn.send(struct.pack("!H", len(wire))) conn.send(wire) diff --git a/regression-tests.dnsdist/test_BrokenAnswer.py b/regression-tests.dnsdist/test_BrokenAnswer.py new file mode 100644 index 000000000..a706dfa7e --- /dev/null +++ b/regression-tests.dnsdist/test_BrokenAnswer.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python +import threading +import clientsubnetoption +import dns +from dnsdisttests import DNSDistTest + +def responseCallback(request): + if len(request.question) != 1: + print("Skipping query with question count %d" % (len(request.question))) + return None + healthCheck = str(request.question[0].name).endswith('a.root-servers.net.') + if healthCheck: + response = dns.message.make_response(request) + return response.to_wire() + # now we create a broken response + response = dns.message.make_response(request) + ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 32) + response.use_edns(edns=True, payload=4096, options=[ecso]) + rrset = dns.rrset.from_text(request.question[0].name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + raw = response.to_wire() + # first label length of this rrset is at 12 (dnsheader) + length(qname) + 2 (leading label length + trailing 0) + 2 (qtype) + 2 (qclass) + offset = 12 + len(str(request.question[0].name)) + 2 + 2 + 2 + altered = raw[:offset] + chr(255).encode() + raw[offset+1:] + return altered + +class TestBrokenAnswerECS(DNSDistTest): + + # this test suite uses a different responder port + # because, contrary to the other ones, its + # responders send raw, broken data + _testServerPort = 5400 + _config_template = """ + setECSSourcePrefixV4(32) + newServer{address="127.0.0.1:%s", useClientSubnet=true} + """ + @classmethod + def startResponders(cls): + print("Launching responders..") + + # Returns broken data for non-healthcheck queries + cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, responseCallback]) + cls._UDPResponder.setDaemon(True) + cls._UDPResponder.start() + + # Returns broken data for non-healthcheck queries + cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, responseCallback]) + cls._TCPResponder.setDaemon(True) + cls._TCPResponder.start() + + def testUDPWithInvalidAnswer(self): + """ + Broken Answer: Invalid UDP answer with ECS + """ + name = 'invalid-ecs-udp.broken-answer.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + expectedResponse = dns.message.make_response(query) + expectedResponse.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) + self.assertFalse(receivedQuery) + self.assertFalse(receivedResponse) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) + self.assertFalse(receivedQuery) + self.assertFalse(receivedResponse) + + def testTCPWithInvalidAnswer(self): + """ + Broken Answer: Invalid TCP answer with ECS + """ + name = 'invalid-ecs-tcp.broken-answer.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + expectedResponse = dns.message.make_response(query) + expectedResponse.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False) + self.assertFalse(receivedQuery) + self.assertFalse(receivedResponse) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False) + self.assertFalse(receivedQuery) + self.assertFalse(receivedResponse)