]> granicus.if.org Git - pdns/commitdiff
dnsdist: Replace the Lua params with a DNSQuestion `dq` object
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 14 Jan 2016 11:57:33 +0000 (12:57 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 14 Jan 2016 11:57:33 +0000 (12:57 +0100)
In order to:
1. Be able to add functions/member without breaking the API
2. Being as compatible as possible with the PowerDNS Lua API

To limit the parsing/copy to a minimum, this DNSQuestion differs
from the PowerDNS one. Most Lua members are properly wrapped,
but it currently lacks some advanced functions like `getRecords()`
or `setRecords()`, that we might add later.
In addition to the existing `tostring()`, this commit adds
`toString()` ones to match the PowerDNS syntax.

LuaWrapper is supposed to support read-only members, where you
only define the getter and no setter, but I can't find the right
syntax for that to work, so for now the setter are present for
read-only members, and just do nothing.

pdns/README-dnsdist.md
pdns/dnsdist-lua.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistconf.lua
pdns/dnsrulactions.hh
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_EdnsClientSubnet.py

index e7c9711ce279479786b03e3da1b9addbd6915004..408b18532cabf8800d6732fad4edbce7cdc40f6b 100644 (file)
@@ -410,8 +410,8 @@ table of domain suffixes.  This is identical to how `addPoolRule()` selects.
 
 The function should look like this:
 ```
-function luarule(remote, qname, qtype, dh, len)
-        if(qtype==35) -- NAPTR
+function luarule(dq)
+        if(dq.qtype==35) -- NAPTR
         then
                 return DNSAction.Pool, "abuse" -- send to abuse pool
         else
@@ -570,7 +570,7 @@ for example:
 
 ```
 counter=0
-function luaroundrobin(servers, remote, qname, qtype, dh) 
+function luaroundrobin(servers, dq)
         counter=counter+1
         return servers[1+(counter % #servers)]
 end
@@ -589,12 +589,12 @@ To implement a split horizon, try:
 ```
 authServer=newServer({address="2001:888:2000:1d::2", pool="auth"})
 
-function splitSetup(servers, remote, qname, qtype, dh)
-        if(dh:getRD() == false)
+function splitSetup(servers, dq)
+        if(dq.dh:getRD() == false)
         then
-               return leastOutstanding.policy(getPoolServers("auth"), remote, qname, qtype, dh)
+               return leastOutstanding.policy(getPoolServers("auth"), dq)
         else
-               return leastOutstanding.policy(servers, remote, qname, qtype, dh)
+               return leastOutstanding.policy(servers, dq)
         end
 end
 
@@ -873,6 +873,16 @@ instantiate a server with additional parameters
      * `newDNSName(name)`: make a DNSName based on this .-terminated name
      * member `isPartOf(dnsname)`: is this dnsname part of that dnsname
      * member `tostring()`: return as a human friendly . terminated string
+   * DNSQuestion related:
+     * member `dh`: DNSHeader
+     * member `len`: the question length
+     * member `localaddr`: ComboAddress of the local bind this question was received on
+     * member `qname`: DNSName of this question
+     * member `qtype`: QType (as an unsigned integer) of this question
+     * member `remoteaddr`: ComboAddress of the remote client
+     * member `rcode`: RCode of this question
+     * member `size`: the total size of the buffer starting at `dh`
+     * member `tcp`: whether this question was received over a TCP socket
    * DNSHeader related
      * member `getRD()`: get recursion desired flag
      * member `setRD(bool)`: set recursion desired flag
@@ -909,10 +919,10 @@ All hooks
 ---------
 `dnsdist` can call Lua per packet if so configured, and will do so with the following hooks:
 
-  * `bool blockFilter(ComboAddress, DNSName, qtype, DNSHeader)`: if defined, called for every function. If this
+  * `bool blockFilter(ComboAddress, DNSQuestion)`: if defined, called for every function. If this
     returns true, the packet is dropped. If false is returned, `dnsdist` will check if the DNSHeader indicates
     the packet is now a query response. If so, `dnsdist` will answer the client directly with the modified packet.
-  * `server policy(candidates, ComboAddress, DNSName, qtype, DNSHeader)`: if configured with `setServerPolicyLua()` 
+  * `server policy(candidates, DNSQuestion)`: if configured with `setServerPolicyLua()`
     gets called for every packet. Candidates is a table of potential servers to pick from, ComboAddress is the 
     address of the requestor, DNSName and qtype describe name and type of query. DNSHeader meanwhile is available for 
     your inspection.
index 91df22529a6930b9f74a4df9c6a0c32376003d5c..12d442b0d31acbd5de99e7be48c99efb774ccdce 100644 (file)
@@ -16,19 +16,19 @@ static vector<std::function<void(void)>>* g_launchWork;
 class LuaAction : public DNSAction
 {
 public:
-  typedef std::function<std::tuple<int, string>(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t len, uint16_t bufferSize)> func_t;
+  typedef std::function<std::tuple<int, string>(DNSQuestion* dq)> func_t;
   LuaAction(LuaAction::func_t func) : d_func(func)
   {}
 
-  Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override
+  Action operator()(DNSQuestion* dq, string* ruleresult) const override
   {
-    auto ret = d_func(remote, qname, qtype, dh, len, bufferSize);
+    auto ret = d_func(dq);
     if(ruleresult)
       *ruleresult=std::get<1>(ret);
     return (Action)std::get<0>(ret);
   }
 
-  string toString() const 
+  string toString() const override
   {
     return "Lua script";
   }
@@ -658,12 +658,13 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
       }
 
       int matches=0;
+      ComboAddress dummy("127.0.0.1");
       DTime dt;
       dt.set();
       for(int n=0; n < times; ++n) {
         const item& i = items[n % items.size()];
-        struct dnsheader* dh = (struct dnsheader*)&i.packet[0];
-        if(rule->matches(i.rem, i.qname, i.qtype, dh, i.packet.size()))
+        DNSQuestion dq(&i.qname, i.qtype, &i.rem, &i.rem, (struct dnsheader*)&i.packet[0], i.packet.size(), i.packet.size(), false);
+        if(rule->matches(&dq))
           matches++;
       }
       double udiff=dt.udiff();
@@ -864,9 +865,12 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
 
   g_lua.registerFunction("tostring", &ComboAddress::toString);
   g_lua.registerFunction("tostringWithPort", &ComboAddress::toStringWithPort);
+  g_lua.registerFunction("toString", &ComboAddress::toString);
+  g_lua.registerFunction("toStringWithPort", &ComboAddress::toStringWithPort);
   g_lua.registerFunction<uint16_t(ComboAddress::*)()>("getPort", [](const ComboAddress& ca) { return ntohs(ca.sin4.sin_port); } );
   g_lua.registerFunction("isPartOf", &DNSName::isPartOf);
   g_lua.registerFunction<string(DNSName::*)()>("tostring", [](const DNSName&dn ) { return dn.toString(); });
+  g_lua.registerFunction<string(DNSName::*)()>("toString", [](const DNSName&dn ) { return dn.toString(); });
   g_lua.writeFunction("newDNSName", [](const std::string& name) { return DNSName(name); });
   g_lua.writeFunction("newSuffixMatchNode", []() { return SuffixMatchNode(); });
 
@@ -1178,6 +1182,19 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
       }
     });
 
+  /* DNSQuestion bindings */
+  /* PowerDNS DNSQuestion compat */
+  g_lua.registerMember<const ComboAddress (DNSQuestion::*)>("localaddr", [](const DNSQuestion& dq) -> const ComboAddress { return *dq.local; }, [](DNSQuestion& dq, const ComboAddress newLocal) { (void) newLocal; });
+  g_lua.registerMember<const DNSName (DNSQuestion::*)>("qname", [](const DNSQuestion& dq) -> const DNSName { return *dq.qname; }, [](DNSQuestion& dq, const DNSName newName) { (void) newName; });
+  g_lua.registerMember<uint16_t (DNSQuestion::*)>("qtype", [](const DNSQuestion& dq) -> uint16_t { return dq.qtype; }, [](DNSQuestion& dq, uint16_t newType) { (void) newType; });
+  g_lua.registerMember<int (DNSQuestion::*)>("rcode", [](const DNSQuestion& dq) -> int { return dq.dh->rcode; }, [](DNSQuestion& dq, int newRCode) { dq.dh->rcode = newRCode; });
+  g_lua.registerMember<const ComboAddress (DNSQuestion::*)>("remoteaddr", [](const DNSQuestion& dq) -> const ComboAddress { return *dq.remote; }, [](DNSQuestion& dq, const ComboAddress newRemote) { (void) newRemote; });
+  /* DNSDist DNSQuestion */
+  g_lua.registerMember("dh", &DNSQuestion::dh);
+  g_lua.registerMember<uint16_t (DNSQuestion::*)>("len", [](const DNSQuestion& dq) -> uint16_t { return dq.len; }, [](DNSQuestion& dq, uint16_t newlen) { dq.len = newlen; });
+  g_lua.registerMember<size_t (DNSQuestion::*)>("size", [](const DNSQuestion& dq) -> size_t { return dq.size; }, [](DNSQuestion& dq, size_t newSize) { (void) newSize; });
+  g_lua.registerMember<bool (DNSQuestion::*)>("tcp", [](const DNSQuestion& dq) -> bool { return dq.tcp; }, [](DNSQuestion& dq, bool newTcp) { (void) newTcp; });
+
   g_lua.writeFunction("setMaxTCPClientThreads", [](uint64_t max) { g_maxTCPClientThreads = max; });
 
   g_lua.writeFunction("setECSSourcePrefixV4", [](uint16_t prefix) { g_ECSSourcePrefixV4=prefix; });
index e70599d6553deae81114361c82b5b05067fd0c09..2f467bdb69d25419ba9e13d20e2ada9f0d1750b2 100644 (file)
@@ -131,7 +131,7 @@ void* tcpClientThread(int pipefd)
   /* we get launched with a pipe on which we receive file descriptors from clients that we own
      from that point on */
      
-  typedef std::function<bool(ComboAddress, DNSName, uint16_t, dnsheader*)> blockfilter_t;
+  typedef std::function<bool(const DNSQuestion*)> blockfilter_t;
   blockfilter_t blockFilter = 0;
   
   {
@@ -182,8 +182,7 @@ void* tcpClientThread(int pipefd)
         size_t querySize = qlen <= 4096 ? qlen + 512 : qlen;
         char queryBuffer[querySize];
         const char* query = queryBuffer;
-        uint16_t queryLen = qlen;
-        readn2WithTimeout(ci.fd, queryBuffer, queryLen, g_tcpRecvTimeout);
+        readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout);
 #ifdef HAVE_DNSCRYPT
         std::shared_ptr<DnsCryptQuery> dnsCryptQuery = 0;
 
@@ -191,7 +190,7 @@ void* tcpClientThread(int pipefd)
           dnsCryptQuery = std::make_shared<DnsCryptQuery>();
           uint16_t decryptedQueryLen = 0;
           vector<uint8_t> response;
-          bool decrypted = handleDnsCryptQuery(ci.cs->dnscryptCtx, queryBuffer, queryLen, dnsCryptQuery, &decryptedQueryLen, true, response);
+          bool decrypted = handleDnsCryptQuery(ci.cs->dnscryptCtx, queryBuffer, qlen, dnsCryptQuery, &decryptedQueryLen, true, response);
 
           if (!decrypted) {
             if (response.size() > 0) {
@@ -200,23 +199,23 @@ void* tcpClientThread(int pipefd)
             }
             break;
           }
-          queryLen = decryptedQueryLen;
+          qlen = decryptedQueryLen;
         }
 #endif
 
        uint16_t qtype;
        unsigned int consumed = 0;
-       DNSName qname(query, queryLen, sizeof(dnsheader), false, &qtype, 0, &consumed);
+       DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, 0, &consumed);
+       DNSQuestion dq(&qname, qtype, &ci.cs->local, &ci.remote, (dnsheader*)query, querySize, qlen, true);
        string ruleresult;
-       struct dnsheader* dh =(dnsheader*)query;
-       const uint16_t * flags = getFlagsFromDNSHeader(dh);
+       const uint16_t * flags = getFlagsFromDNSHeader(dq.dh);
        uint16_t origFlags = *flags;
        struct timespec now;
        clock_gettime(CLOCK_MONOTONIC, &now);
 
        {
          WriteLock wl(&g_rings.queryLock);
-         g_rings.queryRing.push_back({now,ci.remote,qname,queryLen,qtype,*dh});
+         g_rings.queryRing.push_back({now,ci.remote,qname,dq.len,dq.qtype,*dq.dh});
        }
 
        g_stats.queries++;
@@ -233,27 +232,27 @@ void* tcpClientThread(int pipefd)
          }
        }
 
-        if (dh->rd) {
+        if (dq.dh->rd) {
           g_stats.rdQueries++;
         }
 
         if(blockFilter) {
          std::lock_guard<std::mutex> lock(g_luamutex);
        
-         if(blockFilter(ci.remote, qname, qtype, dh)) {
+         if(blockFilter(&dq)) {
            g_stats.blockFilter++;
            goto drop;
           }
-          if(dh->tc && dh->qr) { // don't truncate on TCP/IP!
-            dh->tc=false;        // maybe we should just pass blockFilter the TCP status
-            dh->qr=false;
+          if(dq.dh->tc && dq.dh->qr) { // don't truncate on TCP/IP!
+            dq.dh->tc=false;        // maybe we should just pass blockFilter the TCP status
+            dq.dh->qr=false;
           }
         }
        
        DNSAction::Action action=DNSAction::Action::None;
        for(const auto& lr : *localRulactions) {
-         if(lr.first->matches(ci.remote, qname, qtype, dh, queryLen)) {
-           action=(*lr.second)(ci.remote, qname, qtype, dh, queryLen, querySize, &ruleresult);
+         if(lr.first->matches(&dq)) {
+           action=(*lr.second)(&dq, &ruleresult);
            if(action != DNSAction::Action::None) {
              lr.first->d_matches++;
              break;
@@ -266,8 +265,8 @@ void* tcpClientThread(int pipefd)
          goto drop;
 
        case DNSAction::Action::Nxdomain:
-         dh->rcode = RCode::NXDomain;
-         dh->qr=true;
+         dq.dh->rcode = RCode::NXDomain;
+         dq.dh->qr=true;
          g_stats.ruleNXDomain++;
          break;
        case DNSAction::Action::Pool: 
@@ -284,9 +283,9 @@ void* tcpClientThread(int pipefd)
          break;
        }
        
-       if(dh->qr) { // something turned it into a response
-         if (putNonBlockingMsgLen(ci.fd, queryLen, g_tcpSendTimeout))
-           writen2WithTimeout(ci.fd, query, queryLen, g_tcpSendTimeout);
+       if(dq.dh->qr) { // something turned it into a response
+         if (putNonBlockingMsgLen(ci.fd, dq.len, g_tcpSendTimeout))
+           writen2WithTimeout(ci.fd, query, dq.len, g_tcpSendTimeout);
 
          g_stats.selfAnswered++;
          goto drop;
@@ -294,7 +293,7 @@ void* tcpClientThread(int pipefd)
 
        {
          std::lock_guard<std::mutex> lock(g_luamutex);
-         ds = localPolicy->policy(getDownstreamCandidates(g_dstates.getCopy(), pool), ci.remote, qname, qtype, dh);
+         ds = localPolicy->policy(getDownstreamCandidates(g_dstates.getCopy(), pool), &dq);
        }
        int dsock;
        if(!ds) {
@@ -303,14 +302,14 @@ void* tcpClientThread(int pipefd)
        }
 
         if (ds->useECS) {
-          uint16_t newLen = queryLen;
-          handleEDNSClientSubnet(queryBuffer, querySize, consumed, &newLen, largerQuery, &ednsAdded, ci.remote);
+          uint16_t newLen = dq.len;
+          handleEDNSClientSubnet(queryBuffer, dq.size, consumed, &newLen, largerQuery, &ednsAdded, ci.remote);
           if (largerQuery.empty() == false) {
             query = largerQuery.c_str();
-            queryLen = largerQuery.size();
-            querySize = largerQuery.size();
+            dq.len = largerQuery.size();
+            dq.size = largerQuery.size();
           } else {
-            queryLen = newLen;
+            dq.len = newLen;
           }
         }
 
@@ -323,7 +322,7 @@ void* tcpClientThread(int pipefd)
         ds->queries++;
         ds->outstanding++;
 
-       if(qtype == QType::AXFR || qtype == QType::IXFR)  // XXX fixme we really need to do better
+       if(dq.qtype == QType::AXFR || dq.qtype == QType::IXFR)  // XXX fixme we really need to do better
          break;
 
         uint16_t downstream_failures=0;
@@ -340,7 +339,7 @@ void* tcpClientThread(int pipefd)
           break;
         }
 
-        if(!sendNonBlockingMsgLen(dsock, queryLen, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf)) {
+        if(!sendNonBlockingMsgLen(dsock, dq.len, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf)) {
          vinfolog("Downstream connection to %s died on us, getting a new one!", ds->getName());
           close(dsock);
           sockets[ds->remote]=dsock=setupTCPDownstream(ds);
@@ -350,10 +349,10 @@ void* tcpClientThread(int pipefd)
 
         try {
           if (ds->sourceItf == 0) {
-            writen2WithTimeout(dsock, query, queryLen, ds->tcpSendTimeout);
+            writen2WithTimeout(dsock, query, dq.len, ds->tcpSendTimeout);
           }
           else {
-            sendMsgWithTimeout(dsock, query, queryLen, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf);
+            sendMsgWithTimeout(dsock, query, dq.len, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf);
           }
         }
         catch(const runtime_error& e) {
@@ -455,7 +454,7 @@ void* tcpClientThread(int pipefd)
         unsigned int udiff = 1000000.0*DiffTime(now,answertime);
         {
           std::lock_guard<std::mutex> lock(g_rings.respMutex);
-          g_rings.respRing.push_back({answertime,  ci.remote, qname, qtype, (unsigned int)udiff, (unsigned int)responseLen, *dh, ds->remote});
+          g_rings.respRing.push_back({answertime,  ci.remote, qname, dq.qtype, (unsigned int)udiff, (unsigned int)responseLen, *dq.dh, ds->remote});
         }
 
         largerQuery.clear();
index 6b9b38010f924c312d0d78adf631362a2353879a..699dbce6dbcb702b0cc8508e0411f7e7a5f286f3 100644 (file)
@@ -322,17 +322,17 @@ LuaContext g_lua;
 
 GlobalStateHolder<ServerPolicy> g_policy;
 
-shared_ptr<DownstreamState> firstAvailable(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh)
+shared_ptr<DownstreamState> firstAvailable(const NumberedServerVector& servers, const DNSQuestion* dq)
 {
   for(auto& d : servers) {
     if(d.second->isUp() && d.second->qps.check())
       return d.second;
   }
-  return leastOutstanding(servers, remote, qname, qtype, dh);
+  return leastOutstanding(servers, dq);
 }
 
 // get server with least outstanding queries, and within those, with the lowest order, and within those: the fastest
-shared_ptr<DownstreamState> leastOutstanding(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh)
+shared_ptr<DownstreamState> leastOutstanding(const NumberedServerVector& servers, const DNSQuestion* dq)
 {
   vector<pair<tuple<int,int,double>, shared_ptr<DownstreamState>>> poss;
   /* so you might wonder, why do we go through this trouble? The data on which we sort could change during the sort,
@@ -349,7 +349,7 @@ shared_ptr<DownstreamState> leastOutstanding(const NumberedServerVector& servers
   return poss.begin()->second;
 }
 
-shared_ptr<DownstreamState> valrandom(unsigned int val, const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh)
+shared_ptr<DownstreamState> valrandom(unsigned int val, const NumberedServerVector& servers, const DNSQuestion* dq)
 {
   vector<pair<int, shared_ptr<DownstreamState>>> poss;
   int sum=0;
@@ -371,19 +371,19 @@ shared_ptr<DownstreamState> valrandom(unsigned int val, const NumberedServerVect
   return p->second;
 }
 
-shared_ptr<DownstreamState> wrandom(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh)
+shared_ptr<DownstreamState> wrandom(const NumberedServerVector& servers, const DNSQuestion* dq)
 {
-  return valrandom(random(), servers, remote, qname, qtype, dh);
+  return valrandom(random(), servers, dq);
 }
 
 static uint32_t g_hashperturb;
-shared_ptr<DownstreamState> whashed(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh)
+shared_ptr<DownstreamState> whashed(const NumberedServerVector& servers, const DNSQuestion* dq)
 {
-  return valrandom(qname.hash(g_hashperturb), servers, remote, qname, qtype, dh);
+  return valrandom(dq->qname->hash(g_hashperturb), servers, dq);
 }
 
 
-shared_ptr<DownstreamState> roundrobin(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh)
+shared_ptr<DownstreamState> roundrobin(const NumberedServerVector& servers, const DNSQuestion* dq)
 {
   NumberedServerVector poss;
 
@@ -506,7 +506,7 @@ try
   string largerQuery;
   uint16_t qtype;
 
-  typedef std::function<bool(ComboAddress, DNSName, uint16_t, dnsheader*)> blockfilter_t;
+  typedef std::function<bool(DNSQuestion*)> blockfilter_t;
   blockfilter_t blockFilter = 0;
   {
     std::lock_guard<std::mutex> lock(g_luamutex);
@@ -532,7 +532,6 @@ try
       std::shared_ptr<DnsCryptQuery> dnsCryptQuery = 0;
 #endif
       char* query = packet;
-      size_t querySize = sizeof(packet);
       ssize_t ret = recvmsg(cs->udpFD, &msgh, 0);
 
       cs->queries++;
@@ -556,8 +555,7 @@ try
        continue;
       }
 
-      uint16_t len = ret;
-
+      uint16_t len = (uint16_t) ret;
 #ifdef HAVE_DNSCRYPT
       if (cs->dnscryptCtx) {
         vector<uint8_t> response;
@@ -595,11 +593,14 @@ try
       const uint16_t origFlags = *flags;
       unsigned int consumed = 0;
       DNSName qname(query, len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+
+      DNSQuestion dq(&qname, qtype, &cs->local, &remote, dh, sizeof(packet), len, false);
+
       struct timespec now;
       clock_gettime(CLOCK_MONOTONIC, &now);
       {
         WriteLock wl(&g_rings.queryLock);
-        g_rings.queryRing.push_back({now,remote,qname,len,qtype,*dh});
+        g_rings.queryRing.push_back({now,remote,qname,dq.len,dq.qtype,*dq.dh});
       }
 
       if(auto got=localDynBlock->lookup(remote)) {
@@ -614,7 +615,7 @@ try
       if(blockFilter) {
        std::lock_guard<std::mutex> lock(g_luamutex);
        
-       if(blockFilter(remote, qname, qtype, dh)) {
+       if(blockFilter(&dq)) {
          g_stats.blockFilter++;
          continue;
        }
@@ -625,8 +626,8 @@ try
       string pool;
 
       for(const auto& lr : *localRulactions) {
-       if(lr.first->matches(remote, qname, qtype, dh, len)) {
-         action=(*lr.second)(remote, qname, qtype, dh, len, querySize, &ruleresult);
+       if(lr.first->matches(&dq)) {
+         action=(*lr.second)(&dq, &ruleresult);
          if(action != DNSAction::Action::None) {
            lr.first->d_matches++;
            break;
@@ -639,8 +640,8 @@ try
        g_stats.ruleDrop++;
        continue;
       case DNSAction::Action::Nxdomain:
-       dh->rcode = RCode::NXDomain;
-       dh->qr=true;
+       dq.dh->rcode = RCode::NXDomain;
+       dq.dh->qr=true;
        g_stats.ruleNXDomain++;
        break;
       case DNSAction::Action::Pool: 
@@ -658,11 +659,11 @@ try
        break;
       }
 
-      if(dh->qr) { // something turned it into a response
+      if(dq.dh->qr) { // something turned it into a response
         char* response = query;
-        uint16_t responseLen = len;
+        uint16_t responseLen = dq.len;
 #ifdef HAVE_DNSCRYPT
-        uint16_t responseSize = querySize;
+        uint16_t responseSize = dq.size;
 #endif
         g_stats.selfAnswered++;
 
@@ -694,7 +695,7 @@ try
       auto policy=localPolicy->policy;
       {
        std::lock_guard<std::mutex> lock(g_luamutex);
-       ss = policy(candidates, remote, qname, qtype, dh).get();
+       ss = policy(candidates, &dq).get();
       }
 
       if(!ss) {
@@ -720,7 +721,7 @@ try
       ids->origRemote = remote;
       ids->sentTime.start();
       ids->qname = qname;
-      ids->qtype = qtype;
+      ids->qtype = dq.qtype;
       ids->origDest.sin4.sin_family=0;
       ids->delayMsec = delayMsec;
       ids->origFlags = origFlags;
@@ -733,11 +734,11 @@ try
       dh->id = idOffset;
 
       if (ss->useECS) {
-        handleEDNSClientSubnet(query, querySize, consumed, &len, largerQuery, &(ids->ednsAdded), remote);
+        handleEDNSClientSubnet(query, dq.size, consumed, &dq.len, largerQuery, &(ids->ednsAdded), remote);
       }
 
       if (largerQuery.empty()) {
-        ret = udpClientSendRequestToBackend(ss, ss->fd, query, len);
+        ret = udpClientSendRequestToBackend(ss, ss->fd, query, dq.len);
       }
       else {
         ret = udpClientSendRequestToBackend(ss, ss->fd, largerQuery.c_str(), largerQuery.size());
index ed4d873b5b897df759b9d9079f323ff71a90950e..06c1ff503a215956d92e1c0576f90ccf36c4fe7d 100644 (file)
@@ -347,6 +347,20 @@ struct DownstreamState
 };
 using servers_t =vector<std::shared_ptr<DownstreamState>>;
 
+struct DNSQuestion
+{
+  DNSQuestion(const DNSName* name, uint16_t type, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t queryLen, bool isTcp): qname(name), qtype(type), local(lc), remote(rem), dh(header), size(bufferSize), len(queryLen), tcp(isTcp) {};
+
+  const DNSName* qname;
+  const uint16_t qtype;
+  const ComboAddress* local;
+  const ComboAddress* remote;
+  struct dnsheader* dh;
+  size_t size;
+  uint16_t len;
+  const bool tcp;
+};
+
 template <class T> using NumberedVector = std::vector<std::pair<unsigned int, T> >;
 
 void* responderThread(std::shared_ptr<DownstreamState> state);
@@ -357,7 +371,7 @@ extern std::string g_outputBuffer; // locking for this is ok, as locked by g_lua
 class DNSRule
 {
 public:
-  virtual bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const =0;
+  virtual bool matches(const DNSQuestion* dq) const =0;
   virtual string toString() const = 0;
   mutable std::atomic<uint64_t> d_matches{0};
 };
@@ -375,12 +389,12 @@ class DNSAction
 {
 public:
   enum class Action { Drop, Nxdomain, Spoof, Allow, HeaderModify, Pool, Delay, None};
-  virtual Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const =0;
+  virtual Action operator()(DNSQuestion*, string* ruleresult) const =0;
   virtual string toString() const = 0;
 };
 
 using NumberedServerVector = NumberedVector<shared_ptr<DownstreamState>>;
-typedef std::function<shared_ptr<DownstreamState>(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh)> policyfunc_t;
+typedef std::function<shared_ptr<DownstreamState>(const NumberedServerVector& servers, const DNSQuestion*)> policyfunc_t;
 
 struct ServerPolicy
 {
@@ -433,12 +447,12 @@ void controlThread(int fd, ComboAddress local);
 vector<std::function<void(void)>> setupLua(bool client, const std::string& config);
 NumberedServerVector getDownstreamCandidates(const servers_t& servers, const std::string& pool);
 
-std::shared_ptr<DownstreamState> firstAvailable(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh);
+std::shared_ptr<DownstreamState> firstAvailable(const NumberedServerVector& servers, const DNSQuestion* dq);
 
-std::shared_ptr<DownstreamState> leastOutstanding(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh);
-std::shared_ptr<DownstreamState> wrandom(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh);
-std::shared_ptr<DownstreamState> whashed(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh);
-std::shared_ptr<DownstreamState> roundrobin(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh);
+std::shared_ptr<DownstreamState> leastOutstanding(const NumberedServerVector& servers, const DNSQuestion* dq);
+std::shared_ptr<DownstreamState> wrandom(const NumberedServerVector& servers, const DNSQuestion* dq);
+std::shared_ptr<DownstreamState> whashed(const NumberedServerVector& servers, const DNSQuestion* dq);
+std::shared_ptr<DownstreamState> roundrobin(const NumberedServerVector& servers, const DNSQuestion* dq);
 int getEDNSZ(const char* packet, unsigned int len);
 uint16_t getEDNSOptionCode(const char * packet, size_t len);
 void dnsdistWebserverThread(int sock, const ComboAddress& local, const string& password);
index dba9a8be91c09e1c164bf7da249f0da12504262c..0597352d9353db0b6747c49fdf9c3034632ac65e 100644 (file)
@@ -26,8 +26,8 @@ addPoolRule("192.168.1.0/24", "abuse")
 
 addQPSPoolRule("com.", 100, "abuse")
 
-function luarule(remote, qname, qtype, dh, len)
-       if(qtype==35) -- NAPTR
+function luarule(dq)
+       if(dq.qtype==35) -- NAPTR
        then
                return DNSAction.Pool, "abuse" -- send to abuse pool
        else
@@ -54,16 +54,16 @@ truncateNMG:addMask("fe80::/16")
 
 print(string.format("Have %d entries in truncate NMG", truncateNMG:size()))
 
-function blockFilter(remote, qname, qtype, dh)
-        print(string.format("Got query from %s, (%s) port number: %d", remote:tostring(), remote:tostringWithPort(), remote:getPort()))
-        if(qtype==255 or truncateNMG:match(remote)) 
+function blockFilter(dq)
+        print(string.format("Got query from %s, (%s) port number: %d", dq.remoteaddr:toString(), dq.remoteaddr:toStringWithPort(), dq.remoteaddr:getPort()))
+        if(dq.qtype==255 or truncateNMG:match(dq.remoteaddr))
         then
 --             print("any query, tc=1")
-               dh:setTC(true)
-               dh:setQR(true)
+               dq.dh:setTC(true)
+               dq.dh:setQR(true)
         end
 
-        if(qname:isPartOf(block))
+        if(dq.qname:isPartOf(block))
         then
                print("Blocking *.powerdns.org")
                return true
@@ -76,7 +76,7 @@ blockFilter = nil -- this is how you disable a filter
 counter=0
 
 -- called to pick a downstream server, ignores 'up' status
-function luaroundrobin(servers, remote, qname, qtype, dh) 
+function luaroundrobin(servers, dq)
         counter=counter+1;
         return servers[1+(counter % #servers)]
 end
index 323ee2b6156c9f9cc3fc813a374c97ade24a03b7..d04d83d1ccae1b4ce1ce66eca0f1371f98f3696b 100644 (file)
@@ -9,9 +9,9 @@ public:
     d_qps(qps), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc)
   {}
 
-  bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override
+  bool matches(const DNSQuestion* dq) const override
   {
-    ComboAddress zeroport(remote);
+    ComboAddress zeroport(*dq->remote);
     zeroport.sin4.sin_port=0;
     zeroport.truncate(zeroport.sin4.sin_family == AF_INET ? d_ipv4trunc : d_ipv6trunc);
     auto iter = d_limits.find(zeroport);
@@ -45,7 +45,7 @@ public:
   {}
 
 
-  bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override
+  bool matches(const DNSQuestion* qd) const override
   {
     return d_qps.check();
   }
@@ -69,9 +69,9 @@ public:
   {
 
   }
-  bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override
+  bool matches(const DNSQuestion* dq) const override
   {
-    return d_nmg.match(remote);
+    return d_nmg.match(*dq->remote);
   }
 
   string toString() const override
@@ -86,7 +86,7 @@ class AllRule : public DNSRule
 {
 public:
   AllRule() {}
-  bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override
+  bool matches(const DNSQuestion* dq) const override
   {
     return true;
   }
@@ -106,9 +106,9 @@ public:
   {
 
   }
-  bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override
+  bool matches(const DNSQuestion* dq) const override
   {
-    return dh->cd || (getEDNSZ((const char*)dh, len) & EDNS_HEADER_FLAG_DO);    // turns out dig sets ad by default..
+    return dq->dh->cd || (getEDNSZ((const char*)dq->dh, dq->len) & EDNS_HEADER_FLAG_DO);    // turns out dig sets ad by default..
   }
 
   string toString() const override
@@ -126,11 +126,11 @@ public:
       d_rules.push_back(r.second);
   } 
 
-  bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override
+  bool matches(const DNSQuestion* dq) const override
   {
     auto iter = d_rules.begin();
     for(; iter != d_rules.end(); ++iter)
-      if(!(*iter)->matches(remote, qname, qtype, dh, len))
+      if(!(*iter)->matches(dq))
         break;
     return iter == d_rules.end();
   }
@@ -159,9 +159,9 @@ public:
   {
     
   }
-  bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override
+  bool matches(const DNSQuestion* dq) const override
   {
-    return d_regex.match(qname.toStringNoDot());
+    return d_regex.match(dq->qname->toStringNoDot());
   }
 
   string toString() const override
@@ -180,9 +180,9 @@ public:
   SuffixMatchNodeRule(const SuffixMatchNode& smn) : d_smn(smn)
   {
   }
-  bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override
+  bool matches(const DNSQuestion* dq) const override
   {
-    return d_smn.check(qname);
+    return d_smn.check(*dq->qname);
   }
   string toString() const override
   {
@@ -198,9 +198,9 @@ public:
   QTypeRule(uint16_t qtype) : d_qtype(qtype)
   {
   }
-  bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override
+  bool matches(const DNSQuestion* dq) const override
   {
-    return d_qtype == qtype;
+    return d_qtype == dq->qtype;
   }
   string toString() const override
   {
@@ -214,7 +214,7 @@ private:
 class DropAction : public DNSAction
 {
 public:
-  DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override
+  DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
   {
     return Action::Drop;
   }
@@ -227,7 +227,7 @@ public:
 class AllowAction : public DNSAction
 {
 public:
-  DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override
+  DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
   {
     return Action::Allow;
   }
@@ -243,7 +243,7 @@ class QPSAction : public DNSAction
 public:
   QPSAction(int limit) : d_qps(limit, limit) 
   {}
-  DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override
+  DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
   {
     if(d_qps.check())
       return Action::Allow;
@@ -263,7 +263,7 @@ class DelayAction : public DNSAction
 public:
   DelayAction(int msec) : d_msec(msec)
   {}
-  DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override
+  DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
   {
     *ruleresult=std::to_string(d_msec);
     return Action::Delay;
@@ -281,7 +281,7 @@ class PoolAction : public DNSAction
 {
 public:
   PoolAction(const std::string& pool) : d_pool(pool) {}
-  DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override
+  DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
   {
     *ruleresult=d_pool;
     return Action::Pool;
@@ -300,7 +300,7 @@ class QPSPoolAction : public DNSAction
 {
 public:
   QPSPoolAction(unsigned int limit, const std::string& pool) : d_qps(limit, limit), d_pool(pool) {}
-  DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override
+  DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
   {
     if(d_qps.check()) {
       *ruleresult=d_pool;
@@ -323,10 +323,10 @@ class RCodeAction : public DNSAction
 {
 public:
   RCodeAction(int rcode) : d_rcode(rcode) {}
-  DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override
+  DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
   {
-    dh->rcode = d_rcode;
-    dh->qr = true; // for good measure
+    dq->dh->rcode = d_rcode;
+    dq->dh->qr = true; // for good measure
     return Action::HeaderModify;
   }
   string toString() const override
@@ -341,10 +341,10 @@ private:
 class TCAction : public DNSAction
 {
 public:
-  DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override
+  DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
   {
-    dh->tc = true;
-    dh->qr = true; // for good measure
+    dq->dh->tc = true;
+    dq->dh->qr = true; // for good measure
     return Action::HeaderModify;
   }
   string toString() const override
@@ -359,8 +359,9 @@ public:
   SpoofAction(const ComboAddress& a) : d_a(a) { d_aaaa.sin4.sin_family = 0;}
   SpoofAction(const ComboAddress& a, const ComboAddress& aaaa) : d_a(a), d_aaaa(aaaa) {}
   SpoofAction(const string& cname): d_cname(cname) { d_a.sin4.sin_family = 0; d_aaaa.sin4.sin_family = 0; }
-  DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override
+  DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
   {
+    uint16_t qtype = dq->qtype;
     if(d_cname.empty() &&
        ((qtype == QType::A && d_a.sin4.sin_family == 0) ||
         (qtype == QType::AAAA && d_aaaa.sin4.sin_family == 0) || (qtype != QType::A && qtype != QType::AAAA)))
@@ -381,19 +382,19 @@ public:
                                       0, 0, 0, 60,   // TTL
                                       0, rdatalen};
 
-    DNSName ignore((char*)dh, len, sizeof(dnsheader), false, 0, 0, &consumed);
+    DNSName ignore((char*)dq->dh, dq->len, sizeof(dnsheader), false, 0, 0, &consumed);
 
-    if (bufferSize < (sizeof(dnsheader) + consumed + 4 + sizeof(recordstart) + rdatalen)) {
+    if (dq->size < (sizeof(dnsheader) + consumed + 4 + sizeof(recordstart) + rdatalen)) {
       return Action::None;
     }
 
-    dh->qr = true; // for good measure
-    dh->ra = dh->rd; // for good measure
-    dh->ad = false;
-    dh->ancount = htons(1);
-    dh->arcount = 0; // for now, forget about your EDNS, we're marching over it
+    dq->dh->qr = true; // for good measure
+    dq->dh->ra = dq->dh->rd; // for good measure
+    dq->dh->ad = false;
+    dq->dh->ancount = htons(1);
+    dq->dh->arcount = 0; // for now, forget about your EDNS, we're marching over it
 
-    char* dest = ((char*)dh) +sizeof(dnsheader) + consumed + 4;
+    char* dest = ((char*)dq->dh) +sizeof(dnsheader) + consumed + 4;
     memcpy(dest, recordstart, sizeof(recordstart));
     if(qtype==QType::A) 
       memcpy(dest+sizeof(recordstart), &d_a.sin4.sin_addr.s_addr, 4);
@@ -403,7 +404,7 @@ public:
       string wireData = d_cname.toDNSString();
       memcpy(dest+sizeof(recordstart), wireData.c_str(), wireData.length());
     }
-    len = (dest + sizeof(recordstart) + rdatalen) - (char*)dh;
+    dq->len = (dest + sizeof(recordstart) + rdatalen) - (char*)dq->dh;
     return Action::HeaderModify;
   }
   string toString() const override
@@ -432,9 +433,9 @@ private:
 class NoRecurseAction : public DNSAction
 {
 public:
-  DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override
+  DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
   {
-    dh->rd = false;
+    dq->dh->rd = false;
     return Action::HeaderModify;
   }
   string toString() const override
@@ -462,15 +463,15 @@ public:
     if(d_fp)
       fclose(d_fp);
   }
-  DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override
+  DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
   {
     if(!d_fp) {
-      vinfolog("Packet from %s for %s %s with id %d", remote.toStringWithPort(), qname.toString(), QType(qtype).getName(), dh->id);
+      vinfolog("Packet from %s for %s %s with id %d", dq->remote->toStringWithPort(), dq->qname->toString(), QType(dq->qtype).getName(), dq->dh->id);
     }
     else {
-      string out = qname.toDNSString();
+      string out = dq->qname->toDNSString();
       fwrite(out.c_str(), 1, out.size(), d_fp);
-      fwrite((void*)&qtype, 1, 2, d_fp);
+      fwrite((void*)&dq->qtype, 1, 2, d_fp);
     }
     return Action::None;
   }
@@ -487,9 +488,9 @@ private:
 class DisableValidationAction : public DNSAction
 {
 public:
-  DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override
+  DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
   {
-    dh->cd = true;
+    dq->dh->cd = true;
     return Action::HeaderModify;
   }
   string toString() const override
index 3afcded2177e7034debe68456e8fd99451fe6b0b..c92cb3ea4592de69c895403b11cbd0eafc09fa45 100644 (file)
@@ -41,8 +41,8 @@ class DNSDistTest(unittest.TestCase):
     mySMN:add(newDNSName("nameAndQtype.tests.powerdns.com."))
     addAction(AndRule{SuffixMatchNodeRule(mySMN), QTypeRule("TXT")}, RCodeAction(4))
     block=newDNSName("powerdns.org.")
-    function blockFilter(remote, qname, qtype, dh)
-        if(qname:isPartOf(block))
+    function blockFilter(dq)
+        if(dq.qname:isPartOf(block))
         then
             print("Blocking *.powerdns.org")
             return true
index 949edb7bba5afbcf13f32b23abc6f2cbdb26285f..63c7334cc21f942a4f7f91bf882c4ee0917749ef 100644 (file)
@@ -17,8 +17,8 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
     _config_template = """
     truncateTC(true)
     block=newDNSName("powerdns.org.")
-    function blockFilter(remote, qname, qtype, dh)
-        if(qname:isPartOf(block))
+    function blockFilter(dq)
+        if(dq.qname:isPartOf(block))
         then
             print("Blocking *.powerdns.org")
             return true
@@ -152,8 +152,8 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
     _config_template = """
     truncateTC(true)
     block=newDNSName("powerdns.org.")
-    function blockFilter(remote, qname, qtype, dh)
-        if(qname:isPartOf(block))
+    function blockFilter(dq)
+        if(dq.qname:isPartOf(block))
         then
             print("Blocking *.powerdns.org")
             return true