]> granicus.if.org Git - pdns/commitdiff
Fixed requests from Remi July 7
authorSeth Ornstein <sornstein@globalcyberalliance.org>
Wed, 12 Jul 2017 19:14:01 +0000 (15:14 -0400)
committerNick Saika <nsaika@globalcyberalliance.org>
Wed, 2 Aug 2017 15:40:52 +0000 (11:40 -0400)
pdns/dnsdist-lua2.cc
pdns/dnsdist.hh
pdns/protobuf.cc
pdns/protobuf.hh
regression-tests.dnsdist/test_Protobuf.py

index f31046c628702024a6c9b49d1f58dff5d6b6b85b..0c7639205eea1f1ca038de83ee2a24a08aab6b3f 100644 (file)
@@ -841,14 +841,14 @@ void moreLua(bool client)
       }
      });
 
-    g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(boost::optional <time_t> sec, boost::optional <uint> uSec)>("setProtobufResponseType",
-                                        [](DNSDistProtoBufMessage& message, boost::optional <time_t> sec, boost::optional <uint> uSec) {
+    g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(boost::optional <time_t> sec, boost::optional <uint32_t> uSec)>("setProtobufResponseType",
+                                        [](DNSDistProtoBufMessage& message, boost::optional <time_t> sec, boost::optional <uint32_t> uSec) {
         message.setType(DNSProtoBufMessage::Response);
         message.setQueryTime(sec?*sec:0, uSec?*uSec:0);
      });
 
-    g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(const std::string& strQueryName, uint uType, uint uClass, uint uTTL, const std::string& strBlob)>("addResponseRR", [](DNSDistProtoBufMessage& message,
-                                                            const std::string& strQueryName, uint16_t uType, uint uClass, uint32_t uTTL, const std::string& strBlob) {
+    g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(const std::string& strQueryName, uint16_t uType, uint16_t uClass, uint32_t uTTL, const std::string& strBlob)>("addResponseRR", [](DNSDistProtoBufMessage& message,
+                                                            const std::string& strQueryName, uint16_t uType, uint16_t uClass, uint32_t uTTL, const std::string& strBlob) {
         message.addRR(strQueryName, uType, uClass, uTTL, strBlob);
      });
 
index bb427a8baa4a1bcd1847e32e5750216919d00c0f..cbd3ebdd950a44f31d3fce87f1a8e1c9bdbb9310 100644 (file)
@@ -70,14 +70,12 @@ QTag()
 
 void add(std::string strLabel, std::string strValue)
 {
-
   tagData.insert( {strLabel, strValue});
   return;
 }
 
 std::string getMatch(const std::string& strLabel)  const
 {
-
   std::unordered_map<std::string, std::string>::const_iterator got =tagData.find (strLabel);
   if(got == tagData.end()) {
     return "";
@@ -88,9 +86,8 @@ std::string getMatch(const std::string& strLabel)  const
 
 std::string getEntry(size_t iEntry) const
 {
-std::string strEntry;
-size_t iCounter = 0;
-
+   std::string strEntry;
+   size_t iCounter = 0;
 
   for (const auto& itr : tagData) {
     iCounter++;
@@ -113,8 +110,7 @@ size_t count() const
 
 std::string dumpString() const
 {
-std::string strRet;
-
+  std::string strRet;
 
   for (const auto& itr : tagData) {
     strRet += itr.first;
@@ -127,11 +123,10 @@ std::string strRet;
 }
 
 public:
-    std::unordered_map<std::string, std::string>tagData;
+  std::unordered_map<std::string, std::string>tagData;
 
 private:
-
-    const char *strSep = "\t";
+  const char *strSep = "\t";
 };
 
 
index 66c0c8b22e81fa6b2f8ec3981c5ab0e89513eb11..a2acc254b7dae9a31c6ffe77f47e144c7771a159 100644 (file)
@@ -110,22 +110,21 @@ void DNSProtoBufMessage::addTag(const std::string& strValue)
 #endif /* HAVE_PROTOBUF */
 }
 
-void DNSProtoBufMessage::addRR(const std::string& strName, uint32_t uType, uint32_t uClass, uint32_t uTTL, const std::string& strBlob)
+void DNSProtoBufMessage::addRR(const std::string& strName, uint16_t uType, uint16_t uClass, uint32_t uTTL, const std::string& strBlob)
 {
 #ifdef HAVE_PROTOBUF
 
   PBDNSMessage_DNSResponse* response = d_message.mutable_response();
   if (!response)
     return;
-
   PBDNSMessage_DNSResponse_DNSRR* rr = response->add_rrs();
-      if (rr) {
-        rr->set_name(strName.c_str());
-        rr->set_type(uType);
-        rr->set_class_(uClass);
-        rr->set_ttl(uTTL);
-        rr->set_rdata((const uint8_t *) strBlob.c_str(), strBlob.size());
-      }
+  if (rr) {
+    rr->set_name(strName.c_str());
+    rr->set_type(uType);
+    rr->set_class_(uClass);
+    rr->set_ttl(uTTL);
+    rr->set_rdata((const uint8_t *) strBlob.c_str(), strBlob.size());
+  }
 
 #endif /* HAVE_PROTOBUF */
 }
index 9decfdc24110928d97edffbcdb2454a130f90dd1..0d32a2f90f33bbe4741de907f8fbf74c58d56282 100644 (file)
@@ -71,7 +71,7 @@ public:
   void setRequestorId(const std::string& requestorId);
   std::string toDebugString() const;
   void addTag(const std::string& strValue);
-  void addRR(const std::string& strName, uint32_t utype, uint32_t uClass, uint32_t uTTl, const std::string& strBlob);
+  void addRR(const std::string& strName, uint16_t utype, uint16_t uClass, uint32_t uTTl, const std::string& strBlob);
 
 #ifdef HAVE_PROTOBUF
   DNSProtoBufMessage(DNSProtoBufMessage::DNSProtoBufMessageType type, const boost::uuids::uuid& uuid, const ComboAddress* requestor, const ComboAddress* responder, const DNSName& domain, int qtype, uint16_t qclass, uint16_t qid, bool isTCP, size_t bytes);
index ba3ffdf88c4a526ba995bdf3d501de748e50a23b..fd9163c6fe0616cac497f1dc1ab260fbbc0e624b 100644 (file)
@@ -17,51 +17,54 @@ class TestProtobuf(DNSDistTest):
     _protobufCounter = 0
     _config_params = ['_testServerPort', '_protobufServerPort']
     _config_template = """
-    luasmn = newSuffixMatchNode()                              
-    luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.')) 
-
-    function alterProtobufResponse(dq, protobuf)               
-      if luasmn:check(dq.qname) then                           
-        requestor = newCA(dq.remoteaddr:toString())            -- called by testLuaProtobuf()  
-        if requestor:isIPv4() then                             
-          requestor:truncate(24)                               
+    luasmn = newSuffixMatchNode()
+    luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.'))
+
+    function alterProtobufResponse(dq, protobuf)
+      if luasmn:check(dq.qname) then
+        requestor = newCA(dq.remoteaddr:toString())            -- called by testLuaProtobuf()
+        if requestor:isIPv4() then
+          requestor:truncate(24)
         else
-          requestor:truncate(56)                               
+          requestor:truncate(56)
         end
-        protobuf:setRequestor(requestor)                       
+        protobuf:setRequestor(requestor)
 
-       local tableTags = {}    
+       local tableTags = {}
        table.insert(tableTags, "TestLabel1,TestData1")
        table.insert(tableTags, "TestLabel2,TestData2")
-                                 
-       protobuf:setTagArray(tableTags)                          
 
-       protobuf:setTag('TestLabel3,TestData3')                 
+       protobuf:setTagArray(tableTags)
+
+       protobuf:setTag('TestLabel3,TestData3')
+
+       protobuf:setTag("Response,456")
 
-       protobuf:setTag("Response,456")                         
-      else                                                     
-       local tableTags = {}                                    -- called by testProtobuf()                                   
+      else
+       local tableTags = {}                                    -- called by testProtobuf()
        table.insert(tableTags, "TestLabel1,TestData1")
        table.insert(tableTags, "TestLabel2,TestData2")
-       protobuf:setTagArray(tableTags)                         
+       protobuf:setTagArray(tableTags)
+
+       protobuf:setTag('TestLabel3,TestData3')
 
-       protobuf:setTag('TestLabel3,TestData3')                 
+       protobuf:setTag("Response,456")
 
-       protobuf:setTag("Response,456")                         
       end
     end
 
-    function alterProtobufQuery(dq, protobuf)                  
-      if luasmn:check(dq.qname) then                           
+    function alterProtobufQuery(dq, protobuf)
+
+      if luasmn:check(dq.qname) then
         requestor = newCA(dq.remoteaddr:toString())            -- called by testLuaProtobuf()
-        if requestor:isIPv4() then                             
-          requestor:truncate(24)                               
+        if requestor:isIPv4() then
+          requestor:truncate(24)
         else
-          requestor:truncate(56)                               
+          requestor:truncate(56)
         end
-        protobuf:setRequestor(requestor)                       
+        protobuf:setRequestor(requestor)
 
-       local tableTags = {}                                    
+       local tableTags = {}
        tableTags = dq:getTagArray()                            -- get table from DNSQuery
 
        local tablePB = {}
@@ -79,28 +82,32 @@ class TestProtobuf(DNSDistTest):
        protobuf:setProtobufResponseType()                      -- set protobuf to look like a response and not a query, with 0 default time
 
        blobData = '\127' .. '\000' .. '\000' .. '\001'         -- 127.0.0.1, note: lua 5.1 can only embed decimal not hex
-       protobuf:addResponseRR(strReqName, 1, 1, 123, blobData)  -- add a RR to the protobuf
-                                                               
+
+       protobuf:addResponseRR(strReqName, 1, 1, 123, blobData) -- add a RR to the protobuf
+
+       protobuf:setBytes(65)                                   -- set the size of the query to confirm in checkProtobufBase
+
       else
 
        local tableTags = {}                                    -- called by testProtobuf()
        table.insert(tableTags, "TestLabel1,TestData1")
        table.insert(tableTags, "TestLabel2,TestData2")
 
-       protobuf:setTagArray(tableTags)                         
-       protobuf:setTag('TestLabel3,TestData3')                 
-       protobuf:setTag("Query,123")                            
+       protobuf:setTagArray(tableTags)
+       protobuf:setTag('TestLabel3,TestData3')
+       protobuf:setTag("Query,123")
+
       end
     end
 
     function alterLuaFirst(dq)                                 -- called when dnsdist receives new request
-       local tt = {}                   
-        tt["TestLabel1"] = "TestData1"                  
-        tt["TestLabel2"] = "TestData2"                  
+       local tt = {}
+        tt["TestLabel1"] = "TestData1"
+        tt["TestLabel2"] = "TestData2"
 
-       dq:setTagArray(tt)                                      
+       dq:setTagArray(tt)
 
-       dq:setTag("TestLabel3","TestData3")                     
+       dq:setTag("TestLabel3","TestData3")
        return DNSAction.None, ""                               -- continue to the next rule
     end
 
@@ -166,21 +173,22 @@ class TestProtobuf(DNSDistTest):
         msg.ParseFromString(data)
         return msg
 
-    def checkProtobufBase(self, msg, protocol, query, initiator):
+    def checkProtobufBase(self, msg, protocol, query, initiator, normalQueryResponse=True):
         self.assertTrue(msg)
         self.assertTrue(msg.HasField('timeSec'))
         self.assertTrue(msg.HasField('socketFamily'))
         self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
         self.assertTrue(msg.HasField('from'))
         fromvalue = getattr(msg, 'from')
-        self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator)              
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator)
         self.assertTrue(msg.HasField('socketProtocol'))
         self.assertEquals(msg.socketProtocol, protocol)
         self.assertTrue(msg.HasField('messageId'))
         self.assertTrue(msg.HasField('id'))
-        self.assertEquals(msg.id, query.id)                                                    
+        self.assertEquals(msg.id, query.id)
         self.assertTrue(msg.HasField('inBytes'))
-        self.assertEquals(msg.inBytes, len(query.to_wire()))           
+       if normalQueryResponse:
+               self.assertEquals(msg.inBytes, len(query.to_wire()))        # compare inBytes with length of query/response
         # dnsdist doesn't set the existing EDNS Subnet for now,
         # although it might be set from Lua
         # self.assertTrue(msg.HasField('originalRequestorSubnet'))
@@ -189,12 +197,7 @@ class TestProtobuf(DNSDistTest):
 
 
     def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'):
-
-       if initiator == '127.0.0.1':
-               self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)           # testProtobuf()
-       else:
-               self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)        # testLuaProtobuf()
-
+       self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)           # testProtobuf()
         self.checkProtobufBase(msg, protocol, query, initiator)
         # dnsdist doesn't fill the responder field for responses
         # because it doesn't keep the information around.
@@ -208,18 +211,25 @@ class TestProtobuf(DNSDistTest):
         self.assertTrue(msg.question.HasField('qName'))
         self.assertEquals(msg.question.qName, qname)
 
-
-
        testList = [u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"]
        listx = set(msg.response.tags) ^ set(testList)                                          # only differences will be in new list
-
        self.assertEqual(len(listx), 0, "Lists don't match up in Protobuf Query")               # exclusive or of lists should be empty
 
+    def checkProtobufQueryConvertedToResponse(self, msg, protocol, response, initiator='127.0.0.0'):
+        self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
+        self.checkProtobufBase(msg, protocol, response, initiator, False)      # skip comparing inBytes with length of response as was query not response originally
+        self.assertTrue(msg.HasField('response'))
+        self.assertTrue(msg.response.HasField('queryTimeSec'))
+
+       testList = [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"]
+       listx = set(msg.response.tags) ^ set(testList)                                          # only differences will be in new list
+       self.assertEqual(len(listx), 0, "List's don't match up in Protobuf Response")           # exclusive or of lists should be empty
+
     def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'):
-        self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)       
+        self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
         self.checkProtobufBase(msg, protocol, response, initiator)
-        self.assertTrue(msg.HasField('response'))                      
-        self.assertTrue(msg.response.HasField('queryTimeSec')) 
+        self.assertTrue(msg.HasField('response'))
+        self.assertTrue(msg.response.HasField('queryTimeSec'))
 
        testList = [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"]
        listx = set(msg.response.tags) ^ set(testList)                                          # only differences will be in new list
@@ -240,7 +250,7 @@ class TestProtobuf(DNSDistTest):
         """
         Protobuf: Send data to a protobuf server
         """
-        name = 'query.protobuf.tests.powerdns.com.'            
+        name = 'query.protobuf.tests.powerdns.com.'
 
 
         target = 'target.protobuf.tests.powerdns.com.'
@@ -274,12 +284,12 @@ class TestProtobuf(DNSDistTest):
         # check the protobuf message corresponding to the UDP query
         msg = self.getFirstProtobufMessage()
 
-        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)         
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
 
         # check the protobuf message corresponding to the UDP response
         msg = self.getFirstProtobufMessage()
         self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response)     # check UDP response
-        self.assertEquals(len(msg.response.rrs), 2)                    
+        self.assertEquals(len(msg.response.rrs), 2)
         rr = msg.response.rrs[0]
         self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
         self.assertEquals(rr.rdata, target)
@@ -299,6 +309,7 @@ class TestProtobuf(DNSDistTest):
 
         # check the protobuf message corresponding to the TCP query
         msg = self.getFirstProtobufMessage()
+
         self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
 
         # check the protobuf message corresponding to the TCP response
@@ -330,8 +341,8 @@ class TestProtobuf(DNSDistTest):
 
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
 
-        self.assertTrue(receivedQuery)                 
-        self.assertTrue(receivedResponse)              
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
         receivedQuery.id = query.id
         self.assertEquals(query, receivedQuery)
         self.assertEquals(response, receivedResponse)
@@ -342,7 +353,8 @@ class TestProtobuf(DNSDistTest):
 
         # check the protobuf message corresponding to the UDP query
         msg = self.getFirstProtobufMessage()
-        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0')
+
+       self.checkProtobufQueryConvertedToResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')         # check UDP Response
 
         # check the protobuf message corresponding to the UDP response
         msg = self.getFirstProtobufMessage()
@@ -364,7 +376,7 @@ class TestProtobuf(DNSDistTest):
 
         # check the protobuf message corresponding to the TCP query
         msg = self.getFirstProtobufMessage()
-        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0')
+       self.checkProtobufQueryConvertedToResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')         # check TCP Response
 
         # check the protobuf message corresponding to the TCP response
         msg = self.getFirstProtobufMessage()