]> granicus.if.org Git - pdns/commitdiff
dnsdist: Add regression tests for protobuf alteration via Lua
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 8 Sep 2016 12:05:13 +0000 (14:05 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 8 Sep 2016 12:05:13 +0000 (14:05 +0200)
regression-tests.dnsdist/test_Protobuf.py

index c5722ec73bc482faeca1236b68d76bc6f5425cf0..d4f017dc6bbbe97347ff2a0d81dec30087bccaf6 100644 (file)
@@ -17,10 +17,25 @@ class TestProtobuf(DNSDistTest):
     _protobufCounter = 0
     _config_params = ['_testServerPort', '_protobufServerPort']
     _config_template = """
+    luasmn = newSuffixMatchNode()
+    luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.'))
+
+    function alterProtobuf(dq, protobuf)
+      if luasmn:check(dq.qname) then
+        requestor = newCA(dq.remoteaddr:toString())
+        if requestor:isIPv4() then
+          requestor:truncate(24)
+        else
+          requestor:truncate(56)
+        end
+        protobuf:setRequestor(requestor)
+      end
+    end
+
     newServer{address="127.0.0.1:%s", useClientSubnet=true}
     rl = newRemoteLogger('127.0.0.1:%s')
-    addAction(AllRule(), RemoteLogAction(rl))
-    addResponseAction(AllRule(), RemoteLogResponseAction(rl))
+    addAction(AllRule(), RemoteLogAction(rl, alterProtobuf))
+    addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobuf))
     """
 
     @classmethod
@@ -72,14 +87,14 @@ class TestProtobuf(DNSDistTest):
         msg.ParseFromString(data)
         return msg
 
-    def checkProtobufBase(self, msg, protocol, query):
+    def checkProtobufBase(self, msg, protocol, query, initiator):
         self.assertTrue(msg)
         self.assertTrue(msg.HasField('timeSec'))
         self.assertTrue(msg.HasField('socketFamily'))
         self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
         self.assertTrue(msg.HasField('from'))
         fromvalue = getattr(msg, 'from')
-        self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), '127.0.0.1')
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator)
         self.assertTrue(msg.HasField('socketProtocol'))
         self.assertEquals(msg.socketProtocol, protocol)
         self.assertTrue(msg.HasField('messageId'))
@@ -93,9 +108,9 @@ class TestProtobuf(DNSDistTest):
         # self.assertEquals(len(msg.originalRequestorSubnet), 4)
         # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
 
-    def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname):
+    def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'):
         self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)
-        self.checkProtobufBase(msg, protocol, query)
+        self.checkProtobufBase(msg, protocol, query, initiator)
         # dnsdist doesn't fill the responder field for responses
         # because it doesn't keep the information around.
         self.assertTrue(msg.HasField('to'))
@@ -108,12 +123,23 @@ class TestProtobuf(DNSDistTest):
         self.assertTrue(msg.question.HasField('qName'))
         self.assertEquals(msg.question.qName, qname)
 
-    def checkProtobufResponse(self, msg, protocol, response):
+    def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'):
         self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
-        self.checkProtobufBase(msg, protocol, response)
+        self.checkProtobufBase(msg, protocol, response, initiator)
         self.assertTrue(msg.HasField('response'))
         self.assertTrue(msg.response.HasField('queryTimeSec'))
 
+    def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl):
+        self.assertTrue(record.HasField('class'))
+        self.assertEquals(getattr(record, 'class'), rclass)
+        self.assertTrue(record.HasField('type'))
+        self.assertEquals(record.type, rtype)
+        self.assertTrue(record.HasField('name'))
+        self.assertEquals(record.name, rname)
+        self.assertTrue(record.HasField('ttl'))
+        self.assertEquals(record.ttl, rttl)
+        self.assertTrue(record.HasField('rdata'))
+
     def testProtobuf(self):
         """
         Protobuf: Send data to a protobuf server
@@ -147,15 +173,7 @@ class TestProtobuf(DNSDistTest):
         self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response)
         self.assertEquals(len(msg.response.rrs), 1)
         for rr in msg.response.rrs:
-            self.assertTrue(rr.HasField('class'))
-            self.assertEquals(getattr(rr, 'class'), dns.rdataclass.IN)
-            self.assertTrue(rr.HasField('type'))
-            self.assertEquals(rr.type, dns.rdatatype.A)
-            self.assertTrue(rr.HasField('name'))
-            self.assertEquals(rr.name, name)
-            self.assertTrue(rr.HasField('ttl'))
-            self.assertEquals(rr.ttl, 3600)
-            self.assertTrue(rr.HasField('rdata'))
+            self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
             self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
 
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
@@ -177,13 +195,63 @@ class TestProtobuf(DNSDistTest):
         self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response)
         self.assertEquals(len(msg.response.rrs), 1)
         for rr in msg.response.rrs:
-            self.assertTrue(rr.HasField('class'))
-            self.assertEquals(getattr(rr, 'class'), dns.rdataclass.IN)
-            self.assertTrue(rr.HasField('type'))
-            self.assertEquals(rr.type, dns.rdatatype.A)
-            self.assertTrue(rr.HasField('name'))
-            self.assertEquals(rr.name, name)
-            self.assertTrue(rr.HasField('ttl'))
-            self.assertEquals(rr.ttl, 3600)
-            self.assertTrue(rr.HasField('rdata'))
+            self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
+            self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
+
+    def testLuaProtobuf(self):
+        """
+        Protobuf: Check that the Lua callback rewrote the initiator
+        """
+        name = 'lua.protobuf.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        # let the protobuf messages the time to get there
+        time.sleep(1)
+
+        # check the protobuf message corresponding to the UDP query
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0')
+
+        # check the protobuf message corresponding to the UDP response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')
+        self.assertEquals(len(msg.response.rrs), 1)
+        for rr in msg.response.rrs:
+            self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
+            self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        # let the protobuf messages the time to get there
+        time.sleep(1)
+
+        # check the protobuf message corresponding to the TCP query
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0')
+
+        # check the protobuf message corresponding to the TCP response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')
+        self.assertEquals(len(msg.response.rrs), 1)
+        for rr in msg.response.rrs:
+            self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
             self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')