}
});
- g_lua.registerFunction<void(DNSQuestion::*)(std::string, std::string)>("setTag", [](DNSQuestion& dq, const std::string& strLabel, const std::string& strValue) {
+ g_lua.registerFunction<void(DNSQuestion::*)(std::string, std::string)>("setTag", [](DNSQuestion& dq, const std::string& strLabel, const std::string& strValue) {
- if(dq.qTag == NULL)
- {
- dq.qTag = std::shared_ptr<QTag>(new QTag);
- }
+ if(dq.qTag == nullptr) {
+ dq.qTag = std::make_shared<QTag>();
+ }
dq.qTag->add(strLabel, strValue);
});
- g_lua.registerFunction<void(DNSQuestion::*)(vector<pair<string, string>>)>("setTagArray", [](DNSQuestion& dq, const vector<pair<string, string>>&tags) {
+ g_lua.registerFunction<void(DNSQuestion::*)(vector<pair<string, string>>)>("setTagArray", [](DNSQuestion& dq, const vector<pair<string, string>>&tags) {
- if(dq.qTag == NULL)
- {
- dq.qTag = std::shared_ptr<QTag>(new QTag);
- }
+ if(dq.qTag == nullptr) {
+ dq.qTag = std::make_shared<QTag>();
+ }
- for (const auto& tag : tags)
- {
- dq.qTag->add(tag.first, tag.second);
- }
+ for (const auto& tag : tags) {
+ dq.qTag->add(tag.first, tag.second);
+ }
});
- g_lua.registerFunction<string(DNSQuestion::*)(std::string)>("getTagMatch", [](const DNSQuestion& dq, const std::string& strLabel) {
+ g_lua.registerFunction<string(DNSQuestion::*)(std::string)>("getTag", [](const DNSQuestion& dq, const std::string& strLabel) {
std::string strValue;
- if(dq.qTag != NULL)
- {
- strValue = dq.qTag->getMatch(strLabel);
- }
+ if(dq.qTag != nullptr) {
+ strValue = dq.qTag->getMatch(strLabel);
+ }
return strValue;
});
- g_lua.registerFunction<std::unordered_map<string, string>(DNSQuestion::*)(void)>("getTagArray", [](const DNSQuestion& dq) {
+ g_lua.registerFunction<std::unordered_map<string, string>(DNSQuestion::*)(void)>("getTagArray", [](const DNSQuestion& dq) {
- if(dq.qTag != NULL)
- {
- return dq.qTag->tagData;
- }
- else
- {
- std::unordered_map<string, string> XX;
- return(XX);
- }
+ if(dq.qTag != nullptr) {
+ return dq.qTag->tagData;
+ } else {
+ std::unordered_map<string, string> XX;
+ return XX;
+ }
});
#endif
});
- g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(std::string)>("setTag", [](DNSDistProtoBufMessage& message, const std::string& strValue) {
-
+ g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(std::string)>("setTag", [](DNSDistProtoBufMessage& message, const std::string& strValue) {
message.addTag(strValue);
});
- g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(vector<pair<int, string>>)>("setTagArray", [](DNSDistProtoBufMessage& message, const vector<pair<int, string>>&tags) {
-
-
- for (const auto& tag : tags)
- {
- message.addTag(tag.second);
- }
+ g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(vector<pair<int, string>>)>("setTagArray", [](DNSDistProtoBufMessage& message, const vector<pair<int, string>>&tags) {
+ for (const auto& tag : tags) {
+ message.addTag(tag.second);
+ }
});
- g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(boost::optional <time_t> sec, boost::optional <uint> uSec)>("setProtobufResponseType",
+ g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(boost::optional <time_t> sec, boost::optional <uint> uSec)>("setProtobufResponseType",
[](DNSDistProtoBufMessage& message, boost::optional <time_t> sec, boost::optional <uint> uSec) {
-
message.setType(DNSProtoBufMessage::Response);
-
message.setQueryTime(sec?*sec:0, uSec?*uSec:0);
-
});
- g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(const std::string&, uint uType, uint uClass, uint uTTL, vector<pair<int, int>> )>("setProtobufResponseRR", [](DNSDistProtoBufMessage& message,
- const std::string& strQueryName, uint uType, uint uClass, uint uTTL, const vector<pair<int, int>>& blobData) {
-
- size_t blobSize = blobData.size();
-
- unique_ptr<uint8_t[]> ptrBlob (new uint8_t(blobSize));
-
- int jj=0;
- for (const auto& blob : blobData)
- {
- ptrBlob[jj++] = blob.second;
- }
-
- message.addRR(strQueryName, uType, uClass, uTTL, ptrBlob.get(), blobSize);
-
+ g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(const std::string& strQueryName, uint uType, uint uClass, uint uTTL, const std::string& strBlob)>("addResponseRR", [](DNSDistProtoBufMessage& message,
+ const std::string& strQueryName, uint16_t uType, uint uClass, uint32_t uTTL, const std::string& strBlob) {
+ message.addRR(strQueryName, uType, uClass, uTTL, strBlob);
});
-
g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(const Netmask&)>("setEDNSSubnet", [](DNSDistProtoBufMessage& message, const Netmask& subnet) { message.setEDNSSubnet(subnet); });
g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(const DNSName&, uint16_t, uint16_t)>("setQuestion", [](DNSDistProtoBufMessage& message, const DNSName& qname, uint16_t qtype, uint16_t qclass) { message.setQuestion(qname, qtype, qclass); });
g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(size_t)>("setBytes", [](DNSDistProtoBufMessage& message, size_t bytes) { message.setBytes(bytes); });
{
}
-bool add(std::string strLabel, std::string strValue)
+void add(std::string strLabel, std::string strValue)
{
-bool bStatus = true;
- tagData.insert( {strLabel, strValue});
- return(bStatus);
+ tagData.insert( {strLabel, strValue});
+ return;
}
std::string getMatch(const std::string& strLabel) const
{
- std::unordered_map<std::string, std::string>::const_iterator got =tagData.find (strLabel);
- if(got == tagData.end())
- {
- return("");
- }
- else
- {
- return(got->second);
- }
+ std::unordered_map<std::string, std::string>::const_iterator got =tagData.find (strLabel);
+ if(got == tagData.end()) {
+ return "";
+ } else {
+ return got->second;
+ }
}
std::string getEntry(size_t iEntry) const
size_t iCounter = 0;
- for (const auto& itr : tagData)
- {
- iCounter++;
- if(iCounter == iEntry)
- {
- strEntry = itr.first;
- strEntry += strSep;
- strEntry += itr.second;
- break;
- }
- }
+ for (const auto& itr : tagData) {
+ iCounter++;
+ if(iCounter == iEntry) {
+ strEntry = itr.first;
+ strEntry += strSep;
+ strEntry += itr.second;
+ break;
+ }
+ }
- return(strEntry);
+ return strEntry;
}
size_t count() const
{
- return(tagData.size());
+ return tagData.size();
}
std::string dumpString() const
std::string strRet;
- for (const auto& itr : tagData)
- {
- strRet += itr.first;
- strRet += strSep;
- strRet += itr.second;
- strRet += "\n";
- }
- return(strRet);
-
+ for (const auto& itr : tagData) {
+ strRet += itr.first;
+ strRet += strSep;
+ strRet += itr.second;
+ strRet += "\n";
+ }
+ return strRet;
}
#endif /* HAVE_PROTOBUF */
}
-void DNSProtoBufMessage::addRR(const std::string& strName, uint32_t uType, uint32_t uClass, uint32_t uTTL, const uint8_t *ptrBlob, size_t uBlobLen)
+void DNSProtoBufMessage::addRR(const std::string& strName, uint32_t uType, uint32_t uClass, uint32_t uTTL, const std::string& strBlob)
{
#ifdef HAVE_PROTOBUF
PBDNSMessage_DNSResponse_DNSRR* rr = response->add_rrs();
if (rr) {
- string blob;
rr->set_name(strName.c_str());
rr->set_type(uType);
rr->set_class_(uClass);
rr->set_ttl(uTTL);
- rr->set_rdata(ptrBlob, uBlobLen);
+ rr->set_rdata((const uint8_t *) strBlob.c_str(), strBlob.size());
}
#endif /* HAVE_PROTOBUF */
void setRequestorId(const std::string& requestorId);
std::string toDebugString() const;
void addTag(const std::string& strValue);
- void addRR(const std::string& strName, uint32_t utype, uint32_t uClass, uint32_t uTTl, const uint8_t *ptrBlob, size_t uBlobLen);
+ void addRR(const std::string& strName, uint32_t utype, uint32_t uClass, uint32_t uTTl, const std::string& strBlob);
#ifdef HAVE_PROTOBUF
DNSProtoBufMessage(DNSProtoBufMessage::DNSProtoBufMessageType type, const boost::uuids::uuid& uuid, const ComboAddress* requestor, const ComboAddress* responder, const DNSName& domain, int qtype, uint16_t qclass, uint16_t qid, bool isTCP, size_t bytes);
import dns
import dnsmessage_pb2
-class TestProtobuf(DNSDistTest):
+class TestProtobuf(DNSDistTest):
_protobufServerPort = 4242
_protobufQueue = Queue.Queue()
_protobufCounter = 0
_config_params = ['_testServerPort', '_protobufServerPort']
_config_template = """
- luasmn = newSuffixMatchNode()
- luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.'))
-
- function alterProtobuf(dq, protobuf)
- if luasmn:check(dq.qname) then
- requestor = newCA(dq.remoteaddr:toString())
- if requestor:isIPv4() then
- requestor:truncate(24)
+ luasmn = newSuffixMatchNode()
+ luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.'))
+
+ function alterProtobufResponse(dq, protobuf)
+ if luasmn:check(dq.qname) then
+ requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf()
+ if requestor:isIPv4() then
+ requestor:truncate(24)
else
- requestor:truncate(56)
+ requestor:truncate(56)
end
- protobuf:setRequestor(requestor)
+ protobuf:setRequestor(requestor)
+
+ local tableTags = {}
+ table.insert(tableTags, "TestLabel1,TestData1")
+ table.insert(tableTags, "TestLabel2,TestData2")
+
+ protobuf:setTagArray(tableTags)
+
+ protobuf:setTag('TestLabel3,TestData3')
+
+ protobuf:setTag("Response,456")
+ else
+ local tableTags = {} -- called by testProtobuf()
+ table.insert(tableTags, "TestLabel1,TestData1")
+ table.insert(tableTags, "TestLabel2,TestData2")
+ protobuf:setTagArray(tableTags)
+
+ protobuf:setTag('TestLabel3,TestData3')
+
+ protobuf:setTag("Response,456")
+ end
+ end
+
+ function alterProtobufQuery(dq, protobuf)
+ if luasmn:check(dq.qname) then
+ requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf()
+ if requestor:isIPv4() then
+ requestor:truncate(24)
+ else
+ requestor:truncate(56)
+ end
+ protobuf:setRequestor(requestor)
+
+ local tableTags = {}
+ tableTags = dq:getTagArray() -- get table from DNSQuery
+
+ local tablePB = {}
+ for k, v in pairs( tableTags) do
+ table.insert(tablePB, k .. "," .. v)
+ end
+
+ protobuf:setTagArray(tablePB) -- store table in protobuf
+ protobuf:setTag("Query,123") -- add another tag entry in protobuf
+
+ protobuf:setResponseCode(dnsdist.NXDOMAIN) -- set protobuf response code to be NXDOMAIN
+
+ local strReqName = dq.qname:toString() -- get request dns name
+
+ protobuf:setProtobufResponseType() -- set protobuf to look like a response and not a query, with 0 default time
+
+ blobData = '\127' .. '\000' .. '\000' .. '\001' -- 127.0.0.1, note: lua 5.1 can only embed decimal not hex
+ protobuf:addResponseRR(strReqName, 1, 1, 123, blobData) -- add a RR to the protobuf
+
+ else
+
+ local tableTags = {} -- called by testProtobuf()
+ table.insert(tableTags, "TestLabel1,TestData1")
+ table.insert(tableTags, "TestLabel2,TestData2")
+
+ protobuf:setTagArray(tableTags)
+ protobuf:setTag('TestLabel3,TestData3')
+ protobuf:setTag("Query,123")
end
end
+ function alterLuaFirst(dq) -- called when dnsdist receives new request
+ local tt = {}
+ tt["TestLabel1"] = "TestData1"
+ tt["TestLabel2"] = "TestData2"
+
+ dq:setTagArray(tt)
+
+ dq:setTag("TestLabel3","TestData3")
+ return DNSAction.None, "" -- continue to the next rule
+ end
+
+
newServer{address="127.0.0.1:%s", useClientSubnet=true}
rl = newRemoteLogger('127.0.0.1:%s')
- addAction(AllRule(), RemoteLogAction(rl, alterProtobuf))
- addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobuf, true))
+
+ addLuaAction(AllRule(), alterLuaFirst) -- Add tags to DNSQuery first
+
+ addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery)) -- Send protobuf message before lookup
+
+ addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true)) -- Send protobuf message after lookup
+
"""
@classmethod
cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort])
cls._UDPResponder.setDaemon(True)
cls._UDPResponder.start()
+
cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort])
cls._TCPResponder.setDaemon(True)
cls._TCPResponder.start()
self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
self.assertTrue(msg.HasField('from'))
fromvalue = getattr(msg, 'from')
- self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator)
+ self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator)
self.assertTrue(msg.HasField('socketProtocol'))
self.assertEquals(msg.socketProtocol, protocol)
self.assertTrue(msg.HasField('messageId'))
self.assertTrue(msg.HasField('id'))
- self.assertEquals(msg.id, query.id)
+ self.assertEquals(msg.id, query.id)
self.assertTrue(msg.HasField('inBytes'))
- self.assertEquals(msg.inBytes, len(query.to_wire()))
+ self.assertEquals(msg.inBytes, len(query.to_wire()))
# dnsdist doesn't set the existing EDNS Subnet for now,
# although it might be set from Lua
# self.assertTrue(msg.HasField('originalRequestorSubnet'))
# self.assertEquals(len(msg.originalRequestorSubnet), 4)
# self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
+
def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'):
- self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)
+
+ if initiator == '127.0.0.1':
+ self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType) # testProtobuf()
+ else:
+ self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType) # testLuaProtobuf()
+
self.checkProtobufBase(msg, protocol, query, initiator)
# dnsdist doesn't fill the responder field for responses
# because it doesn't keep the information around.
self.assertTrue(msg.question.HasField('qName'))
self.assertEquals(msg.question.qName, qname)
+
+
+ testList = [u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"]
+ listx = set(msg.response.tags) ^ set(testList) # only differences will be in new list
+
+ self.assertEqual(len(listx), 0, "Lists don't match up in Protobuf Query") # exclusive or of lists should be empty
+
def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'):
- self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
+ self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
self.checkProtobufBase(msg, protocol, response, initiator)
- self.assertTrue(msg.HasField('response'))
- self.assertTrue(msg.response.HasField('queryTimeSec'))
+ self.assertTrue(msg.HasField('response'))
+ self.assertTrue(msg.response.HasField('queryTimeSec'))
+
+ testList = [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"]
+ listx = set(msg.response.tags) ^ set(testList) # only differences will be in new list
+ self.assertEqual(len(listx), 0, "List's don't match up in Protobuf Response") # exclusive or of lists should be empty
def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl):
self.assertTrue(record.HasField('class'))
"""
Protobuf: Send data to a protobuf server
"""
- name = 'query.protobuf.tests.powerdns.com.'
+ name = 'query.protobuf.tests.powerdns.com.'
+
+
target = 'target.protobuf.tests.powerdns.com.'
query = dns.message.make_query(name, 'A', 'IN')
response = dns.message.make_response(query)
+
rrset = dns.rrset.from_text(name,
3600,
dns.rdataclass.IN,
dns.rdatatype.CNAME,
target)
response.answer.append(rrset)
+
rrset = dns.rrset.from_text(target,
3600,
dns.rdataclass.IN,
# check the protobuf message corresponding to the UDP query
msg = self.getFirstProtobufMessage()
- self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+
+ self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
# check the protobuf message corresponding to the UDP response
msg = self.getFirstProtobufMessage()
- self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response)
- self.assertEquals(len(msg.response.rrs), 2)
+ self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response) # check UDP response
+ self.assertEquals(len(msg.response.rrs), 2)
rr = msg.response.rrs[0]
self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
self.assertEquals(rr.rdata, target)
# check the protobuf message corresponding to the TCP response
msg = self.getFirstProtobufMessage()
- self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response)
+ self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response) # check TCP response
self.assertEquals(len(msg.response.rrs), 2)
rr = msg.response.rrs[0]
self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
def testLuaProtobuf(self):
+
"""
Protobuf: Check that the Lua callback rewrote the initiator
"""
'127.0.0.1')
response.answer.append(rrset)
+
(receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
- self.assertTrue(receivedQuery)
- self.assertTrue(receivedResponse)
+
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
receivedQuery.id = query.id
self.assertEquals(query, receivedQuery)
self.assertEquals(response, receivedResponse)
+
# let the protobuf messages the time to get there
time.sleep(1)
# check the protobuf message corresponding to the UDP response
msg = self.getFirstProtobufMessage()
- self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')
+ self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0') # check UDP Response
self.assertEquals(len(msg.response.rrs), 1)
for rr in msg.response.rrs:
self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
# check the protobuf message corresponding to the TCP response
msg = self.getFirstProtobufMessage()
- self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')
+ self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0') # check TCP response
self.assertEquals(len(msg.response.rrs), 1)
for rr in msg.response.rrs:
self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
+++ /dev/null
-#!/usr/bin/env python
-import Queue
-import threading
-import socket
-import struct
-import sys
-import time
-from dnsdisttests import DNSDistTest
-
-import dns
-import dnsmessage_pb2
-
-
-class TestProtobuf(DNSDistTest):
- _protobufServerPort = 4242
- _protobufQueue = Queue.Queue()
- _protobufCounter = 0
- _config_params = ['_testServerPort', '_protobufServerPort']
- _config_template = """
- luasmn = newSuffixMatchNode()
- luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.'))
-
- function alterProtobufResponse(dq, protobuf)
- if luasmn:check(dq.qname) then
- requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf()
- if requestor:isIPv4() then
- requestor:truncate(24)
- else
- requestor:truncate(56)
- end
- protobuf:setRequestor(requestor)
-
- local tableTags = {}
- table.insert(tableTags, "TestLabel1,TestData1")
- table.insert(tableTags, "TestLabel2,TestData2")
-
- -- tableTags["TestLabel2"] = "TestData2"
--- tableTags["TestLabel1"] = "TestData1"
- protobuf:setTagArray(tableTags) -- setTagArray
-
- protobuf:setTag('TestLabel3,TestData3') -- setTag
-
- protobuf:setTag("Response,456") -- setTag
- else
- local tableTags = {} -- called by testProtobuf()
- table.insert(tableTags, "TestLabel1,TestData1")
- table.insert(tableTags, "TestLabel2,TestData2")
- protobuf:setTagArray(tableTags) -- setTagArray
-
- protobuf:setTag('TestLabel3,TestData3') -- setTag
-
- protobuf:setTag("Response,456") -- setTag
- end
- end
-
- function alterProtobufQuery(dq, protobuf)
- if luasmn:check(dq.qname) then
- requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf()
- if requestor:isIPv4() then
- requestor:truncate(24)
- else
- requestor:truncate(56)
- end
- protobuf:setRequestor(requestor)
-
- local tableTags = {} -- declare table
- tableTags = dq:getTagArray() -- get table from DNSQuery
-
- local tablePB = {}
- for k, v in pairs( tableTags) do
- table.insert(tablePB, k .. "," .. v)
- end
-
- protobuf:setTagArray(tablePB) -- store table in protobuf
- protobuf:setTag("Query,123") -- add another tag entry in protobuf
-
- protobuf:setResponseCode(dnsdist.NXDOMAIN) -- set protobuf response code to be NXDOMAIN
-
- local strReqName = dq.qname:toString() -- get request dns name
-
- protobuf:setProtobufResponseType() -- set protobuf to look like a response and not a query, with 0 default time
-
- blobData={0x7F, 0x00, 0x00, 0x01} -- 127.0.0.1
- protobuf:setProtobufResponseRR(strReqName, 1, 1, 123, blobData) -- set protobuf to have a RR
-
- else
-
- local tableTags = {} -- called by testProtobuf()
- table.insert(tableTags, "TestLabel1,TestData1")
- table.insert(tableTags, "TestLabel2,TestData2")
-
- protobuf:setTagArray(tableTags) -- setTagArray
- protobuf:setTag('TestLabel3,TestData3') -- setTag
- protobuf:setTag("Query,123") -- setTag
- end
- end
-
- function alterLuaFirst(dq) -- called when dnsdist receives new request
- local tt = {}
- tt["TestLabel1"] = "TestData1"
- tt["TestLabel2"] = "TestData2"
-
- dq:setTagArray(tt) --setTagArray
-
- dq:setTag("TestLabel3","TestData3") -- setTag
- return DNSAction.None, "" -- continue to the next rule
- end
-
-
- newServer{address="127.0.0.1:%s", useClientSubnet=true}
- rl = newRemoteLogger('127.0.0.1:%s')
-
- addLuaAction(AllRule(), alterLuaFirst) -- Add tags to DNSQuery first
-
- addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery)) -- Send protobuf message before lookup
-
- addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true)) -- Send protobuf message after lookup
-
- """
-
- @classmethod
- def ProtobufListener(cls, port):
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
- try:
- sock.bind(("127.0.0.1", port))
- except socket.error as e:
- print("Error binding in the protbuf listener: %s" % str(e))
- sys.exit(1)
-
- sock.listen(100)
- while True:
- (conn, _) = sock.accept()
- data = None
- while True:
- data = conn.recv(2)
- if not data:
- break
- (datalen,) = struct.unpack("!H", data)
- data = conn.recv(datalen)
- if not data:
- break
-
- cls._protobufQueue.put(data, True, timeout=2.0)
-
- conn.close()
- sock.close()
-
- @classmethod
- def startResponders(cls):
- cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort])
- cls._UDPResponder.setDaemon(True)
- cls._UDPResponder.start()
-
- cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort])
- cls._TCPResponder.setDaemon(True)
- cls._TCPResponder.start()
-
- cls._protobufListener = threading.Thread(name='Protobuf Listener', target=cls.ProtobufListener, args=[cls._protobufServerPort])
- cls._protobufListener.setDaemon(True)
- cls._protobufListener.start()
-
- def getFirstProtobufMessage(self):
- self.assertFalse(self._protobufQueue.empty())
- data = self._protobufQueue.get(False)
- self.assertTrue(data)
- msg = dnsmessage_pb2.PBDNSMessage()
- msg.ParseFromString(data)
- return msg
-
- def checkProtobufBase(self, msg, protocol, query, initiator):
- self.assertTrue(msg)
- self.assertTrue(msg.HasField('timeSec'))
- self.assertTrue(msg.HasField('socketFamily'))
- self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
- self.assertTrue(msg.HasField('from'))
- fromvalue = getattr(msg, 'from')
- self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator)
- self.assertTrue(msg.HasField('socketProtocol'))
- self.assertEquals(msg.socketProtocol, protocol)
- self.assertTrue(msg.HasField('messageId'))
- self.assertTrue(msg.HasField('id'))
- self.assertEquals(msg.id, query.id)
- self.assertTrue(msg.HasField('inBytes'))
- self.assertEquals(msg.inBytes, len(query.to_wire()))
- # dnsdist doesn't set the existing EDNS Subnet for now,
- # although it might be set from Lua
- # self.assertTrue(msg.HasField('originalRequestorSubnet'))
- # self.assertEquals(len(msg.originalRequestorSubnet), 4)
- # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
-
-
- def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'):
-
- if initiator == '127.0.0.1':
- self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType) # testProtobuf()
- else:
- self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType) # testLuaProtobuf()
-
- self.checkProtobufBase(msg, protocol, query, initiator)
- # dnsdist doesn't fill the responder field for responses
- # because it doesn't keep the information around.
- self.assertTrue(msg.HasField('to'))
- self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.to), '127.0.0.1')
- self.assertTrue(msg.HasField('question'))
- self.assertTrue(msg.question.HasField('qClass'))
- self.assertEquals(msg.question.qClass, qclass)
- self.assertTrue(msg.question.HasField('qType'))
- self.assertEquals(msg.question.qClass, qtype)
- self.assertTrue(msg.question.HasField('qName'))
- self.assertEquals(msg.question.qName, qname)
-
-
-
- testList = [u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"]
- listx = set(msg.response.tags) ^ set(testList) # only differences will be in new list
-
- self.assertEqual(len(listx), 0, "Lists don't match up in Protobuf Query") # exclusive or of lists should be empty
-
- def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'):
- self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
- self.checkProtobufBase(msg, protocol, response, initiator)
- self.assertTrue(msg.HasField('response'))
- self.assertTrue(msg.response.HasField('queryTimeSec'))
-
- testList = [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"]
- listx = set(msg.response.tags) ^ set(testList) # only differences will be in new list
- self.assertEqual(len(listx), 0, "List's don't match up in Protobuf Response") # exclusive or of lists should be empty
-
- def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl):
- self.assertTrue(record.HasField('class'))
- self.assertEquals(getattr(record, 'class'), rclass)
- self.assertTrue(record.HasField('type'))
- self.assertEquals(record.type, rtype)
- self.assertTrue(record.HasField('name'))
- self.assertEquals(record.name, rname)
- self.assertTrue(record.HasField('ttl'))
- self.assertEquals(record.ttl, rttl)
- self.assertTrue(record.HasField('rdata'))
-
- def testProtobuf(self):
- """
- Protobuf: Send data to a protobuf server
- """
- name = 'query.protobuf.tests.powerdns.com.'
-
-
- target = 'target.protobuf.tests.powerdns.com.'
- query = dns.message.make_query(name, 'A', 'IN')
- response = dns.message.make_response(query)
-
- rrset = dns.rrset.from_text(name,
- 3600,
- dns.rdataclass.IN,
- dns.rdatatype.CNAME,
- target)
- response.answer.append(rrset)
-
- rrset = dns.rrset.from_text(target,
- 3600,
- dns.rdataclass.IN,
- dns.rdatatype.A,
- '127.0.0.1')
- response.answer.append(rrset)
-
- (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
- self.assertTrue(receivedQuery)
- self.assertTrue(receivedResponse)
- receivedQuery.id = query.id
- self.assertEquals(query, receivedQuery)
- self.assertEquals(response, receivedResponse)
-
- # let the protobuf messages the time to get there
- time.sleep(1)
-
- # check the protobuf message corresponding to the UDP query
- msg = self.getFirstProtobufMessage()
-
- self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
-
- # check the protobuf message corresponding to the UDP response
- msg = self.getFirstProtobufMessage()
- self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response) # check UDP response
- self.assertEquals(len(msg.response.rrs), 2)
- rr = msg.response.rrs[0]
- self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
- self.assertEquals(rr.rdata, target)
- rr = msg.response.rrs[1]
- self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
- self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
-
- (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
- self.assertTrue(receivedQuery)
- self.assertTrue(receivedResponse)
- receivedQuery.id = query.id
- self.assertEquals(query, receivedQuery)
- self.assertEquals(response, receivedResponse)
-
- # let the protobuf messages the time to get there
- time.sleep(1)
-
- # check the protobuf message corresponding to the TCP query
- msg = self.getFirstProtobufMessage()
- self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
-
- # check the protobuf message corresponding to the TCP response
- msg = self.getFirstProtobufMessage()
- self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response) # check TCP response
- self.assertEquals(len(msg.response.rrs), 2)
- rr = msg.response.rrs[0]
- self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
- self.assertEquals(rr.rdata, target)
- rr = msg.response.rrs[1]
- self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
- self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
-
- def testLuaProtobuf(self):
-
- """
- Protobuf: Check that the Lua callback rewrote the initiator
- """
- name = 'lua.protobuf.tests.powerdns.com.'
- query = dns.message.make_query(name, 'A', 'IN')
- response = dns.message.make_response(query)
- rrset = dns.rrset.from_text(name,
- 3600,
- dns.rdataclass.IN,
- dns.rdatatype.A,
- '127.0.0.1')
- response.answer.append(rrset)
-
-
- (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-
- self.assertTrue(receivedQuery)
- self.assertTrue(receivedResponse)
- receivedQuery.id = query.id
- self.assertEquals(query, receivedQuery)
- self.assertEquals(response, receivedResponse)
-
-
- # let the protobuf messages the time to get there
- time.sleep(1)
-
- # check the protobuf message corresponding to the UDP query
- msg = self.getFirstProtobufMessage()
- self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0')
-
- # check the protobuf message corresponding to the UDP response
- msg = self.getFirstProtobufMessage()
- self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0') # check UDP Response
- self.assertEquals(len(msg.response.rrs), 1)
- for rr in msg.response.rrs:
- self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
- self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
-
- (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
- self.assertTrue(receivedQuery)
- self.assertTrue(receivedResponse)
- receivedQuery.id = query.id
- self.assertEquals(query, receivedQuery)
- self.assertEquals(response, receivedResponse)
-
- # let the protobuf messages the time to get there
- time.sleep(1)
-
- # check the protobuf message corresponding to the TCP query
- msg = self.getFirstProtobufMessage()
- self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0')
-
- # check the protobuf message corresponding to the TCP response
- msg = self.getFirstProtobufMessage()
- self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0') # check TCP response
- self.assertEquals(len(msg.response.rrs), 1)
- for rr in msg.response.rrs:
- self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
- self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')