_protobufCounter = 0
_config_params = ['_testServerPort', '_protobufServerPort']
_config_template = """
+ luasmn = newSuffixMatchNode()
+ luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.'))
+
+ function alterProtobuf(dq, protobuf)
+ if luasmn:check(dq.qname) then
+ requestor = newCA(dq.remoteaddr:toString())
+ if requestor:isIPv4() then
+ requestor:truncate(24)
+ else
+ requestor:truncate(56)
+ end
+ protobuf:setRequestor(requestor)
+ end
+ end
+
newServer{address="127.0.0.1:%s", useClientSubnet=true}
rl = newRemoteLogger('127.0.0.1:%s')
- addAction(AllRule(), RemoteLogAction(rl))
- addResponseAction(AllRule(), RemoteLogResponseAction(rl))
+ addAction(AllRule(), RemoteLogAction(rl, alterProtobuf))
+ addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobuf))
"""
@classmethod
msg.ParseFromString(data)
return msg
- def checkProtobufBase(self, msg, protocol, query):
+ def checkProtobufBase(self, msg, protocol, query, initiator):
self.assertTrue(msg)
self.assertTrue(msg.HasField('timeSec'))
self.assertTrue(msg.HasField('socketFamily'))
self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
self.assertTrue(msg.HasField('from'))
fromvalue = getattr(msg, 'from')
- self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), '127.0.0.1')
+ self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator)
self.assertTrue(msg.HasField('socketProtocol'))
self.assertEquals(msg.socketProtocol, protocol)
self.assertTrue(msg.HasField('messageId'))
# self.assertEquals(len(msg.originalRequestorSubnet), 4)
# self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
- def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname):
+ def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'):
self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)
- self.checkProtobufBase(msg, protocol, query)
+ self.checkProtobufBase(msg, protocol, query, initiator)
# dnsdist doesn't fill the responder field for responses
# because it doesn't keep the information around.
self.assertTrue(msg.HasField('to'))
self.assertTrue(msg.question.HasField('qName'))
self.assertEquals(msg.question.qName, qname)
- def checkProtobufResponse(self, msg, protocol, response):
+ def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'):
self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
- self.checkProtobufBase(msg, protocol, response)
+ self.checkProtobufBase(msg, protocol, response, initiator)
self.assertTrue(msg.HasField('response'))
self.assertTrue(msg.response.HasField('queryTimeSec'))
+ def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl):
+ self.assertTrue(record.HasField('class'))
+ self.assertEquals(getattr(record, 'class'), rclass)
+ self.assertTrue(record.HasField('type'))
+ self.assertEquals(record.type, rtype)
+ self.assertTrue(record.HasField('name'))
+ self.assertEquals(record.name, rname)
+ self.assertTrue(record.HasField('ttl'))
+ self.assertEquals(record.ttl, rttl)
+ self.assertTrue(record.HasField('rdata'))
+
def testProtobuf(self):
"""
Protobuf: Send data to a protobuf server
self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response)
self.assertEquals(len(msg.response.rrs), 1)
for rr in msg.response.rrs:
- self.assertTrue(rr.HasField('class'))
- self.assertEquals(getattr(rr, 'class'), dns.rdataclass.IN)
- self.assertTrue(rr.HasField('type'))
- self.assertEquals(rr.type, dns.rdatatype.A)
- self.assertTrue(rr.HasField('name'))
- self.assertEquals(rr.name, name)
- self.assertTrue(rr.HasField('ttl'))
- self.assertEquals(rr.ttl, 3600)
- self.assertTrue(rr.HasField('rdata'))
+ self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
(receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response)
self.assertEquals(len(msg.response.rrs), 1)
for rr in msg.response.rrs:
- self.assertTrue(rr.HasField('class'))
- self.assertEquals(getattr(rr, 'class'), dns.rdataclass.IN)
- self.assertTrue(rr.HasField('type'))
- self.assertEquals(rr.type, dns.rdatatype.A)
- self.assertTrue(rr.HasField('name'))
- self.assertEquals(rr.name, name)
- self.assertTrue(rr.HasField('ttl'))
- self.assertEquals(rr.ttl, 3600)
- self.assertTrue(rr.HasField('rdata'))
+ self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
+ self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
+
+ def testLuaProtobuf(self):
+ """
+ Protobuf: Check that the Lua callback rewrote the initiator
+ """
+ name = 'lua.protobuf.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+
+ (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(response, receivedResponse)
+
+ # let the protobuf messages the time to get there
+ time.sleep(1)
+
+ # check the protobuf message corresponding to the UDP query
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0')
+
+ # check the protobuf message corresponding to the UDP response
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')
+ self.assertEquals(len(msg.response.rrs), 1)
+ for rr in msg.response.rrs:
+ self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
+ self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
+
+ (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(response, receivedResponse)
+
+ # let the protobuf messages the time to get there
+ time.sleep(1)
+
+ # check the protobuf message corresponding to the TCP query
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0')
+
+ # check the protobuf message corresponding to the TCP response
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')
+ self.assertEquals(len(msg.response.rrs), 1)
+ for rr in msg.response.rrs:
+ self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')