self.assertEquals(len(msg.originalRequestorSubnet), 4)
self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
- def checkOutgoingProtobufBase(self, msg, protocol, query, initiator):
+ def checkOutgoingProtobufBase(self, msg, protocol, query, initiator, length=None):
self.assertTrue(msg)
self.assertTrue(msg.HasField('timeSec'))
self.assertTrue(msg.HasField('socketFamily'))
self.assertTrue(msg.HasField('id'))
self.assertNotEquals(msg.id, query.id)
self.assertTrue(msg.HasField('inBytes'))
- # compare inBytes with length of query/response
- self.assertEquals(msg.inBytes, len(query.to_wire()))
+ if length is not None:
+ self.assertEquals(msg.inBytes, length)
+ else:
+ # compare inBytes with length of query/response
+ self.assertEquals(msg.inBytes, len(query.to_wire()))
def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'):
self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)
for tag in msg.response.tags:
self.assertTrue(tag in tags)
- def checkProtobufOutgoingQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'):
+ def checkProtobufOutgoingQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1', length=None):
self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSOutgoingQueryType)
- self.checkOutgoingProtobufBase(msg, protocol, query, initiator)
+ self.checkOutgoingProtobufBase(msg, protocol, query, initiator, length=length)
self.assertTrue(msg.HasField('to'))
self.assertTrue(msg.HasField('question'))
self.assertTrue(msg.question.HasField('qClass'))
self.assertTrue(msg.question.HasField('qName'))
self.assertEquals(msg.question.qName, qname)
- def checkProtobufIncomingResponse(self, msg, protocol, response, initiator='127.0.0.1'):
+ def checkProtobufIncomingResponse(self, msg, protocol, response, initiator='127.0.0.1', length=None):
self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSIncomingResponseType)
- self.checkOutgoingProtobufBase(msg, protocol, response, initiator)
+ self.checkOutgoingProtobufBase(msg, protocol, response, initiator, length=length)
self.assertTrue(msg.HasField('response'))
+ self.assertTrue(msg.response.HasField('rcode'))
self.assertTrue(msg.response.HasField('queryTimeSec'))
+ def checkProtobufIncomingNetworkErrorResponse(self, msg, protocol, response, initiator='127.0.0.1'):
+ self.checkProtobufIncomingResponse(msg, protocol, response, initiator, length=0)
+ self.assertEquals(msg.response.rcode, 65536)
+
@classmethod
def setUpClass(cls):
# check the protobuf messages corresponding to the UDP query and answer
msg = self.getFirstProtobufMessage()
self.checkProtobufOutgoingQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
-# # then the response
-# msg = self.getFirstProtobufMessage()
-# self.checkProtobufIncomingResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+ # then the response
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufIncomingNetworkErrorResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
self.checkNoRemainingMessage()
class OutgoingProtobufNoQueriesTest(TestRecursorProtobuf):
query.flags |= dns.flags.RD
res = self.sendUDPQuery(query)
-# # check the response
-# msg = self.getFirstProtobufMessage()
-# self.checkProtobufIncomingResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
- # let's wait a bit for a potential message to arrive
- time.sleep(2)
+ # check the response
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufIncomingNetworkErrorResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
self.checkNoRemainingMessage()
class ProtobufMasksTest(TestRecursorProtobuf):