]> granicus.if.org Git - pdns/commitdiff
addressed fixes requested by Remi July 3rd
authorSeth Ornstein <sornstein@globalcyberalliance.org>
Thu, 6 Jul 2017 15:06:05 +0000 (11:06 -0400)
committerNick Saika <nicksaika@gmail.com>
Wed, 2 Aug 2017 15:06:43 +0000 (11:06 -0400)
pdns/dnsdist-lua.cc
pdns/dnsdist-lua2.cc
pdns/dnsdist.hh
pdns/protobuf.cc
pdns/protobuf.hh
regression-tests.dnsdist/test_Protobuf.py
regression-tests.dnsdist/test_ProtobufTag.py [deleted file]

index e1ccacb2d726fdd407c586829ca7b6242f27f481..ee7c438fdad0baae666634b35af0f8886eb5e20c 100644 (file)
@@ -1586,54 +1586,47 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
       }
     });
 
-    g_lua.registerFunction<void(DNSQuestion::*)(std::string, std::string)>("setTag", [](DNSQuestion& dq, const std::string& strLabel, const std::string& strValue) {
+  g_lua.registerFunction<void(DNSQuestion::*)(std::string, std::string)>("setTag", [](DNSQuestion& dq, const std::string& strLabel, const std::string& strValue) {
 
-      if(dq.qTag == NULL)
-        {
-         dq.qTag = std::shared_ptr<QTag>(new QTag);
-        }
+      if(dq.qTag == nullptr) {
+        dq.qTag = std::make_shared<QTag>();
+      }
       dq.qTag->add(strLabel, strValue);
 
     });
 
-     g_lua.registerFunction<void(DNSQuestion::*)(vector<pair<string, string>>)>("setTagArray", [](DNSQuestion& dq, const vector<pair<string, string>>&tags) {
+  g_lua.registerFunction<void(DNSQuestion::*)(vector<pair<string, string>>)>("setTagArray", [](DNSQuestion& dq, const vector<pair<string, string>>&tags) {
 
-      if(dq.qTag == NULL)
-        {
-         dq.qTag = std::shared_ptr<QTag>(new QTag);
-        }
+      if(dq.qTag == nullptr) {
+        dq.qTag = std::make_shared<QTag>();
+      }
 
-      for (const auto& tag : tags)
-        {
-           dq.qTag->add(tag.first, tag.second);
-        }
+      for (const auto& tag : tags) {
+        dq.qTag->add(tag.first, tag.second);
+      }
 
      });
 
 
-     g_lua.registerFunction<string(DNSQuestion::*)(std::string)>("getTagMatch", [](const DNSQuestion& dq, const std::string& strLabel) {
+  g_lua.registerFunction<string(DNSQuestion::*)(std::string)>("getTag", [](const DNSQuestion& dq, const std::string& strLabel) {
 
     std::string strValue;
-    if(dq.qTag != NULL)
-      {
-        strValue = dq.qTag->getMatch(strLabel);
-      }
+    if(dq.qTag != nullptr) {
+      strValue = dq.qTag->getMatch(strLabel);
+    }
     return strValue;
 
      });
 
 
-     g_lua.registerFunction<std::unordered_map<string, string>(DNSQuestion::*)(void)>("getTagArray", [](const DNSQuestion& dq) {
+  g_lua.registerFunction<std::unordered_map<string, string>(DNSQuestion::*)(void)>("getTagArray", [](const DNSQuestion& dq) {
 
-    if(dq.qTag != NULL)
-      {
-        return dq.qTag->tagData;
-      }
-    else
-      {
-       std::unordered_map<string, string> XX;
-       return(XX);
-      }
+    if(dq.qTag != nullptr) {
+      return dq.qTag->tagData;
+    } else {
+      std::unordered_map<string, string> XX;
+      return XX;
+    }
       });
 
 
index 6d8e9af895a5152c2a3c07545afdf1248d8394cb..f31046c628702024a6c9b49d1f58dff5d6b6b85b 100644 (file)
@@ -831,47 +831,27 @@ void moreLua(bool client)
 #endif
       });
 
-     g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(std::string)>("setTag", [](DNSDistProtoBufMessage& message, const std::string& strValue) {
-
+    g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(std::string)>("setTag", [](DNSDistProtoBufMessage& message, const std::string& strValue) {
       message.addTag(strValue);
      });
 
-     g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(vector<pair<int, string>>)>("setTagArray", [](DNSDistProtoBufMessage& message, const vector<pair<int, string>>&tags) {
-
-
-      for (const auto& tag : tags)
-        {
-          message.addTag(tag.second);
-        }
+    g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(vector<pair<int, string>>)>("setTagArray", [](DNSDistProtoBufMessage& message, const vector<pair<int, string>>&tags) {
+      for (const auto& tag : tags) {
+        message.addTag(tag.second);
+      }
      });
 
-     g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(boost::optional <time_t> sec, boost::optional <uint> uSec)>("setProtobufResponseType",
+    g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(boost::optional <time_t> sec, boost::optional <uint> uSec)>("setProtobufResponseType",
                                         [](DNSDistProtoBufMessage& message, boost::optional <time_t> sec, boost::optional <uint> uSec) {
-
         message.setType(DNSProtoBufMessage::Response);
-
         message.setQueryTime(sec?*sec:0, uSec?*uSec:0);
-
      });
 
-     g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(const std::string&, uint uType, uint uClass, uint uTTL, vector<pair<int, int>> )>("setProtobufResponseRR", [](DNSDistProtoBufMessage& message,
-                                                            const std::string& strQueryName, uint uType, uint uClass, uint uTTL, const vector<pair<int, int>>& blobData) {
-
-        size_t blobSize = blobData.size();
-
-        unique_ptr<uint8_t[]> ptrBlob (new uint8_t(blobSize));
-
-        int jj=0;
-        for (const auto& blob : blobData)
-           {
-            ptrBlob[jj++] = blob.second;
-           }
-
-        message.addRR(strQueryName, uType, uClass, uTTL, ptrBlob.get(), blobSize);
-
+    g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(const std::string& strQueryName, uint uType, uint uClass, uint uTTL, const std::string& strBlob)>("addResponseRR", [](DNSDistProtoBufMessage& message,
+                                                            const std::string& strQueryName, uint16_t uType, uint uClass, uint32_t uTTL, const std::string& strBlob) {
+        message.addRR(strQueryName, uType, uClass, uTTL, strBlob);
      });
 
-
     g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(const Netmask&)>("setEDNSSubnet", [](DNSDistProtoBufMessage& message, const Netmask& subnet) { message.setEDNSSubnet(subnet); });
     g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(const DNSName&, uint16_t, uint16_t)>("setQuestion", [](DNSDistProtoBufMessage& message, const DNSName& qname, uint16_t qtype, uint16_t qclass) { message.setQuestion(qname, qtype, qclass); });
     g_lua.registerFunction<void(DNSDistProtoBufMessage::*)(size_t)>("setBytes", [](DNSDistProtoBufMessage& message, size_t bytes) { message.setBytes(bytes); });
index 85832ebbe6d4ce19a32da07d0aaff5212ced6370..bb427a8baa4a1bcd1847e32e5750216919d00c0f 100644 (file)
@@ -68,26 +68,22 @@ QTag()
 {
 }
 
-bool add(std::string strLabel, std::string strValue)
+void add(std::string strLabel, std::string strValue)
 {
-bool bStatus = true;
 
-   tagData.insert( {strLabel, strValue});
-  return(bStatus);
+  tagData.insert( {strLabel, strValue});
+  return;
 }
 
 std::string getMatch(const std::string& strLabel)  const
 {
 
-    std::unordered_map<std::string, std::string>::const_iterator got =tagData.find (strLabel);
-    if(got == tagData.end())
-      {
-       return("");
-      }
-    else
-      {
-       return(got->second);
-      }
+  std::unordered_map<std::string, std::string>::const_iterator got =tagData.find (strLabel);
+  if(got == tagData.end()) {
+    return "";
+  } else {
+    return got->second;
+  }
 }
 
 std::string getEntry(size_t iEntry) const
@@ -96,25 +92,23 @@ std::string strEntry;
 size_t iCounter = 0;
 
 
-   for (const auto& itr : tagData)
-      {
-        iCounter++;
-        if(iCounter == iEntry)
-          {
-           strEntry = itr.first;
-           strEntry += strSep;
-           strEntry += itr.second;
-           break;
-          }
-      }
+  for (const auto& itr : tagData) {
+    iCounter++;
+    if(iCounter == iEntry) {
+      strEntry = itr.first;
+      strEntry += strSep;
+      strEntry += itr.second;
+      break;
+    }
+  }
 
-   return(strEntry);
+  return strEntry;
 
 }
 
 size_t count() const
 {
-   return(tagData.size());
+   return tagData.size();
 }
 
 std::string dumpString() const
@@ -122,15 +116,13 @@ std::string dumpString() const
 std::string strRet;
 
 
-   for (const auto& itr : tagData)
-      {
-        strRet += itr.first;
-        strRet += strSep;
-        strRet += itr.second;
-        strRet += "\n";
-       }
-    return(strRet);
-
+  for (const auto& itr : tagData) {
+    strRet += itr.first;
+    strRet += strSep;
+    strRet += itr.second;
+    strRet += "\n";
+  }
+  return strRet;
 
 }
 
index 47f02cb643e5200dce4b601c262791299d16df50..66c0c8b22e81fa6b2f8ec3981c5ab0e89513eb11 100644 (file)
@@ -110,7 +110,7 @@ void DNSProtoBufMessage::addTag(const std::string& strValue)
 #endif /* HAVE_PROTOBUF */
 }
 
-void DNSProtoBufMessage::addRR(const std::string& strName, uint32_t uType, uint32_t uClass, uint32_t uTTL, const uint8_t *ptrBlob, size_t uBlobLen)
+void DNSProtoBufMessage::addRR(const std::string& strName, uint32_t uType, uint32_t uClass, uint32_t uTTL, const std::string& strBlob)
 {
 #ifdef HAVE_PROTOBUF
 
@@ -120,12 +120,11 @@ void DNSProtoBufMessage::addRR(const std::string& strName, uint32_t uType, uint3
 
   PBDNSMessage_DNSResponse_DNSRR* rr = response->add_rrs();
       if (rr) {
-        string blob;
         rr->set_name(strName.c_str());
         rr->set_type(uType);
         rr->set_class_(uClass);
         rr->set_ttl(uTTL);
-        rr->set_rdata(ptrBlob, uBlobLen);
+        rr->set_rdata((const uint8_t *) strBlob.c_str(), strBlob.size());
       }
 
 #endif /* HAVE_PROTOBUF */
index b31a5a45ba3194bd963fb5aca8c59b5f47e4e151..9decfdc24110928d97edffbcdb2454a130f90dd1 100644 (file)
@@ -71,7 +71,7 @@ public:
   void setRequestorId(const std::string& requestorId);
   std::string toDebugString() const;
   void addTag(const std::string& strValue);
-  void addRR(const std::string& strName, uint32_t utype, uint32_t uClass, uint32_t uTTl, const uint8_t *ptrBlob, size_t uBlobLen);
+  void addRR(const std::string& strName, uint32_t utype, uint32_t uClass, uint32_t uTTl, const std::string& strBlob);
 
 #ifdef HAVE_PROTOBUF
   DNSProtoBufMessage(DNSProtoBufMessage::DNSProtoBufMessageType type, const boost::uuids::uuid& uuid, const ComboAddress* requestor, const ComboAddress* responder, const DNSName& domain, int qtype, uint16_t qclass, uint16_t qid, bool isTCP, size_t bytes);
index 6bbd4cfb41976ccab12b86e046319f7aca87905a..ba3ffdf88c4a526ba995bdf3d501de748e50a23b 100644 (file)
@@ -10,32 +10,110 @@ from dnsdisttests import DNSDistTest
 import dns
 import dnsmessage_pb2
 
-class TestProtobuf(DNSDistTest):
 
+class TestProtobuf(DNSDistTest):
     _protobufServerPort = 4242
     _protobufQueue = Queue.Queue()
     _protobufCounter = 0
     _config_params = ['_testServerPort', '_protobufServerPort']
     _config_template = """
-    luasmn = newSuffixMatchNode()
-    luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.'))
-
-    function alterProtobuf(dq, protobuf)
-      if luasmn:check(dq.qname) then
-        requestor = newCA(dq.remoteaddr:toString())
-        if requestor:isIPv4() then
-          requestor:truncate(24)
+    luasmn = newSuffixMatchNode()                              
+    luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.')) 
+
+    function alterProtobufResponse(dq, protobuf)               
+      if luasmn:check(dq.qname) then                           
+        requestor = newCA(dq.remoteaddr:toString())            -- called by testLuaProtobuf()  
+        if requestor:isIPv4() then                             
+          requestor:truncate(24)                               
         else
-          requestor:truncate(56)
+          requestor:truncate(56)                               
         end
-        protobuf:setRequestor(requestor)
+        protobuf:setRequestor(requestor)                       
+
+       local tableTags = {}    
+       table.insert(tableTags, "TestLabel1,TestData1")
+       table.insert(tableTags, "TestLabel2,TestData2")
+                                 
+       protobuf:setTagArray(tableTags)                          
+
+       protobuf:setTag('TestLabel3,TestData3')                 
+
+       protobuf:setTag("Response,456")                         
+      else                                                     
+       local tableTags = {}                                    -- called by testProtobuf()                                   
+       table.insert(tableTags, "TestLabel1,TestData1")
+       table.insert(tableTags, "TestLabel2,TestData2")
+       protobuf:setTagArray(tableTags)                         
+
+       protobuf:setTag('TestLabel3,TestData3')                 
+
+       protobuf:setTag("Response,456")                         
+      end
+    end
+
+    function alterProtobufQuery(dq, protobuf)                  
+      if luasmn:check(dq.qname) then                           
+        requestor = newCA(dq.remoteaddr:toString())            -- called by testLuaProtobuf()
+        if requestor:isIPv4() then                             
+          requestor:truncate(24)                               
+        else
+          requestor:truncate(56)                               
+        end
+        protobuf:setRequestor(requestor)                       
+
+       local tableTags = {}                                    
+       tableTags = dq:getTagArray()                            -- get table from DNSQuery
+
+       local tablePB = {}
+               for k, v in pairs( tableTags) do
+               table.insert(tablePB, k .. "," .. v)
+       end
+
+       protobuf:setTagArray(tablePB)                           -- store table in protobuf
+       protobuf:setTag("Query,123")                            -- add another tag entry in protobuf
+
+       protobuf:setResponseCode(dnsdist.NXDOMAIN)              -- set protobuf response code to be NXDOMAIN
+
+       local strReqName = dq.qname:toString()                  -- get request dns name
+
+       protobuf:setProtobufResponseType()                      -- set protobuf to look like a response and not a query, with 0 default time
+
+       blobData = '\127' .. '\000' .. '\000' .. '\001'         -- 127.0.0.1, note: lua 5.1 can only embed decimal not hex
+       protobuf:addResponseRR(strReqName, 1, 1, 123, blobData)  -- add a RR to the protobuf
+                                                               
+      else
+
+       local tableTags = {}                                    -- called by testProtobuf()
+       table.insert(tableTags, "TestLabel1,TestData1")
+       table.insert(tableTags, "TestLabel2,TestData2")
+
+       protobuf:setTagArray(tableTags)                         
+       protobuf:setTag('TestLabel3,TestData3')                 
+       protobuf:setTag("Query,123")                            
       end
     end
 
+    function alterLuaFirst(dq)                                 -- called when dnsdist receives new request
+       local tt = {}                   
+        tt["TestLabel1"] = "TestData1"                  
+        tt["TestLabel2"] = "TestData2"                  
+
+       dq:setTagArray(tt)                                      
+
+       dq:setTag("TestLabel3","TestData3")                     
+       return DNSAction.None, ""                               -- continue to the next rule
+    end
+
+
     newServer{address="127.0.0.1:%s", useClientSubnet=true}
     rl = newRemoteLogger('127.0.0.1:%s')
-    addAction(AllRule(), RemoteLogAction(rl, alterProtobuf))
-    addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobuf, true))
+
+    addLuaAction(AllRule(), alterLuaFirst)                                                     -- Add tags to DNSQuery first
+
+    addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery))                              -- Send protobuf message before lookup
+
+    addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true))     -- Send protobuf message after lookup
+
     """
 
     @classmethod
@@ -71,6 +149,7 @@ class TestProtobuf(DNSDistTest):
         cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort])
         cls._UDPResponder.setDaemon(True)
         cls._UDPResponder.start()
+
         cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort])
         cls._TCPResponder.setDaemon(True)
         cls._TCPResponder.start()
@@ -94,22 +173,28 @@ class TestProtobuf(DNSDistTest):
         self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
         self.assertTrue(msg.HasField('from'))
         fromvalue = getattr(msg, 'from')
-        self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator)
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator)              
         self.assertTrue(msg.HasField('socketProtocol'))
         self.assertEquals(msg.socketProtocol, protocol)
         self.assertTrue(msg.HasField('messageId'))
         self.assertTrue(msg.HasField('id'))
-        self.assertEquals(msg.id, query.id)
+        self.assertEquals(msg.id, query.id)                                                    
         self.assertTrue(msg.HasField('inBytes'))
-        self.assertEquals(msg.inBytes, len(query.to_wire()))
+        self.assertEquals(msg.inBytes, len(query.to_wire()))           
         # dnsdist doesn't set the existing EDNS Subnet for now,
         # although it might be set from Lua
         # self.assertTrue(msg.HasField('originalRequestorSubnet'))
         # self.assertEquals(len(msg.originalRequestorSubnet), 4)
         # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
 
+
     def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'):
-        self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)
+
+       if initiator == '127.0.0.1':
+               self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)           # testProtobuf()
+       else:
+               self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)        # testLuaProtobuf()
+
         self.checkProtobufBase(msg, protocol, query, initiator)
         # dnsdist doesn't fill the responder field for responses
         # because it doesn't keep the information around.
@@ -123,11 +208,22 @@ class TestProtobuf(DNSDistTest):
         self.assertTrue(msg.question.HasField('qName'))
         self.assertEquals(msg.question.qName, qname)
 
+
+
+       testList = [u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"]
+       listx = set(msg.response.tags) ^ set(testList)                                          # only differences will be in new list
+
+       self.assertEqual(len(listx), 0, "Lists don't match up in Protobuf Query")               # exclusive or of lists should be empty
+
     def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'):
-        self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
+        self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)       
         self.checkProtobufBase(msg, protocol, response, initiator)
-        self.assertTrue(msg.HasField('response'))
-        self.assertTrue(msg.response.HasField('queryTimeSec'))
+        self.assertTrue(msg.HasField('response'))                      
+        self.assertTrue(msg.response.HasField('queryTimeSec')) 
+
+       testList = [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"]
+       listx = set(msg.response.tags) ^ set(testList)                                          # only differences will be in new list
+       self.assertEqual(len(listx), 0, "List's don't match up in Protobuf Response")           # exclusive or of lists should be empty
 
     def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl):
         self.assertTrue(record.HasField('class'))
@@ -144,16 +240,20 @@ class TestProtobuf(DNSDistTest):
         """
         Protobuf: Send data to a protobuf server
         """
-        name = 'query.protobuf.tests.powerdns.com.'
+        name = 'query.protobuf.tests.powerdns.com.'            
+
+
         target = 'target.protobuf.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
+
         rrset = dns.rrset.from_text(name,
                                     3600,
                                     dns.rdataclass.IN,
                                     dns.rdatatype.CNAME,
                                     target)
         response.answer.append(rrset)
+
         rrset = dns.rrset.from_text(target,
                                     3600,
                                     dns.rdataclass.IN,
@@ -173,12 +273,13 @@ class TestProtobuf(DNSDistTest):
 
         # check the protobuf message corresponding to the UDP query
         msg = self.getFirstProtobufMessage()
-        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)         
 
         # check the protobuf message corresponding to the UDP response
         msg = self.getFirstProtobufMessage()
-        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response)
-        self.assertEquals(len(msg.response.rrs), 2)
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response)     # check UDP response
+        self.assertEquals(len(msg.response.rrs), 2)                    
         rr = msg.response.rrs[0]
         self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
         self.assertEquals(rr.rdata, target)
@@ -202,7 +303,7 @@ class TestProtobuf(DNSDistTest):
 
         # check the protobuf message corresponding to the TCP response
         msg = self.getFirstProtobufMessage()
-        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response)
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response)     # check TCP response
         self.assertEquals(len(msg.response.rrs), 2)
         rr = msg.response.rrs[0]
         self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
@@ -212,6 +313,7 @@ class TestProtobuf(DNSDistTest):
         self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
 
     def testLuaProtobuf(self):
+
         """
         Protobuf: Check that the Lua callback rewrote the initiator
         """
@@ -225,13 +327,16 @@ class TestProtobuf(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
+
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
+
+        self.assertTrue(receivedQuery)                 
+        self.assertTrue(receivedResponse)              
         receivedQuery.id = query.id
         self.assertEquals(query, receivedQuery)
         self.assertEquals(response, receivedResponse)
 
+
         # let the protobuf messages the time to get there
         time.sleep(1)
 
@@ -241,7 +346,7 @@ class TestProtobuf(DNSDistTest):
 
         # check the protobuf message corresponding to the UDP response
         msg = self.getFirstProtobufMessage()
-        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')                # check UDP Response
         self.assertEquals(len(msg.response.rrs), 1)
         for rr in msg.response.rrs:
             self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
@@ -263,7 +368,7 @@ class TestProtobuf(DNSDistTest):
 
         # check the protobuf message corresponding to the TCP response
         msg = self.getFirstProtobufMessage()
-        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')                # check TCP response
         self.assertEquals(len(msg.response.rrs), 1)
         for rr in msg.response.rrs:
             self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
diff --git a/regression-tests.dnsdist/test_ProtobufTag.py b/regression-tests.dnsdist/test_ProtobufTag.py
deleted file mode 100644 (file)
index bb743a6..0000000
+++ /dev/null
@@ -1,377 +0,0 @@
-#!/usr/bin/env python
-import Queue
-import threading
-import socket
-import struct
-import sys
-import time
-from dnsdisttests import DNSDistTest
-
-import dns
-import dnsmessage_pb2
-
-
-class TestProtobuf(DNSDistTest):
-    _protobufServerPort = 4242
-    _protobufQueue = Queue.Queue()
-    _protobufCounter = 0
-    _config_params = ['_testServerPort', '_protobufServerPort']
-    _config_template = """
-    luasmn = newSuffixMatchNode()                              
-    luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.')) 
-
-    function alterProtobufResponse(dq, protobuf)               
-      if luasmn:check(dq.qname) then                           
-        requestor = newCA(dq.remoteaddr:toString())            -- called by testLuaProtobuf()  
-        if requestor:isIPv4() then                             
-          requestor:truncate(24)                               
-        else
-          requestor:truncate(56)                               
-        end
-        protobuf:setRequestor(requestor)                       
-
-       local tableTags = {}    
-       table.insert(tableTags, "TestLabel1,TestData1")
-       table.insert(tableTags, "TestLabel2,TestData2")
-                                 
- --       tableTags["TestLabel2"] = "TestData2"                    
---        tableTags["TestLabel1"] = "TestData1"                    
-       protobuf:setTagArray(tableTags)                         -- setTagArray 
-
-       protobuf:setTag('TestLabel3,TestData3')                 -- setTag
-
-       protobuf:setTag("Response,456")                         -- setTag
-      else                                                     
-       local tableTags = {}                                    -- called by testProtobuf()                                   
-       table.insert(tableTags, "TestLabel1,TestData1")
-       table.insert(tableTags, "TestLabel2,TestData2")
-       protobuf:setTagArray(tableTags)                         -- setTagArray
-
-       protobuf:setTag('TestLabel3,TestData3')                 -- setTag
-
-       protobuf:setTag("Response,456")                         -- setTag
-      end
-    end
-
-    function alterProtobufQuery(dq, protobuf)                  
-      if luasmn:check(dq.qname) then                           
-        requestor = newCA(dq.remoteaddr:toString())            -- called by testLuaProtobuf()
-        if requestor:isIPv4() then                             
-          requestor:truncate(24)                               
-        else
-          requestor:truncate(56)                               
-        end
-        protobuf:setRequestor(requestor)                       
-
-       local tableTags = {}                                    -- declare table
-       tableTags = dq:getTagArray()                            -- get table from DNSQuery
-
-       local tablePB = {}
-               for k, v in pairs( tableTags) do
-               table.insert(tablePB, k .. "," .. v)
-       end
-
-       protobuf:setTagArray(tablePB)                           -- store table in protobuf
-       protobuf:setTag("Query,123")                            -- add another tag entry in protobuf
-
-       protobuf:setResponseCode(dnsdist.NXDOMAIN)              -- set protobuf response code to be NXDOMAIN
-
-       local strReqName = dq.qname:toString()                  -- get request dns name
-
-       protobuf:setProtobufResponseType()                      -- set protobuf to look like a response and not a query, with 0 default time
-
-       blobData={0x7F, 0x00, 0x00, 0x01}                       -- 127.0.0.1
-       protobuf:setProtobufResponseRR(strReqName, 1, 1, 123, blobData)    -- set protobuf to have a RR 
-                                                               
-      else
-
-       local tableTags = {}                                    -- called by testProtobuf()
-       table.insert(tableTags, "TestLabel1,TestData1")
-       table.insert(tableTags, "TestLabel2,TestData2")
-
-       protobuf:setTagArray(tableTags)                         -- setTagArray
-       protobuf:setTag('TestLabel3,TestData3')                 -- setTag
-       protobuf:setTag("Query,123")                            -- setTag
-      end
-    end
-
-    function alterLuaFirst(dq)                                 -- called when dnsdist receives new request
-       local tt = {}                   
-        tt["TestLabel1"] = "TestData1"                  
-        tt["TestLabel2"] = "TestData2"                  
-
-       dq:setTagArray(tt)                                      --setTagArray
-
-       dq:setTag("TestLabel3","TestData3")                     -- setTag
-       return DNSAction.None, ""                               -- continue to the next rule
-    end
-
-
-    newServer{address="127.0.0.1:%s", useClientSubnet=true}
-    rl = newRemoteLogger('127.0.0.1:%s')
-
-    addLuaAction(AllRule(), alterLuaFirst)                                                     -- Add tags to DNSQuery first
-
-    addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery))                              -- Send protobuf message before lookup
-
-    addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true))     -- Send protobuf message after lookup
-
-    """
-
-    @classmethod
-    def ProtobufListener(cls, port):
-        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
-        try:
-            sock.bind(("127.0.0.1", port))
-        except socket.error as e:
-            print("Error binding in the protbuf listener: %s" % str(e))
-            sys.exit(1)
-
-        sock.listen(100)
-        while True:
-            (conn, _) = sock.accept()
-            data = None
-            while True:
-                data = conn.recv(2)
-                if not data:
-                    break
-                (datalen,) = struct.unpack("!H", data)
-                data = conn.recv(datalen)
-                if not data:
-                    break
-
-                cls._protobufQueue.put(data, True, timeout=2.0)
-
-            conn.close()
-        sock.close()
-
-    @classmethod
-    def startResponders(cls):
-        cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort])
-        cls._UDPResponder.setDaemon(True)
-        cls._UDPResponder.start()
-
-        cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort])
-        cls._TCPResponder.setDaemon(True)
-        cls._TCPResponder.start()
-
-        cls._protobufListener = threading.Thread(name='Protobuf Listener', target=cls.ProtobufListener, args=[cls._protobufServerPort])
-        cls._protobufListener.setDaemon(True)
-        cls._protobufListener.start()
-
-    def getFirstProtobufMessage(self):
-        self.assertFalse(self._protobufQueue.empty())
-        data = self._protobufQueue.get(False)
-        self.assertTrue(data)
-        msg = dnsmessage_pb2.PBDNSMessage()
-        msg.ParseFromString(data)
-        return msg
-
-    def checkProtobufBase(self, msg, protocol, query, initiator):
-        self.assertTrue(msg)
-        self.assertTrue(msg.HasField('timeSec'))
-        self.assertTrue(msg.HasField('socketFamily'))
-        self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
-        self.assertTrue(msg.HasField('from'))
-        fromvalue = getattr(msg, 'from')
-        self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator)              
-        self.assertTrue(msg.HasField('socketProtocol'))
-        self.assertEquals(msg.socketProtocol, protocol)
-        self.assertTrue(msg.HasField('messageId'))
-        self.assertTrue(msg.HasField('id'))
-        self.assertEquals(msg.id, query.id)                                                    
-        self.assertTrue(msg.HasField('inBytes'))
-        self.assertEquals(msg.inBytes, len(query.to_wire()))           
-        # dnsdist doesn't set the existing EDNS Subnet for now,
-        # although it might be set from Lua
-        # self.assertTrue(msg.HasField('originalRequestorSubnet'))
-        # self.assertEquals(len(msg.originalRequestorSubnet), 4)
-        # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
-
-
-    def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'):
-
-       if initiator == '127.0.0.1':
-               self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)           # testProtobuf()
-       else:
-               self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)        # testLuaProtobuf()
-
-        self.checkProtobufBase(msg, protocol, query, initiator)
-        # dnsdist doesn't fill the responder field for responses
-        # because it doesn't keep the information around.
-        self.assertTrue(msg.HasField('to'))
-        self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.to), '127.0.0.1')
-        self.assertTrue(msg.HasField('question'))
-        self.assertTrue(msg.question.HasField('qClass'))
-        self.assertEquals(msg.question.qClass, qclass)
-        self.assertTrue(msg.question.HasField('qType'))
-        self.assertEquals(msg.question.qClass, qtype)
-        self.assertTrue(msg.question.HasField('qName'))
-        self.assertEquals(msg.question.qName, qname)
-
-
-
-       testList = [u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"]
-       listx = set(msg.response.tags) ^ set(testList)                                          # only differences will be in new list
-
-       self.assertEqual(len(listx), 0, "Lists don't match up in Protobuf Query")               # exclusive or of lists should be empty
-
-    def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'):
-        self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)       
-        self.checkProtobufBase(msg, protocol, response, initiator)
-        self.assertTrue(msg.HasField('response'))                      
-        self.assertTrue(msg.response.HasField('queryTimeSec')) 
-
-       testList = [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"]
-       listx = set(msg.response.tags) ^ set(testList)                                          # only differences will be in new list
-       self.assertEqual(len(listx), 0, "List's don't match up in Protobuf Response")           # exclusive or of lists should be empty
-
-    def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl):
-        self.assertTrue(record.HasField('class'))
-        self.assertEquals(getattr(record, 'class'), rclass)
-        self.assertTrue(record.HasField('type'))
-        self.assertEquals(record.type, rtype)
-        self.assertTrue(record.HasField('name'))
-        self.assertEquals(record.name, rname)
-        self.assertTrue(record.HasField('ttl'))
-        self.assertEquals(record.ttl, rttl)
-        self.assertTrue(record.HasField('rdata'))
-
-    def testProtobuf(self):
-        """
-        Protobuf: Send data to a protobuf server
-        """
-        name = 'query.protobuf.tests.powerdns.com.'            
-
-
-        target = 'target.protobuf.tests.powerdns.com.'
-        query = dns.message.make_query(name, 'A', 'IN')
-        response = dns.message.make_response(query)
-
-        rrset = dns.rrset.from_text(name,
-                                    3600,
-                                    dns.rdataclass.IN,
-                                    dns.rdatatype.CNAME,
-                                    target)
-        response.answer.append(rrset)
-
-        rrset = dns.rrset.from_text(target,
-                                    3600,
-                                    dns.rdataclass.IN,
-                                    dns.rdatatype.A,
-                                    '127.0.0.1')
-        response.answer.append(rrset)
-
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        # let the protobuf messages the time to get there
-        time.sleep(1)
-
-        # check the protobuf message corresponding to the UDP query
-        msg = self.getFirstProtobufMessage()
-
-        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)         
-
-        # check the protobuf message corresponding to the UDP response
-        msg = self.getFirstProtobufMessage()
-        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response)     # check UDP response
-        self.assertEquals(len(msg.response.rrs), 2)                    
-        rr = msg.response.rrs[0]
-        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
-        self.assertEquals(rr.rdata, target)
-        rr = msg.response.rrs[1]
-        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
-        self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        # let the protobuf messages the time to get there
-        time.sleep(1)
-
-        # check the protobuf message corresponding to the TCP query
-        msg = self.getFirstProtobufMessage()
-        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
-
-        # check the protobuf message corresponding to the TCP response
-        msg = self.getFirstProtobufMessage()
-        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response)     # check TCP response
-        self.assertEquals(len(msg.response.rrs), 2)
-        rr = msg.response.rrs[0]
-        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
-        self.assertEquals(rr.rdata, target)
-        rr = msg.response.rrs[1]
-        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
-        self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
-
-    def testLuaProtobuf(self):
-
-        """
-        Protobuf: Check that the Lua callback rewrote the initiator
-        """
-        name = 'lua.protobuf.tests.powerdns.com.'
-        query = dns.message.make_query(name, 'A', 'IN')
-        response = dns.message.make_response(query)
-        rrset = dns.rrset.from_text(name,
-                                    3600,
-                                    dns.rdataclass.IN,
-                                    dns.rdatatype.A,
-                                    '127.0.0.1')
-        response.answer.append(rrset)
-
-
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-
-        self.assertTrue(receivedQuery)                 
-        self.assertTrue(receivedResponse)              
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-
-        # let the protobuf messages the time to get there
-        time.sleep(1)
-
-        # check the protobuf message corresponding to the UDP query
-        msg = self.getFirstProtobufMessage()
-        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0')
-
-        # check the protobuf message corresponding to the UDP response
-        msg = self.getFirstProtobufMessage()
-        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')                # check UDP Response
-        self.assertEquals(len(msg.response.rrs), 1)
-        for rr in msg.response.rrs:
-            self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
-            self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        # let the protobuf messages the time to get there
-        time.sleep(1)
-
-        # check the protobuf message corresponding to the TCP query
-        msg = self.getFirstProtobufMessage()
-        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0')
-
-        # check the protobuf message corresponding to the TCP response
-        msg = self.getFirstProtobufMessage()
-        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')                # check TCP response
-        self.assertEquals(len(msg.response.rrs), 1)
-        for rr in msg.response.rrs:
-            self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
-            self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')