From: Seth Ornstein Date: Thu, 6 Jul 2017 15:06:05 +0000 (-0400) Subject: addressed fixes requested by Remi July 3rd X-Git-Tag: dnsdist-1.2.0~39^2~5 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=741ebe08b36032e6254dea516a24451b978c81f2;p=pdns addressed fixes requested by Remi July 3rd --- diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index e1ccacb2d..ee7c438fd 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -1586,54 +1586,47 @@ vector> setupLua(bool client, const std::string& confi } }); - g_lua.registerFunction("setTag", [](DNSQuestion& dq, const std::string& strLabel, const std::string& strValue) { + g_lua.registerFunction("setTag", [](DNSQuestion& dq, const std::string& strLabel, const std::string& strValue) { - if(dq.qTag == NULL) - { - dq.qTag = std::shared_ptr(new QTag); - } + if(dq.qTag == nullptr) { + dq.qTag = std::make_shared(); + } dq.qTag->add(strLabel, strValue); }); - g_lua.registerFunction>)>("setTagArray", [](DNSQuestion& dq, const vector>&tags) { + g_lua.registerFunction>)>("setTagArray", [](DNSQuestion& dq, const vector>&tags) { - if(dq.qTag == NULL) - { - dq.qTag = std::shared_ptr(new QTag); - } + if(dq.qTag == nullptr) { + dq.qTag = std::make_shared(); + } - for (const auto& tag : tags) - { - dq.qTag->add(tag.first, tag.second); - } + for (const auto& tag : tags) { + dq.qTag->add(tag.first, tag.second); + } }); - g_lua.registerFunction("getTagMatch", [](const DNSQuestion& dq, const std::string& strLabel) { + g_lua.registerFunction("getTag", [](const DNSQuestion& dq, const std::string& strLabel) { std::string strValue; - if(dq.qTag != NULL) - { - strValue = dq.qTag->getMatch(strLabel); - } + if(dq.qTag != nullptr) { + strValue = dq.qTag->getMatch(strLabel); + } return strValue; }); - g_lua.registerFunction(DNSQuestion::*)(void)>("getTagArray", [](const DNSQuestion& dq) { + g_lua.registerFunction(DNSQuestion::*)(void)>("getTagArray", [](const DNSQuestion& dq) { - if(dq.qTag != NULL) - { - return dq.qTag->tagData; - } - else - { - std::unordered_map XX; - return(XX); - } + if(dq.qTag != nullptr) { + return dq.qTag->tagData; + } else { + std::unordered_map XX; + return XX; + } }); diff --git a/pdns/dnsdist-lua2.cc b/pdns/dnsdist-lua2.cc index 6d8e9af89..f31046c62 100644 --- a/pdns/dnsdist-lua2.cc +++ b/pdns/dnsdist-lua2.cc @@ -831,47 +831,27 @@ void moreLua(bool client) #endif }); - g_lua.registerFunction("setTag", [](DNSDistProtoBufMessage& message, const std::string& strValue) { - + g_lua.registerFunction("setTag", [](DNSDistProtoBufMessage& message, const std::string& strValue) { message.addTag(strValue); }); - g_lua.registerFunction>)>("setTagArray", [](DNSDistProtoBufMessage& message, const vector>&tags) { - - - for (const auto& tag : tags) - { - message.addTag(tag.second); - } + g_lua.registerFunction>)>("setTagArray", [](DNSDistProtoBufMessage& message, const vector>&tags) { + for (const auto& tag : tags) { + message.addTag(tag.second); + } }); - g_lua.registerFunction sec, boost::optional uSec)>("setProtobufResponseType", + g_lua.registerFunction sec, boost::optional uSec)>("setProtobufResponseType", [](DNSDistProtoBufMessage& message, boost::optional sec, boost::optional uSec) { - message.setType(DNSProtoBufMessage::Response); - message.setQueryTime(sec?*sec:0, uSec?*uSec:0); - }); - g_lua.registerFunction> )>("setProtobufResponseRR", [](DNSDistProtoBufMessage& message, - const std::string& strQueryName, uint uType, uint uClass, uint uTTL, const vector>& blobData) { - - size_t blobSize = blobData.size(); - - unique_ptr ptrBlob (new uint8_t(blobSize)); - - int jj=0; - for (const auto& blob : blobData) - { - ptrBlob[jj++] = blob.second; - } - - message.addRR(strQueryName, uType, uClass, uTTL, ptrBlob.get(), blobSize); - + g_lua.registerFunction("addResponseRR", [](DNSDistProtoBufMessage& message, + const std::string& strQueryName, uint16_t uType, uint uClass, uint32_t uTTL, const std::string& strBlob) { + message.addRR(strQueryName, uType, uClass, uTTL, strBlob); }); - g_lua.registerFunction("setEDNSSubnet", [](DNSDistProtoBufMessage& message, const Netmask& subnet) { message.setEDNSSubnet(subnet); }); g_lua.registerFunction("setQuestion", [](DNSDistProtoBufMessage& message, const DNSName& qname, uint16_t qtype, uint16_t qclass) { message.setQuestion(qname, qtype, qclass); }); g_lua.registerFunction("setBytes", [](DNSDistProtoBufMessage& message, size_t bytes) { message.setBytes(bytes); }); diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index 85832ebbe..bb427a8ba 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -68,26 +68,22 @@ QTag() { } -bool add(std::string strLabel, std::string strValue) +void add(std::string strLabel, std::string strValue) { -bool bStatus = true; - tagData.insert( {strLabel, strValue}); - return(bStatus); + tagData.insert( {strLabel, strValue}); + return; } std::string getMatch(const std::string& strLabel) const { - std::unordered_map::const_iterator got =tagData.find (strLabel); - if(got == tagData.end()) - { - return(""); - } - else - { - return(got->second); - } + std::unordered_map::const_iterator got =tagData.find (strLabel); + if(got == tagData.end()) { + return ""; + } else { + return got->second; + } } std::string getEntry(size_t iEntry) const @@ -96,25 +92,23 @@ std::string strEntry; size_t iCounter = 0; - for (const auto& itr : tagData) - { - iCounter++; - if(iCounter == iEntry) - { - strEntry = itr.first; - strEntry += strSep; - strEntry += itr.second; - break; - } - } + for (const auto& itr : tagData) { + iCounter++; + if(iCounter == iEntry) { + strEntry = itr.first; + strEntry += strSep; + strEntry += itr.second; + break; + } + } - return(strEntry); + return strEntry; } size_t count() const { - return(tagData.size()); + return tagData.size(); } std::string dumpString() const @@ -122,15 +116,13 @@ std::string dumpString() const std::string strRet; - for (const auto& itr : tagData) - { - strRet += itr.first; - strRet += strSep; - strRet += itr.second; - strRet += "\n"; - } - return(strRet); - + for (const auto& itr : tagData) { + strRet += itr.first; + strRet += strSep; + strRet += itr.second; + strRet += "\n"; + } + return strRet; } diff --git a/pdns/protobuf.cc b/pdns/protobuf.cc index 47f02cb64..66c0c8b22 100644 --- a/pdns/protobuf.cc +++ b/pdns/protobuf.cc @@ -110,7 +110,7 @@ void DNSProtoBufMessage::addTag(const std::string& strValue) #endif /* HAVE_PROTOBUF */ } -void DNSProtoBufMessage::addRR(const std::string& strName, uint32_t uType, uint32_t uClass, uint32_t uTTL, const uint8_t *ptrBlob, size_t uBlobLen) +void DNSProtoBufMessage::addRR(const std::string& strName, uint32_t uType, uint32_t uClass, uint32_t uTTL, const std::string& strBlob) { #ifdef HAVE_PROTOBUF @@ -120,12 +120,11 @@ void DNSProtoBufMessage::addRR(const std::string& strName, uint32_t uType, uint3 PBDNSMessage_DNSResponse_DNSRR* rr = response->add_rrs(); if (rr) { - string blob; rr->set_name(strName.c_str()); rr->set_type(uType); rr->set_class_(uClass); rr->set_ttl(uTTL); - rr->set_rdata(ptrBlob, uBlobLen); + rr->set_rdata((const uint8_t *) strBlob.c_str(), strBlob.size()); } #endif /* HAVE_PROTOBUF */ diff --git a/pdns/protobuf.hh b/pdns/protobuf.hh index b31a5a45b..9decfdc24 100644 --- a/pdns/protobuf.hh +++ b/pdns/protobuf.hh @@ -71,7 +71,7 @@ public: void setRequestorId(const std::string& requestorId); std::string toDebugString() const; void addTag(const std::string& strValue); - void addRR(const std::string& strName, uint32_t utype, uint32_t uClass, uint32_t uTTl, const uint8_t *ptrBlob, size_t uBlobLen); + void addRR(const std::string& strName, uint32_t utype, uint32_t uClass, uint32_t uTTl, const std::string& strBlob); #ifdef HAVE_PROTOBUF DNSProtoBufMessage(DNSProtoBufMessage::DNSProtoBufMessageType type, const boost::uuids::uuid& uuid, const ComboAddress* requestor, const ComboAddress* responder, const DNSName& domain, int qtype, uint16_t qclass, uint16_t qid, bool isTCP, size_t bytes); diff --git a/regression-tests.dnsdist/test_Protobuf.py b/regression-tests.dnsdist/test_Protobuf.py index 6bbd4cfb4..ba3ffdf88 100644 --- a/regression-tests.dnsdist/test_Protobuf.py +++ b/regression-tests.dnsdist/test_Protobuf.py @@ -10,32 +10,110 @@ from dnsdisttests import DNSDistTest import dns import dnsmessage_pb2 -class TestProtobuf(DNSDistTest): +class TestProtobuf(DNSDistTest): _protobufServerPort = 4242 _protobufQueue = Queue.Queue() _protobufCounter = 0 _config_params = ['_testServerPort', '_protobufServerPort'] _config_template = """ - luasmn = newSuffixMatchNode() - luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.')) - - function alterProtobuf(dq, protobuf) - if luasmn:check(dq.qname) then - requestor = newCA(dq.remoteaddr:toString()) - if requestor:isIPv4() then - requestor:truncate(24) + luasmn = newSuffixMatchNode() + luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.')) + + function alterProtobufResponse(dq, protobuf) + if luasmn:check(dq.qname) then + requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf() + if requestor:isIPv4() then + requestor:truncate(24) else - requestor:truncate(56) + requestor:truncate(56) end - protobuf:setRequestor(requestor) + protobuf:setRequestor(requestor) + + local tableTags = {} + table.insert(tableTags, "TestLabel1,TestData1") + table.insert(tableTags, "TestLabel2,TestData2") + + protobuf:setTagArray(tableTags) + + protobuf:setTag('TestLabel3,TestData3') + + protobuf:setTag("Response,456") + else + local tableTags = {} -- called by testProtobuf() + table.insert(tableTags, "TestLabel1,TestData1") + table.insert(tableTags, "TestLabel2,TestData2") + protobuf:setTagArray(tableTags) + + protobuf:setTag('TestLabel3,TestData3') + + protobuf:setTag("Response,456") + end + end + + function alterProtobufQuery(dq, protobuf) + if luasmn:check(dq.qname) then + requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf() + if requestor:isIPv4() then + requestor:truncate(24) + else + requestor:truncate(56) + end + protobuf:setRequestor(requestor) + + local tableTags = {} + tableTags = dq:getTagArray() -- get table from DNSQuery + + local tablePB = {} + for k, v in pairs( tableTags) do + table.insert(tablePB, k .. "," .. v) + end + + protobuf:setTagArray(tablePB) -- store table in protobuf + protobuf:setTag("Query,123") -- add another tag entry in protobuf + + protobuf:setResponseCode(dnsdist.NXDOMAIN) -- set protobuf response code to be NXDOMAIN + + local strReqName = dq.qname:toString() -- get request dns name + + protobuf:setProtobufResponseType() -- set protobuf to look like a response and not a query, with 0 default time + + blobData = '\127' .. '\000' .. '\000' .. '\001' -- 127.0.0.1, note: lua 5.1 can only embed decimal not hex + protobuf:addResponseRR(strReqName, 1, 1, 123, blobData) -- add a RR to the protobuf + + else + + local tableTags = {} -- called by testProtobuf() + table.insert(tableTags, "TestLabel1,TestData1") + table.insert(tableTags, "TestLabel2,TestData2") + + protobuf:setTagArray(tableTags) + protobuf:setTag('TestLabel3,TestData3') + protobuf:setTag("Query,123") end end + function alterLuaFirst(dq) -- called when dnsdist receives new request + local tt = {} + tt["TestLabel1"] = "TestData1" + tt["TestLabel2"] = "TestData2" + + dq:setTagArray(tt) + + dq:setTag("TestLabel3","TestData3") + return DNSAction.None, "" -- continue to the next rule + end + + newServer{address="127.0.0.1:%s", useClientSubnet=true} rl = newRemoteLogger('127.0.0.1:%s') - addAction(AllRule(), RemoteLogAction(rl, alterProtobuf)) - addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobuf, true)) + + addLuaAction(AllRule(), alterLuaFirst) -- Add tags to DNSQuery first + + addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery)) -- Send protobuf message before lookup + + addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true)) -- Send protobuf message after lookup + """ @classmethod @@ -71,6 +149,7 @@ class TestProtobuf(DNSDistTest): cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort]) cls._UDPResponder.setDaemon(True) cls._UDPResponder.start() + cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort]) cls._TCPResponder.setDaemon(True) cls._TCPResponder.start() @@ -94,22 +173,28 @@ class TestProtobuf(DNSDistTest): self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET) self.assertTrue(msg.HasField('from')) fromvalue = getattr(msg, 'from') - self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator) + self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator) self.assertTrue(msg.HasField('socketProtocol')) self.assertEquals(msg.socketProtocol, protocol) self.assertTrue(msg.HasField('messageId')) self.assertTrue(msg.HasField('id')) - self.assertEquals(msg.id, query.id) + self.assertEquals(msg.id, query.id) self.assertTrue(msg.HasField('inBytes')) - self.assertEquals(msg.inBytes, len(query.to_wire())) + self.assertEquals(msg.inBytes, len(query.to_wire())) # dnsdist doesn't set the existing EDNS Subnet for now, # although it might be set from Lua # self.assertTrue(msg.HasField('originalRequestorSubnet')) # self.assertEquals(len(msg.originalRequestorSubnet), 4) # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1') + def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'): - self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType) + + if initiator == '127.0.0.1': + self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType) # testProtobuf() + else: + self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType) # testLuaProtobuf() + self.checkProtobufBase(msg, protocol, query, initiator) # dnsdist doesn't fill the responder field for responses # because it doesn't keep the information around. @@ -123,11 +208,22 @@ class TestProtobuf(DNSDistTest): self.assertTrue(msg.question.HasField('qName')) self.assertEquals(msg.question.qName, qname) + + + testList = [u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"] + listx = set(msg.response.tags) ^ set(testList) # only differences will be in new list + + self.assertEqual(len(listx), 0, "Lists don't match up in Protobuf Query") # exclusive or of lists should be empty + def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'): - self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType) + self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType) self.checkProtobufBase(msg, protocol, response, initiator) - self.assertTrue(msg.HasField('response')) - self.assertTrue(msg.response.HasField('queryTimeSec')) + self.assertTrue(msg.HasField('response')) + self.assertTrue(msg.response.HasField('queryTimeSec')) + + testList = [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"] + listx = set(msg.response.tags) ^ set(testList) # only differences will be in new list + self.assertEqual(len(listx), 0, "List's don't match up in Protobuf Response") # exclusive or of lists should be empty def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl): self.assertTrue(record.HasField('class')) @@ -144,16 +240,20 @@ class TestProtobuf(DNSDistTest): """ Protobuf: Send data to a protobuf server """ - name = 'query.protobuf.tests.powerdns.com.' + name = 'query.protobuf.tests.powerdns.com.' + + target = 'target.protobuf.tests.powerdns.com.' query = dns.message.make_query(name, 'A', 'IN') response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, 3600, dns.rdataclass.IN, dns.rdatatype.CNAME, target) response.answer.append(rrset) + rrset = dns.rrset.from_text(target, 3600, dns.rdataclass.IN, @@ -173,12 +273,13 @@ class TestProtobuf(DNSDistTest): # check the protobuf message corresponding to the UDP query msg = self.getFirstProtobufMessage() - self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name) + + self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name) # check the protobuf message corresponding to the UDP response msg = self.getFirstProtobufMessage() - self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response) - self.assertEquals(len(msg.response.rrs), 2) + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response) # check UDP response + self.assertEquals(len(msg.response.rrs), 2) rr = msg.response.rrs[0] self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600) self.assertEquals(rr.rdata, target) @@ -202,7 +303,7 @@ class TestProtobuf(DNSDistTest): # check the protobuf message corresponding to the TCP response msg = self.getFirstProtobufMessage() - self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response) + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response) # check TCP response self.assertEquals(len(msg.response.rrs), 2) rr = msg.response.rrs[0] self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600) @@ -212,6 +313,7 @@ class TestProtobuf(DNSDistTest): self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1') def testLuaProtobuf(self): + """ Protobuf: Check that the Lua callback rewrote the initiator """ @@ -225,13 +327,16 @@ class TestProtobuf(DNSDistTest): '127.0.0.1') response.answer.append(rrset) + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) - self.assertTrue(receivedQuery) - self.assertTrue(receivedResponse) + + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) receivedQuery.id = query.id self.assertEquals(query, receivedQuery) self.assertEquals(response, receivedResponse) + # let the protobuf messages the time to get there time.sleep(1) @@ -241,7 +346,7 @@ class TestProtobuf(DNSDistTest): # check the protobuf message corresponding to the UDP response msg = self.getFirstProtobufMessage() - self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0') + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0') # check UDP Response self.assertEquals(len(msg.response.rrs), 1) for rr in msg.response.rrs: self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600) @@ -263,7 +368,7 @@ class TestProtobuf(DNSDistTest): # check the protobuf message corresponding to the TCP response msg = self.getFirstProtobufMessage() - self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0') + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0') # check TCP response self.assertEquals(len(msg.response.rrs), 1) for rr in msg.response.rrs: self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600) diff --git a/regression-tests.dnsdist/test_ProtobufTag.py b/regression-tests.dnsdist/test_ProtobufTag.py deleted file mode 100644 index bb743a632..000000000 --- a/regression-tests.dnsdist/test_ProtobufTag.py +++ /dev/null @@ -1,377 +0,0 @@ -#!/usr/bin/env python -import Queue -import threading -import socket -import struct -import sys -import time -from dnsdisttests import DNSDistTest - -import dns -import dnsmessage_pb2 - - -class TestProtobuf(DNSDistTest): - _protobufServerPort = 4242 - _protobufQueue = Queue.Queue() - _protobufCounter = 0 - _config_params = ['_testServerPort', '_protobufServerPort'] - _config_template = """ - luasmn = newSuffixMatchNode() - luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.')) - - function alterProtobufResponse(dq, protobuf) - if luasmn:check(dq.qname) then - requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf() - if requestor:isIPv4() then - requestor:truncate(24) - else - requestor:truncate(56) - end - protobuf:setRequestor(requestor) - - local tableTags = {} - table.insert(tableTags, "TestLabel1,TestData1") - table.insert(tableTags, "TestLabel2,TestData2") - - -- tableTags["TestLabel2"] = "TestData2" --- tableTags["TestLabel1"] = "TestData1" - protobuf:setTagArray(tableTags) -- setTagArray - - protobuf:setTag('TestLabel3,TestData3') -- setTag - - protobuf:setTag("Response,456") -- setTag - else - local tableTags = {} -- called by testProtobuf() - table.insert(tableTags, "TestLabel1,TestData1") - table.insert(tableTags, "TestLabel2,TestData2") - protobuf:setTagArray(tableTags) -- setTagArray - - protobuf:setTag('TestLabel3,TestData3') -- setTag - - protobuf:setTag("Response,456") -- setTag - end - end - - function alterProtobufQuery(dq, protobuf) - if luasmn:check(dq.qname) then - requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf() - if requestor:isIPv4() then - requestor:truncate(24) - else - requestor:truncate(56) - end - protobuf:setRequestor(requestor) - - local tableTags = {} -- declare table - tableTags = dq:getTagArray() -- get table from DNSQuery - - local tablePB = {} - for k, v in pairs( tableTags) do - table.insert(tablePB, k .. "," .. v) - end - - protobuf:setTagArray(tablePB) -- store table in protobuf - protobuf:setTag("Query,123") -- add another tag entry in protobuf - - protobuf:setResponseCode(dnsdist.NXDOMAIN) -- set protobuf response code to be NXDOMAIN - - local strReqName = dq.qname:toString() -- get request dns name - - protobuf:setProtobufResponseType() -- set protobuf to look like a response and not a query, with 0 default time - - blobData={0x7F, 0x00, 0x00, 0x01} -- 127.0.0.1 - protobuf:setProtobufResponseRR(strReqName, 1, 1, 123, blobData) -- set protobuf to have a RR - - else - - local tableTags = {} -- called by testProtobuf() - table.insert(tableTags, "TestLabel1,TestData1") - table.insert(tableTags, "TestLabel2,TestData2") - - protobuf:setTagArray(tableTags) -- setTagArray - protobuf:setTag('TestLabel3,TestData3') -- setTag - protobuf:setTag("Query,123") -- setTag - end - end - - function alterLuaFirst(dq) -- called when dnsdist receives new request - local tt = {} - tt["TestLabel1"] = "TestData1" - tt["TestLabel2"] = "TestData2" - - dq:setTagArray(tt) --setTagArray - - dq:setTag("TestLabel3","TestData3") -- setTag - return DNSAction.None, "" -- continue to the next rule - end - - - newServer{address="127.0.0.1:%s", useClientSubnet=true} - rl = newRemoteLogger('127.0.0.1:%s') - - addLuaAction(AllRule(), alterLuaFirst) -- Add tags to DNSQuery first - - addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery)) -- Send protobuf message before lookup - - addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true)) -- Send protobuf message after lookup - - """ - - @classmethod - def ProtobufListener(cls, port): - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - try: - sock.bind(("127.0.0.1", port)) - except socket.error as e: - print("Error binding in the protbuf listener: %s" % str(e)) - sys.exit(1) - - sock.listen(100) - while True: - (conn, _) = sock.accept() - data = None - while True: - data = conn.recv(2) - if not data: - break - (datalen,) = struct.unpack("!H", data) - data = conn.recv(datalen) - if not data: - break - - cls._protobufQueue.put(data, True, timeout=2.0) - - conn.close() - sock.close() - - @classmethod - def startResponders(cls): - cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort]) - cls._UDPResponder.setDaemon(True) - cls._UDPResponder.start() - - cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort]) - cls._TCPResponder.setDaemon(True) - cls._TCPResponder.start() - - cls._protobufListener = threading.Thread(name='Protobuf Listener', target=cls.ProtobufListener, args=[cls._protobufServerPort]) - cls._protobufListener.setDaemon(True) - cls._protobufListener.start() - - def getFirstProtobufMessage(self): - self.assertFalse(self._protobufQueue.empty()) - data = self._protobufQueue.get(False) - self.assertTrue(data) - msg = dnsmessage_pb2.PBDNSMessage() - msg.ParseFromString(data) - return msg - - def checkProtobufBase(self, msg, protocol, query, initiator): - self.assertTrue(msg) - self.assertTrue(msg.HasField('timeSec')) - self.assertTrue(msg.HasField('socketFamily')) - self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET) - self.assertTrue(msg.HasField('from')) - fromvalue = getattr(msg, 'from') - self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator) - self.assertTrue(msg.HasField('socketProtocol')) - self.assertEquals(msg.socketProtocol, protocol) - self.assertTrue(msg.HasField('messageId')) - self.assertTrue(msg.HasField('id')) - self.assertEquals(msg.id, query.id) - self.assertTrue(msg.HasField('inBytes')) - self.assertEquals(msg.inBytes, len(query.to_wire())) - # dnsdist doesn't set the existing EDNS Subnet for now, - # although it might be set from Lua - # self.assertTrue(msg.HasField('originalRequestorSubnet')) - # self.assertEquals(len(msg.originalRequestorSubnet), 4) - # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1') - - - def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'): - - if initiator == '127.0.0.1': - self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType) # testProtobuf() - else: - self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType) # testLuaProtobuf() - - self.checkProtobufBase(msg, protocol, query, initiator) - # dnsdist doesn't fill the responder field for responses - # because it doesn't keep the information around. - self.assertTrue(msg.HasField('to')) - self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.to), '127.0.0.1') - self.assertTrue(msg.HasField('question')) - self.assertTrue(msg.question.HasField('qClass')) - self.assertEquals(msg.question.qClass, qclass) - self.assertTrue(msg.question.HasField('qType')) - self.assertEquals(msg.question.qClass, qtype) - self.assertTrue(msg.question.HasField('qName')) - self.assertEquals(msg.question.qName, qname) - - - - testList = [u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"] - listx = set(msg.response.tags) ^ set(testList) # only differences will be in new list - - self.assertEqual(len(listx), 0, "Lists don't match up in Protobuf Query") # exclusive or of lists should be empty - - def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'): - self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType) - self.checkProtobufBase(msg, protocol, response, initiator) - self.assertTrue(msg.HasField('response')) - self.assertTrue(msg.response.HasField('queryTimeSec')) - - testList = [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"] - listx = set(msg.response.tags) ^ set(testList) # only differences will be in new list - self.assertEqual(len(listx), 0, "List's don't match up in Protobuf Response") # exclusive or of lists should be empty - - def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl): - self.assertTrue(record.HasField('class')) - self.assertEquals(getattr(record, 'class'), rclass) - self.assertTrue(record.HasField('type')) - self.assertEquals(record.type, rtype) - self.assertTrue(record.HasField('name')) - self.assertEquals(record.name, rname) - self.assertTrue(record.HasField('ttl')) - self.assertEquals(record.ttl, rttl) - self.assertTrue(record.HasField('rdata')) - - def testProtobuf(self): - """ - Protobuf: Send data to a protobuf server - """ - name = 'query.protobuf.tests.powerdns.com.' - - - target = 'target.protobuf.tests.powerdns.com.' - query = dns.message.make_query(name, 'A', 'IN') - response = dns.message.make_response(query) - - rrset = dns.rrset.from_text(name, - 3600, - dns.rdataclass.IN, - dns.rdatatype.CNAME, - target) - response.answer.append(rrset) - - rrset = dns.rrset.from_text(target, - 3600, - dns.rdataclass.IN, - dns.rdatatype.A, - '127.0.0.1') - response.answer.append(rrset) - - (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) - self.assertTrue(receivedQuery) - self.assertTrue(receivedResponse) - receivedQuery.id = query.id - self.assertEquals(query, receivedQuery) - self.assertEquals(response, receivedResponse) - - # let the protobuf messages the time to get there - time.sleep(1) - - # check the protobuf message corresponding to the UDP query - msg = self.getFirstProtobufMessage() - - self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name) - - # check the protobuf message corresponding to the UDP response - msg = self.getFirstProtobufMessage() - self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response) # check UDP response - self.assertEquals(len(msg.response.rrs), 2) - rr = msg.response.rrs[0] - self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600) - self.assertEquals(rr.rdata, target) - rr = msg.response.rrs[1] - self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600) - self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1') - - (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) - self.assertTrue(receivedQuery) - self.assertTrue(receivedResponse) - receivedQuery.id = query.id - self.assertEquals(query, receivedQuery) - self.assertEquals(response, receivedResponse) - - # let the protobuf messages the time to get there - time.sleep(1) - - # check the protobuf message corresponding to the TCP query - msg = self.getFirstProtobufMessage() - self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name) - - # check the protobuf message corresponding to the TCP response - msg = self.getFirstProtobufMessage() - self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response) # check TCP response - self.assertEquals(len(msg.response.rrs), 2) - rr = msg.response.rrs[0] - self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600) - self.assertEquals(rr.rdata, target) - rr = msg.response.rrs[1] - self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600) - self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1') - - def testLuaProtobuf(self): - - """ - Protobuf: Check that the Lua callback rewrote the initiator - """ - name = 'lua.protobuf.tests.powerdns.com.' - query = dns.message.make_query(name, 'A', 'IN') - response = dns.message.make_response(query) - rrset = dns.rrset.from_text(name, - 3600, - dns.rdataclass.IN, - dns.rdatatype.A, - '127.0.0.1') - response.answer.append(rrset) - - - (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) - - self.assertTrue(receivedQuery) - self.assertTrue(receivedResponse) - receivedQuery.id = query.id - self.assertEquals(query, receivedQuery) - self.assertEquals(response, receivedResponse) - - - # let the protobuf messages the time to get there - time.sleep(1) - - # check the protobuf message corresponding to the UDP query - msg = self.getFirstProtobufMessage() - self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0') - - # check the protobuf message corresponding to the UDP response - msg = self.getFirstProtobufMessage() - self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0') # check UDP Response - self.assertEquals(len(msg.response.rrs), 1) - for rr in msg.response.rrs: - self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600) - self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1') - - (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) - self.assertTrue(receivedQuery) - self.assertTrue(receivedResponse) - receivedQuery.id = query.id - self.assertEquals(query, receivedQuery) - self.assertEquals(response, receivedResponse) - - # let the protobuf messages the time to get there - time.sleep(1) - - # check the protobuf message corresponding to the TCP query - msg = self.getFirstProtobufMessage() - self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0') - - # check the protobuf message corresponding to the TCP response - msg = self.getFirstProtobufMessage() - self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0') # check TCP response - self.assertEquals(len(msg.response.rrs), 1) - for rr in msg.response.rrs: - self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600) - self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')