From: Seth Ornstein Date: Fri, 23 Jun 2017 04:12:27 +0000 (-0400) Subject: completed test for mods to dnsdist in pdns/regression-tests.dnsdist/test_ProtobufTag.py X-Git-Tag: dnsdist-1.2.0~39^2~7 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=8a87971069babcff64305ef9e79b8e0615d13ff2;p=pdns completed test for mods to dnsdist in pdns/regression-tests.dnsdist/test_ProtobufTag.py script to execute it in pdns/zzz-gca-example/test-protobuf-tag.sh --- diff --git a/regression-tests.dnsdist/test_ProtobufTag.py b/regression-tests.dnsdist/test_ProtobufTag.py new file mode 100644 index 000000000..976cecf37 --- /dev/null +++ b/regression-tests.dnsdist/test_ProtobufTag.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python +import Queue +import threading +import socket +import struct +import sys +import time +from dnsdisttests import DNSDistTest + +import dns +import dnsmessage_pb2 + + +class TestProtobuf(DNSDistTest): + _protobufServerPort = 4242 + _protobufQueue = Queue.Queue() + _protobufCounter = 0 + _config_params = ['_testServerPort', '_protobufServerPort'] + _config_template = """ + luasmn = newSuffixMatchNode() + luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.')) + + function alterProtobufResponse(dq, protobuf) + if luasmn:check(dq.qname) then + requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf() + if requestor:isIPv4() then + requestor:truncate(24) + else + requestor:truncate(56) + end + protobuf:setRequestor(requestor) + + local tableTags = {} + tableTags["TestLabel2"] = "TestData2" + tableTags["TestLabel1"] = "TestData1" + protobuf:setTagArray(tableTags) -- setTagArray + + protobuf:setTag('TestLabel3', 'TestData3') -- setTag + + protobuf:setTag("Response", "456") -- setTag + else + local tableTags = {} -- called by testProtobuf() + tableTags["TestLabel2"] = "TestData2" + tableTags["TestLabel1"] = "TestData1" + protobuf:setTagArray(tableTags) -- setTagArray + + protobuf:setTag('TestLabel3', 'TestData3') -- setTag + + protobuf:setTag("Response", "456") -- setTag + end + end + + function alterProtobufQuery(dq, protobuf) + if luasmn:check(dq.qname) then + requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf() + if requestor:isIPv4() then + requestor:truncate(24) + else + requestor:truncate(56) + end + protobuf:setRequestor(requestor) + + local tableTags = {} -- declare table + tableTags = dq:getTagArray() -- get table from DNSQuery + + protobuf:setTagArray(tableTags) -- store table in protobuf + protobuf:setTag("Query", "123") -- add another tag entry in protobuf + + protobuf:setResponseCode(dnsdist.NXDOMAIN) -- set protobuf response code to be NXDOMAIN + + local strReqName = dq.qname:toString() -- get request dns name + + protobuf:setProtobufResponseType(strReqName) -- set protobuf to look like a response and not a query + + else + local tableTags = {} -- called by testProtobuf() + tableTags["TestLabel2"] = "TestData2" + tableTags["TestLabel1"] = "TestData1" + protobuf:setTagArray(tableTags) -- setTagArray + protobuf:setTag('TestLabel3', 'TestData3') -- setTag + protobuf:setTag("Query", "123") -- setTag + end + end + + function alterLuaFirst(dq) -- called when dnsdist receives new request + + local tt = {} + tt["TestLabel2"] = "TestData2" + tt["TestLabel1"] = "TestData1" + + dq:setTagArray(tt) -- setTagArray + + dq:setTag('TestLabel3', 'TestData3') -- setTag + + return DNSAction.None, "" -- continue to the next rule + end + + + newServer{address="127.0.0.1:%s", useClientSubnet=true} + rl = newRemoteLogger('127.0.0.1:%s') + + addLuaAction(AllRule(), alterLuaFirst) -- Add tags to DNSQuery first + + addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery)) -- Send protobuf message before lookup + + addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true)) -- Send protobuf message after lookup + + """ + + @classmethod + def ProtobufListener(cls, port): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + try: + sock.bind(("127.0.0.1", port)) + except socket.error as e: + print("Error binding in the protbuf listener: %s" % str(e)) + sys.exit(1) + + sock.listen(100) + while True: + (conn, _) = sock.accept() + data = None + while True: + data = conn.recv(2) + if not data: + break + (datalen,) = struct.unpack("!H", data) + data = conn.recv(datalen) + if not data: + break + + cls._protobufQueue.put(data, True, timeout=2.0) + + conn.close() + sock.close() + + @classmethod + def startResponders(cls): + cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort]) + cls._UDPResponder.setDaemon(True) + cls._UDPResponder.start() + + cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort]) + cls._TCPResponder.setDaemon(True) + cls._TCPResponder.start() + + cls._protobufListener = threading.Thread(name='Protobuf Listener', target=cls.ProtobufListener, args=[cls._protobufServerPort]) + cls._protobufListener.setDaemon(True) + cls._protobufListener.start() + + def getFirstProtobufMessage(self): + self.assertFalse(self._protobufQueue.empty()) + data = self._protobufQueue.get(False) + self.assertTrue(data) + msg = dnsmessage_pb2.PBDNSMessage() + msg.ParseFromString(data) + return msg + + def checkProtobufBase(self, msg, protocol, query, initiator): + self.assertTrue(msg) + self.assertTrue(msg.HasField('timeSec')) + self.assertTrue(msg.HasField('socketFamily')) + self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET) + self.assertTrue(msg.HasField('from')) + fromvalue = getattr(msg, 'from') + self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator) + self.assertTrue(msg.HasField('socketProtocol')) + self.assertEquals(msg.socketProtocol, protocol) + self.assertTrue(msg.HasField('messageId')) + self.assertTrue(msg.HasField('id')) + self.assertEquals(msg.id, query.id) + self.assertTrue(msg.HasField('inBytes')) + self.assertEquals(msg.inBytes, len(query.to_wire())) + # dnsdist doesn't set the existing EDNS Subnet for now, + # although it might be set from Lua + # self.assertTrue(msg.HasField('originalRequestorSubnet')) + # self.assertEquals(len(msg.originalRequestorSubnet), 4) + # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1') + + + def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'): + + if initiator == '127.0.0.1': + self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType) # testProtobuf() + else: + self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType) # testLuaProtobuf() + + self.checkProtobufBase(msg, protocol, query, initiator) + # dnsdist doesn't fill the responder field for responses + # because it doesn't keep the information around. + self.assertTrue(msg.HasField('to')) + self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.to), '127.0.0.1') + self.assertTrue(msg.HasField('question')) + self.assertTrue(msg.question.HasField('qClass')) + self.assertEquals(msg.question.qClass, qclass) + self.assertTrue(msg.question.HasField('qType')) + self.assertEquals(msg.question.qClass, qtype) + self.assertTrue(msg.question.HasField('qName')) + self.assertEquals(msg.question.qName, qname) + + + + testList = [u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"] + listx = set(msg.response.tags) ^ set(testList) # only differences will be in new list + self.assertEqual(len(listx), 0, "Lists don't match up in Protobuf Query") # exclusive or of lists should be empty + + def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'): + self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType) + self.checkProtobufBase(msg, protocol, response, initiator) + self.assertTrue(msg.HasField('response')) + self.assertTrue(msg.response.HasField('queryTimeSec')) + + testList = [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"] + listx = set(msg.response.tags) ^ set(testList) # only differences will be in new list + self.assertEqual(len(listx), 0, "List's don't match up in Protobuf Response") # exclusive or of lists should be empty + + def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl): + self.assertTrue(record.HasField('class')) + self.assertEquals(getattr(record, 'class'), rclass) + self.assertTrue(record.HasField('type')) + self.assertEquals(record.type, rtype) + self.assertTrue(record.HasField('name')) + self.assertEquals(record.name, rname) + self.assertTrue(record.HasField('ttl')) + self.assertEquals(record.ttl, rttl) + self.assertTrue(record.HasField('rdata')) + + def testProtobuf(self): + """ + Protobuf: Send data to a protobuf server + """ + name = 'query.protobuf.tests.powerdns.com.' + + + target = 'target.protobuf.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.CNAME, + target) + response.answer.append(rrset) + + rrset = dns.rrset.from_text(target, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + + # let the protobuf messages the time to get there + time.sleep(1) + + # check the protobuf message corresponding to the UDP query + msg = self.getFirstProtobufMessage() + self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name) + + # check the protobuf message corresponding to the UDP response + msg = self.getFirstProtobufMessage() + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response) # check UDP response + self.assertEquals(len(msg.response.rrs), 2) + rr = msg.response.rrs[0] + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600) + self.assertEquals(rr.rdata, target) + rr = msg.response.rrs[1] + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600) + self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1') + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + + # let the protobuf messages the time to get there + time.sleep(1) + + # check the protobuf message corresponding to the TCP query + msg = self.getFirstProtobufMessage() + self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name) + + # check the protobuf message corresponding to the TCP response + msg = self.getFirstProtobufMessage() + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response) # check TCP response + self.assertEquals(len(msg.response.rrs), 2) + rr = msg.response.rrs[0] + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600) + self.assertEquals(rr.rdata, target) + rr = msg.response.rrs[1] + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600) + self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1') + + def testLuaProtobuf(self): + + """ + Protobuf: Check that the Lua callback rewrote the initiator + """ + + name = 'lua.protobuf.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + + # let the protobuf messages the time to get there + time.sleep(1) + + # check the protobuf message corresponding to the UDP query + msg = self.getFirstProtobufMessage() + self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0') + + # check the protobuf message corresponding to the UDP response + msg = self.getFirstProtobufMessage() + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0') # check UDP Response + self.assertEquals(len(msg.response.rrs), 1) + for rr in msg.response.rrs: + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600) + self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1') + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + + # let the protobuf messages the time to get there + time.sleep(1) + + # check the protobuf message corresponding to the TCP query + msg = self.getFirstProtobufMessage() + self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0') + + # check the protobuf message corresponding to the TCP response + msg = self.getFirstProtobufMessage() + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0') # check TCP response + self.assertEquals(len(msg.response.rrs), 1) + for rr in msg.response.rrs: + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600) + self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1') diff --git a/zzz-gca-example/dnsdist.conf b/zzz-gca-example/dnsdist.conf index a79521a6a..a547b2dbd 100644 --- a/zzz-gca-example/dnsdist.conf +++ b/zzz-gca-example/dnsdist.conf @@ -178,7 +178,6 @@ function luaLogForward(dr, pbMsg) - local tableTags = {} -- create a table tableTags["Trans"] = "FWD" -- add transaction type to table @@ -200,7 +199,7 @@ end function luaLogCache(dr, pbMsg) -- this is the lua code that executes after a cache hit - + local tableTags = {} -- create a table tableTags["Trans"] = "CACHE" -- add transaction type to table diff --git a/zzz-gca-example/test-protobuf-tag.sh b/zzz-gca-example/test-protobuf-tag.sh new file mode 100755 index 000000000..9a0103c88 --- /dev/null +++ b/zzz-gca-example/test-protobuf-tag.sh @@ -0,0 +1,14 @@ +cd ../regression-tests.dnsdist +DNSDISTBIN=../pdns/dnsdistdist/dnsdist ./runtests test_ProtobufTag.py + + +echo "-----------------------------------------------------------" +echo "-----------------------------------------------------------" +echo "-----------------------------------------------------------" +echo "-----------------------------------------------------------" +echo "-----------------------------------------------------------" +echo "-----------------------------------------------------------" + +cat nosetests.xml + +