From 497a6e3ae74fcbd1d68b59c94aa66277e997cedc Mon Sep 17 00:00:00 2001 From: Remi Gacogne Date: Thu, 14 Jan 2016 12:57:33 +0100 Subject: [PATCH] dnsdist: Replace the Lua params with a DNSQuestion `dq` object In order to: 1. Be able to add functions/member without breaking the API 2. Being as compatible as possible with the PowerDNS Lua API To limit the parsing/copy to a minimum, this DNSQuestion differs from the PowerDNS one. Most Lua members are properly wrapped, but it currently lacks some advanced functions like `getRecords()` or `setRecords()`, that we might add later. In addition to the existing `tostring()`, this commit adds `toString()` ones to match the PowerDNS syntax. LuaWrapper is supposed to support read-only members, where you only define the getter and no setter, but I can't find the right syntax for that to work, so for now the setter are present for read-only members, and just do nothing. --- pdns/README-dnsdist.md | 28 ++++-- pdns/dnsdist-lua.cc | 29 ++++-- pdns/dnsdist-tcp.cc | 63 +++++++------ pdns/dnsdist.cc | 53 +++++------ pdns/dnsdist.hh | 30 ++++-- pdns/dnsdistconf.lua | 18 ++-- pdns/dnsrulactions.hh | 93 ++++++++++--------- regression-tests.dnsdist/dnsdisttests.py | 4 +- .../test_EdnsClientSubnet.py | 8 +- 9 files changed, 184 insertions(+), 142 deletions(-) diff --git a/pdns/README-dnsdist.md b/pdns/README-dnsdist.md index e7c9711ce..408b18532 100644 --- a/pdns/README-dnsdist.md +++ b/pdns/README-dnsdist.md @@ -410,8 +410,8 @@ table of domain suffixes. This is identical to how `addPoolRule()` selects. The function should look like this: ``` -function luarule(remote, qname, qtype, dh, len) - if(qtype==35) -- NAPTR +function luarule(dq) + if(dq.qtype==35) -- NAPTR then return DNSAction.Pool, "abuse" -- send to abuse pool else @@ -570,7 +570,7 @@ for example: ``` counter=0 -function luaroundrobin(servers, remote, qname, qtype, dh) +function luaroundrobin(servers, dq) counter=counter+1 return servers[1+(counter % #servers)] end @@ -589,12 +589,12 @@ To implement a split horizon, try: ``` authServer=newServer({address="2001:888:2000:1d::2", pool="auth"}) -function splitSetup(servers, remote, qname, qtype, dh) - if(dh:getRD() == false) +function splitSetup(servers, dq) + if(dq.dh:getRD() == false) then - return leastOutstanding.policy(getPoolServers("auth"), remote, qname, qtype, dh) + return leastOutstanding.policy(getPoolServers("auth"), dq) else - return leastOutstanding.policy(servers, remote, qname, qtype, dh) + return leastOutstanding.policy(servers, dq) end end @@ -873,6 +873,16 @@ instantiate a server with additional parameters * `newDNSName(name)`: make a DNSName based on this .-terminated name * member `isPartOf(dnsname)`: is this dnsname part of that dnsname * member `tostring()`: return as a human friendly . terminated string + * DNSQuestion related: + * member `dh`: DNSHeader + * member `len`: the question length + * member `localaddr`: ComboAddress of the local bind this question was received on + * member `qname`: DNSName of this question + * member `qtype`: QType (as an unsigned integer) of this question + * member `remoteaddr`: ComboAddress of the remote client + * member `rcode`: RCode of this question + * member `size`: the total size of the buffer starting at `dh` + * member `tcp`: whether this question was received over a TCP socket * DNSHeader related * member `getRD()`: get recursion desired flag * member `setRD(bool)`: set recursion desired flag @@ -909,10 +919,10 @@ All hooks --------- `dnsdist` can call Lua per packet if so configured, and will do so with the following hooks: - * `bool blockFilter(ComboAddress, DNSName, qtype, DNSHeader)`: if defined, called for every function. If this + * `bool blockFilter(ComboAddress, DNSQuestion)`: if defined, called for every function. If this returns true, the packet is dropped. If false is returned, `dnsdist` will check if the DNSHeader indicates the packet is now a query response. If so, `dnsdist` will answer the client directly with the modified packet. - * `server policy(candidates, ComboAddress, DNSName, qtype, DNSHeader)`: if configured with `setServerPolicyLua()` + * `server policy(candidates, DNSQuestion)`: if configured with `setServerPolicyLua()` gets called for every packet. Candidates is a table of potential servers to pick from, ComboAddress is the address of the requestor, DNSName and qtype describe name and type of query. DNSHeader meanwhile is available for your inspection. diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index 91df22529..12d442b0d 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -16,19 +16,19 @@ static vector>* g_launchWork; class LuaAction : public DNSAction { public: - typedef std::function(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t len, uint16_t bufferSize)> func_t; + typedef std::function(DNSQuestion* dq)> func_t; LuaAction(LuaAction::func_t func) : d_func(func) {} - Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override + Action operator()(DNSQuestion* dq, string* ruleresult) const override { - auto ret = d_func(remote, qname, qtype, dh, len, bufferSize); + auto ret = d_func(dq); if(ruleresult) *ruleresult=std::get<1>(ret); return (Action)std::get<0>(ret); } - string toString() const + string toString() const override { return "Lua script"; } @@ -658,12 +658,13 @@ vector> setupLua(bool client, const std::string& confi } int matches=0; + ComboAddress dummy("127.0.0.1"); DTime dt; dt.set(); for(int n=0; n < times; ++n) { const item& i = items[n % items.size()]; - struct dnsheader* dh = (struct dnsheader*)&i.packet[0]; - if(rule->matches(i.rem, i.qname, i.qtype, dh, i.packet.size())) + DNSQuestion dq(&i.qname, i.qtype, &i.rem, &i.rem, (struct dnsheader*)&i.packet[0], i.packet.size(), i.packet.size(), false); + if(rule->matches(&dq)) matches++; } double udiff=dt.udiff(); @@ -864,9 +865,12 @@ vector> setupLua(bool client, const std::string& confi g_lua.registerFunction("tostring", &ComboAddress::toString); g_lua.registerFunction("tostringWithPort", &ComboAddress::toStringWithPort); + g_lua.registerFunction("toString", &ComboAddress::toString); + g_lua.registerFunction("toStringWithPort", &ComboAddress::toStringWithPort); g_lua.registerFunction("getPort", [](const ComboAddress& ca) { return ntohs(ca.sin4.sin_port); } ); g_lua.registerFunction("isPartOf", &DNSName::isPartOf); g_lua.registerFunction("tostring", [](const DNSName&dn ) { return dn.toString(); }); + g_lua.registerFunction("toString", [](const DNSName&dn ) { return dn.toString(); }); g_lua.writeFunction("newDNSName", [](const std::string& name) { return DNSName(name); }); g_lua.writeFunction("newSuffixMatchNode", []() { return SuffixMatchNode(); }); @@ -1178,6 +1182,19 @@ vector> setupLua(bool client, const std::string& confi } }); + /* DNSQuestion bindings */ + /* PowerDNS DNSQuestion compat */ + g_lua.registerMember("localaddr", [](const DNSQuestion& dq) -> const ComboAddress { return *dq.local; }, [](DNSQuestion& dq, const ComboAddress newLocal) { (void) newLocal; }); + g_lua.registerMember("qname", [](const DNSQuestion& dq) -> const DNSName { return *dq.qname; }, [](DNSQuestion& dq, const DNSName newName) { (void) newName; }); + g_lua.registerMember("qtype", [](const DNSQuestion& dq) -> uint16_t { return dq.qtype; }, [](DNSQuestion& dq, uint16_t newType) { (void) newType; }); + g_lua.registerMember("rcode", [](const DNSQuestion& dq) -> int { return dq.dh->rcode; }, [](DNSQuestion& dq, int newRCode) { dq.dh->rcode = newRCode; }); + g_lua.registerMember("remoteaddr", [](const DNSQuestion& dq) -> const ComboAddress { return *dq.remote; }, [](DNSQuestion& dq, const ComboAddress newRemote) { (void) newRemote; }); + /* DNSDist DNSQuestion */ + g_lua.registerMember("dh", &DNSQuestion::dh); + g_lua.registerMember("len", [](const DNSQuestion& dq) -> uint16_t { return dq.len; }, [](DNSQuestion& dq, uint16_t newlen) { dq.len = newlen; }); + g_lua.registerMember("size", [](const DNSQuestion& dq) -> size_t { return dq.size; }, [](DNSQuestion& dq, size_t newSize) { (void) newSize; }); + g_lua.registerMember("tcp", [](const DNSQuestion& dq) -> bool { return dq.tcp; }, [](DNSQuestion& dq, bool newTcp) { (void) newTcp; }); + g_lua.writeFunction("setMaxTCPClientThreads", [](uint64_t max) { g_maxTCPClientThreads = max; }); g_lua.writeFunction("setECSSourcePrefixV4", [](uint16_t prefix) { g_ECSSourcePrefixV4=prefix; }); diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index e70599d65..2f467bdb6 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -131,7 +131,7 @@ void* tcpClientThread(int pipefd) /* we get launched with a pipe on which we receive file descriptors from clients that we own from that point on */ - typedef std::function blockfilter_t; + typedef std::function blockfilter_t; blockfilter_t blockFilter = 0; { @@ -182,8 +182,7 @@ void* tcpClientThread(int pipefd) size_t querySize = qlen <= 4096 ? qlen + 512 : qlen; char queryBuffer[querySize]; const char* query = queryBuffer; - uint16_t queryLen = qlen; - readn2WithTimeout(ci.fd, queryBuffer, queryLen, g_tcpRecvTimeout); + readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout); #ifdef HAVE_DNSCRYPT std::shared_ptr dnsCryptQuery = 0; @@ -191,7 +190,7 @@ void* tcpClientThread(int pipefd) dnsCryptQuery = std::make_shared(); uint16_t decryptedQueryLen = 0; vector response; - bool decrypted = handleDnsCryptQuery(ci.cs->dnscryptCtx, queryBuffer, queryLen, dnsCryptQuery, &decryptedQueryLen, true, response); + bool decrypted = handleDnsCryptQuery(ci.cs->dnscryptCtx, queryBuffer, qlen, dnsCryptQuery, &decryptedQueryLen, true, response); if (!decrypted) { if (response.size() > 0) { @@ -200,23 +199,23 @@ void* tcpClientThread(int pipefd) } break; } - queryLen = decryptedQueryLen; + qlen = decryptedQueryLen; } #endif uint16_t qtype; unsigned int consumed = 0; - DNSName qname(query, queryLen, sizeof(dnsheader), false, &qtype, 0, &consumed); + DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, 0, &consumed); + DNSQuestion dq(&qname, qtype, &ci.cs->local, &ci.remote, (dnsheader*)query, querySize, qlen, true); string ruleresult; - struct dnsheader* dh =(dnsheader*)query; - const uint16_t * flags = getFlagsFromDNSHeader(dh); + const uint16_t * flags = getFlagsFromDNSHeader(dq.dh); uint16_t origFlags = *flags; struct timespec now; clock_gettime(CLOCK_MONOTONIC, &now); { WriteLock wl(&g_rings.queryLock); - g_rings.queryRing.push_back({now,ci.remote,qname,queryLen,qtype,*dh}); + g_rings.queryRing.push_back({now,ci.remote,qname,dq.len,dq.qtype,*dq.dh}); } g_stats.queries++; @@ -233,27 +232,27 @@ void* tcpClientThread(int pipefd) } } - if (dh->rd) { + if (dq.dh->rd) { g_stats.rdQueries++; } if(blockFilter) { std::lock_guard lock(g_luamutex); - if(blockFilter(ci.remote, qname, qtype, dh)) { + if(blockFilter(&dq)) { g_stats.blockFilter++; goto drop; } - if(dh->tc && dh->qr) { // don't truncate on TCP/IP! - dh->tc=false; // maybe we should just pass blockFilter the TCP status - dh->qr=false; + if(dq.dh->tc && dq.dh->qr) { // don't truncate on TCP/IP! + dq.dh->tc=false; // maybe we should just pass blockFilter the TCP status + dq.dh->qr=false; } } DNSAction::Action action=DNSAction::Action::None; for(const auto& lr : *localRulactions) { - if(lr.first->matches(ci.remote, qname, qtype, dh, queryLen)) { - action=(*lr.second)(ci.remote, qname, qtype, dh, queryLen, querySize, &ruleresult); + if(lr.first->matches(&dq)) { + action=(*lr.second)(&dq, &ruleresult); if(action != DNSAction::Action::None) { lr.first->d_matches++; break; @@ -266,8 +265,8 @@ void* tcpClientThread(int pipefd) goto drop; case DNSAction::Action::Nxdomain: - dh->rcode = RCode::NXDomain; - dh->qr=true; + dq.dh->rcode = RCode::NXDomain; + dq.dh->qr=true; g_stats.ruleNXDomain++; break; case DNSAction::Action::Pool: @@ -284,9 +283,9 @@ void* tcpClientThread(int pipefd) break; } - if(dh->qr) { // something turned it into a response - if (putNonBlockingMsgLen(ci.fd, queryLen, g_tcpSendTimeout)) - writen2WithTimeout(ci.fd, query, queryLen, g_tcpSendTimeout); + if(dq.dh->qr) { // something turned it into a response + if (putNonBlockingMsgLen(ci.fd, dq.len, g_tcpSendTimeout)) + writen2WithTimeout(ci.fd, query, dq.len, g_tcpSendTimeout); g_stats.selfAnswered++; goto drop; @@ -294,7 +293,7 @@ void* tcpClientThread(int pipefd) { std::lock_guard lock(g_luamutex); - ds = localPolicy->policy(getDownstreamCandidates(g_dstates.getCopy(), pool), ci.remote, qname, qtype, dh); + ds = localPolicy->policy(getDownstreamCandidates(g_dstates.getCopy(), pool), &dq); } int dsock; if(!ds) { @@ -303,14 +302,14 @@ void* tcpClientThread(int pipefd) } if (ds->useECS) { - uint16_t newLen = queryLen; - handleEDNSClientSubnet(queryBuffer, querySize, consumed, &newLen, largerQuery, &ednsAdded, ci.remote); + uint16_t newLen = dq.len; + handleEDNSClientSubnet(queryBuffer, dq.size, consumed, &newLen, largerQuery, &ednsAdded, ci.remote); if (largerQuery.empty() == false) { query = largerQuery.c_str(); - queryLen = largerQuery.size(); - querySize = largerQuery.size(); + dq.len = largerQuery.size(); + dq.size = largerQuery.size(); } else { - queryLen = newLen; + dq.len = newLen; } } @@ -323,7 +322,7 @@ void* tcpClientThread(int pipefd) ds->queries++; ds->outstanding++; - if(qtype == QType::AXFR || qtype == QType::IXFR) // XXX fixme we really need to do better + if(dq.qtype == QType::AXFR || dq.qtype == QType::IXFR) // XXX fixme we really need to do better break; uint16_t downstream_failures=0; @@ -340,7 +339,7 @@ void* tcpClientThread(int pipefd) break; } - if(!sendNonBlockingMsgLen(dsock, queryLen, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf)) { + if(!sendNonBlockingMsgLen(dsock, dq.len, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf)) { vinfolog("Downstream connection to %s died on us, getting a new one!", ds->getName()); close(dsock); sockets[ds->remote]=dsock=setupTCPDownstream(ds); @@ -350,10 +349,10 @@ void* tcpClientThread(int pipefd) try { if (ds->sourceItf == 0) { - writen2WithTimeout(dsock, query, queryLen, ds->tcpSendTimeout); + writen2WithTimeout(dsock, query, dq.len, ds->tcpSendTimeout); } else { - sendMsgWithTimeout(dsock, query, queryLen, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf); + sendMsgWithTimeout(dsock, query, dq.len, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf); } } catch(const runtime_error& e) { @@ -455,7 +454,7 @@ void* tcpClientThread(int pipefd) unsigned int udiff = 1000000.0*DiffTime(now,answertime); { std::lock_guard lock(g_rings.respMutex); - g_rings.respRing.push_back({answertime, ci.remote, qname, qtype, (unsigned int)udiff, (unsigned int)responseLen, *dh, ds->remote}); + g_rings.respRing.push_back({answertime, ci.remote, qname, dq.qtype, (unsigned int)udiff, (unsigned int)responseLen, *dq.dh, ds->remote}); } largerQuery.clear(); diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 6b9b38010..699dbce6d 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -322,17 +322,17 @@ LuaContext g_lua; GlobalStateHolder g_policy; -shared_ptr firstAvailable(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh) +shared_ptr firstAvailable(const NumberedServerVector& servers, const DNSQuestion* dq) { for(auto& d : servers) { if(d.second->isUp() && d.second->qps.check()) return d.second; } - return leastOutstanding(servers, remote, qname, qtype, dh); + return leastOutstanding(servers, dq); } // get server with least outstanding queries, and within those, with the lowest order, and within those: the fastest -shared_ptr leastOutstanding(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh) +shared_ptr leastOutstanding(const NumberedServerVector& servers, const DNSQuestion* dq) { vector, shared_ptr>> poss; /* so you might wonder, why do we go through this trouble? The data on which we sort could change during the sort, @@ -349,7 +349,7 @@ shared_ptr leastOutstanding(const NumberedServerVector& servers return poss.begin()->second; } -shared_ptr valrandom(unsigned int val, const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh) +shared_ptr valrandom(unsigned int val, const NumberedServerVector& servers, const DNSQuestion* dq) { vector>> poss; int sum=0; @@ -371,19 +371,19 @@ shared_ptr valrandom(unsigned int val, const NumberedServerVect return p->second; } -shared_ptr wrandom(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh) +shared_ptr wrandom(const NumberedServerVector& servers, const DNSQuestion* dq) { - return valrandom(random(), servers, remote, qname, qtype, dh); + return valrandom(random(), servers, dq); } static uint32_t g_hashperturb; -shared_ptr whashed(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh) +shared_ptr whashed(const NumberedServerVector& servers, const DNSQuestion* dq) { - return valrandom(qname.hash(g_hashperturb), servers, remote, qname, qtype, dh); + return valrandom(dq->qname->hash(g_hashperturb), servers, dq); } -shared_ptr roundrobin(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh) +shared_ptr roundrobin(const NumberedServerVector& servers, const DNSQuestion* dq) { NumberedServerVector poss; @@ -506,7 +506,7 @@ try string largerQuery; uint16_t qtype; - typedef std::function blockfilter_t; + typedef std::function blockfilter_t; blockfilter_t blockFilter = 0; { std::lock_guard lock(g_luamutex); @@ -532,7 +532,6 @@ try std::shared_ptr dnsCryptQuery = 0; #endif char* query = packet; - size_t querySize = sizeof(packet); ssize_t ret = recvmsg(cs->udpFD, &msgh, 0); cs->queries++; @@ -556,8 +555,7 @@ try continue; } - uint16_t len = ret; - + uint16_t len = (uint16_t) ret; #ifdef HAVE_DNSCRYPT if (cs->dnscryptCtx) { vector response; @@ -595,11 +593,14 @@ try const uint16_t origFlags = *flags; unsigned int consumed = 0; DNSName qname(query, len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + + DNSQuestion dq(&qname, qtype, &cs->local, &remote, dh, sizeof(packet), len, false); + struct timespec now; clock_gettime(CLOCK_MONOTONIC, &now); { WriteLock wl(&g_rings.queryLock); - g_rings.queryRing.push_back({now,remote,qname,len,qtype,*dh}); + g_rings.queryRing.push_back({now,remote,qname,dq.len,dq.qtype,*dq.dh}); } if(auto got=localDynBlock->lookup(remote)) { @@ -614,7 +615,7 @@ try if(blockFilter) { std::lock_guard lock(g_luamutex); - if(blockFilter(remote, qname, qtype, dh)) { + if(blockFilter(&dq)) { g_stats.blockFilter++; continue; } @@ -625,8 +626,8 @@ try string pool; for(const auto& lr : *localRulactions) { - if(lr.first->matches(remote, qname, qtype, dh, len)) { - action=(*lr.second)(remote, qname, qtype, dh, len, querySize, &ruleresult); + if(lr.first->matches(&dq)) { + action=(*lr.second)(&dq, &ruleresult); if(action != DNSAction::Action::None) { lr.first->d_matches++; break; @@ -639,8 +640,8 @@ try g_stats.ruleDrop++; continue; case DNSAction::Action::Nxdomain: - dh->rcode = RCode::NXDomain; - dh->qr=true; + dq.dh->rcode = RCode::NXDomain; + dq.dh->qr=true; g_stats.ruleNXDomain++; break; case DNSAction::Action::Pool: @@ -658,11 +659,11 @@ try break; } - if(dh->qr) { // something turned it into a response + if(dq.dh->qr) { // something turned it into a response char* response = query; - uint16_t responseLen = len; + uint16_t responseLen = dq.len; #ifdef HAVE_DNSCRYPT - uint16_t responseSize = querySize; + uint16_t responseSize = dq.size; #endif g_stats.selfAnswered++; @@ -694,7 +695,7 @@ try auto policy=localPolicy->policy; { std::lock_guard lock(g_luamutex); - ss = policy(candidates, remote, qname, qtype, dh).get(); + ss = policy(candidates, &dq).get(); } if(!ss) { @@ -720,7 +721,7 @@ try ids->origRemote = remote; ids->sentTime.start(); ids->qname = qname; - ids->qtype = qtype; + ids->qtype = dq.qtype; ids->origDest.sin4.sin_family=0; ids->delayMsec = delayMsec; ids->origFlags = origFlags; @@ -733,11 +734,11 @@ try dh->id = idOffset; if (ss->useECS) { - handleEDNSClientSubnet(query, querySize, consumed, &len, largerQuery, &(ids->ednsAdded), remote); + handleEDNSClientSubnet(query, dq.size, consumed, &dq.len, largerQuery, &(ids->ednsAdded), remote); } if (largerQuery.empty()) { - ret = udpClientSendRequestToBackend(ss, ss->fd, query, len); + ret = udpClientSendRequestToBackend(ss, ss->fd, query, dq.len); } else { ret = udpClientSendRequestToBackend(ss, ss->fd, largerQuery.c_str(), largerQuery.size()); diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index ed4d873b5..06c1ff503 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -347,6 +347,20 @@ struct DownstreamState }; using servers_t =vector>; +struct DNSQuestion +{ + DNSQuestion(const DNSName* name, uint16_t type, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t queryLen, bool isTcp): qname(name), qtype(type), local(lc), remote(rem), dh(header), size(bufferSize), len(queryLen), tcp(isTcp) {}; + + const DNSName* qname; + const uint16_t qtype; + const ComboAddress* local; + const ComboAddress* remote; + struct dnsheader* dh; + size_t size; + uint16_t len; + const bool tcp; +}; + template using NumberedVector = std::vector >; void* responderThread(std::shared_ptr state); @@ -357,7 +371,7 @@ extern std::string g_outputBuffer; // locking for this is ok, as locked by g_lua class DNSRule { public: - virtual bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const =0; + virtual bool matches(const DNSQuestion* dq) const =0; virtual string toString() const = 0; mutable std::atomic d_matches{0}; }; @@ -375,12 +389,12 @@ class DNSAction { public: enum class Action { Drop, Nxdomain, Spoof, Allow, HeaderModify, Pool, Delay, None}; - virtual Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const =0; + virtual Action operator()(DNSQuestion*, string* ruleresult) const =0; virtual string toString() const = 0; }; using NumberedServerVector = NumberedVector>; -typedef std::function(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh)> policyfunc_t; +typedef std::function(const NumberedServerVector& servers, const DNSQuestion*)> policyfunc_t; struct ServerPolicy { @@ -433,12 +447,12 @@ void controlThread(int fd, ComboAddress local); vector> setupLua(bool client, const std::string& config); NumberedServerVector getDownstreamCandidates(const servers_t& servers, const std::string& pool); -std::shared_ptr firstAvailable(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh); +std::shared_ptr firstAvailable(const NumberedServerVector& servers, const DNSQuestion* dq); -std::shared_ptr leastOutstanding(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh); -std::shared_ptr wrandom(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh); -std::shared_ptr whashed(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh); -std::shared_ptr roundrobin(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh); +std::shared_ptr leastOutstanding(const NumberedServerVector& servers, const DNSQuestion* dq); +std::shared_ptr wrandom(const NumberedServerVector& servers, const DNSQuestion* dq); +std::shared_ptr whashed(const NumberedServerVector& servers, const DNSQuestion* dq); +std::shared_ptr roundrobin(const NumberedServerVector& servers, const DNSQuestion* dq); int getEDNSZ(const char* packet, unsigned int len); uint16_t getEDNSOptionCode(const char * packet, size_t len); void dnsdistWebserverThread(int sock, const ComboAddress& local, const string& password); diff --git a/pdns/dnsdistconf.lua b/pdns/dnsdistconf.lua index dba9a8be9..0597352d9 100644 --- a/pdns/dnsdistconf.lua +++ b/pdns/dnsdistconf.lua @@ -26,8 +26,8 @@ addPoolRule("192.168.1.0/24", "abuse") addQPSPoolRule("com.", 100, "abuse") -function luarule(remote, qname, qtype, dh, len) - if(qtype==35) -- NAPTR +function luarule(dq) + if(dq.qtype==35) -- NAPTR then return DNSAction.Pool, "abuse" -- send to abuse pool else @@ -54,16 +54,16 @@ truncateNMG:addMask("fe80::/16") print(string.format("Have %d entries in truncate NMG", truncateNMG:size())) -function blockFilter(remote, qname, qtype, dh) - print(string.format("Got query from %s, (%s) port number: %d", remote:tostring(), remote:tostringWithPort(), remote:getPort())) - if(qtype==255 or truncateNMG:match(remote)) +function blockFilter(dq) + print(string.format("Got query from %s, (%s) port number: %d", dq.remoteaddr:toString(), dq.remoteaddr:toStringWithPort(), dq.remoteaddr:getPort())) + if(dq.qtype==255 or truncateNMG:match(dq.remoteaddr)) then -- print("any query, tc=1") - dh:setTC(true) - dh:setQR(true) + dq.dh:setTC(true) + dq.dh:setQR(true) end - if(qname:isPartOf(block)) + if(dq.qname:isPartOf(block)) then print("Blocking *.powerdns.org") return true @@ -76,7 +76,7 @@ blockFilter = nil -- this is how you disable a filter counter=0 -- called to pick a downstream server, ignores 'up' status -function luaroundrobin(servers, remote, qname, qtype, dh) +function luaroundrobin(servers, dq) counter=counter+1; return servers[1+(counter % #servers)] end diff --git a/pdns/dnsrulactions.hh b/pdns/dnsrulactions.hh index 323ee2b61..d04d83d1c 100644 --- a/pdns/dnsrulactions.hh +++ b/pdns/dnsrulactions.hh @@ -9,9 +9,9 @@ public: d_qps(qps), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc) {} - bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override + bool matches(const DNSQuestion* dq) const override { - ComboAddress zeroport(remote); + ComboAddress zeroport(*dq->remote); zeroport.sin4.sin_port=0; zeroport.truncate(zeroport.sin4.sin_family == AF_INET ? d_ipv4trunc : d_ipv6trunc); auto iter = d_limits.find(zeroport); @@ -45,7 +45,7 @@ public: {} - bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override + bool matches(const DNSQuestion* qd) const override { return d_qps.check(); } @@ -69,9 +69,9 @@ public: { } - bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override + bool matches(const DNSQuestion* dq) const override { - return d_nmg.match(remote); + return d_nmg.match(*dq->remote); } string toString() const override @@ -86,7 +86,7 @@ class AllRule : public DNSRule { public: AllRule() {} - bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override + bool matches(const DNSQuestion* dq) const override { return true; } @@ -106,9 +106,9 @@ public: { } - bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override + bool matches(const DNSQuestion* dq) const override { - return dh->cd || (getEDNSZ((const char*)dh, len) & EDNS_HEADER_FLAG_DO); // turns out dig sets ad by default.. + return dq->dh->cd || (getEDNSZ((const char*)dq->dh, dq->len) & EDNS_HEADER_FLAG_DO); // turns out dig sets ad by default.. } string toString() const override @@ -126,11 +126,11 @@ public: d_rules.push_back(r.second); } - bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override + bool matches(const DNSQuestion* dq) const override { auto iter = d_rules.begin(); for(; iter != d_rules.end(); ++iter) - if(!(*iter)->matches(remote, qname, qtype, dh, len)) + if(!(*iter)->matches(dq)) break; return iter == d_rules.end(); } @@ -159,9 +159,9 @@ public: { } - bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override + bool matches(const DNSQuestion* dq) const override { - return d_regex.match(qname.toStringNoDot()); + return d_regex.match(dq->qname->toStringNoDot()); } string toString() const override @@ -180,9 +180,9 @@ public: SuffixMatchNodeRule(const SuffixMatchNode& smn) : d_smn(smn) { } - bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override + bool matches(const DNSQuestion* dq) const override { - return d_smn.check(qname); + return d_smn.check(*dq->qname); } string toString() const override { @@ -198,9 +198,9 @@ public: QTypeRule(uint16_t qtype) : d_qtype(qtype) { } - bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override + bool matches(const DNSQuestion* dq) const override { - return d_qtype == qtype; + return d_qtype == dq->qtype; } string toString() const override { @@ -214,7 +214,7 @@ private: class DropAction : public DNSAction { public: - DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override { return Action::Drop; } @@ -227,7 +227,7 @@ public: class AllowAction : public DNSAction { public: - DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override { return Action::Allow; } @@ -243,7 +243,7 @@ class QPSAction : public DNSAction public: QPSAction(int limit) : d_qps(limit, limit) {} - DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override { if(d_qps.check()) return Action::Allow; @@ -263,7 +263,7 @@ class DelayAction : public DNSAction public: DelayAction(int msec) : d_msec(msec) {} - DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override { *ruleresult=std::to_string(d_msec); return Action::Delay; @@ -281,7 +281,7 @@ class PoolAction : public DNSAction { public: PoolAction(const std::string& pool) : d_pool(pool) {} - DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override { *ruleresult=d_pool; return Action::Pool; @@ -300,7 +300,7 @@ class QPSPoolAction : public DNSAction { public: QPSPoolAction(unsigned int limit, const std::string& pool) : d_qps(limit, limit), d_pool(pool) {} - DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override { if(d_qps.check()) { *ruleresult=d_pool; @@ -323,10 +323,10 @@ class RCodeAction : public DNSAction { public: RCodeAction(int rcode) : d_rcode(rcode) {} - DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override { - dh->rcode = d_rcode; - dh->qr = true; // for good measure + dq->dh->rcode = d_rcode; + dq->dh->qr = true; // for good measure return Action::HeaderModify; } string toString() const override @@ -341,10 +341,10 @@ private: class TCAction : public DNSAction { public: - DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override { - dh->tc = true; - dh->qr = true; // for good measure + dq->dh->tc = true; + dq->dh->qr = true; // for good measure return Action::HeaderModify; } string toString() const override @@ -359,8 +359,9 @@ public: SpoofAction(const ComboAddress& a) : d_a(a) { d_aaaa.sin4.sin_family = 0;} SpoofAction(const ComboAddress& a, const ComboAddress& aaaa) : d_a(a), d_aaaa(aaaa) {} SpoofAction(const string& cname): d_cname(cname) { d_a.sin4.sin_family = 0; d_aaaa.sin4.sin_family = 0; } - DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override { + uint16_t qtype = dq->qtype; if(d_cname.empty() && ((qtype == QType::A && d_a.sin4.sin_family == 0) || (qtype == QType::AAAA && d_aaaa.sin4.sin_family == 0) || (qtype != QType::A && qtype != QType::AAAA))) @@ -381,19 +382,19 @@ public: 0, 0, 0, 60, // TTL 0, rdatalen}; - DNSName ignore((char*)dh, len, sizeof(dnsheader), false, 0, 0, &consumed); + DNSName ignore((char*)dq->dh, dq->len, sizeof(dnsheader), false, 0, 0, &consumed); - if (bufferSize < (sizeof(dnsheader) + consumed + 4 + sizeof(recordstart) + rdatalen)) { + if (dq->size < (sizeof(dnsheader) + consumed + 4 + sizeof(recordstart) + rdatalen)) { return Action::None; } - dh->qr = true; // for good measure - dh->ra = dh->rd; // for good measure - dh->ad = false; - dh->ancount = htons(1); - dh->arcount = 0; // for now, forget about your EDNS, we're marching over it + dq->dh->qr = true; // for good measure + dq->dh->ra = dq->dh->rd; // for good measure + dq->dh->ad = false; + dq->dh->ancount = htons(1); + dq->dh->arcount = 0; // for now, forget about your EDNS, we're marching over it - char* dest = ((char*)dh) +sizeof(dnsheader) + consumed + 4; + char* dest = ((char*)dq->dh) +sizeof(dnsheader) + consumed + 4; memcpy(dest, recordstart, sizeof(recordstart)); if(qtype==QType::A) memcpy(dest+sizeof(recordstart), &d_a.sin4.sin_addr.s_addr, 4); @@ -403,7 +404,7 @@ public: string wireData = d_cname.toDNSString(); memcpy(dest+sizeof(recordstart), wireData.c_str(), wireData.length()); } - len = (dest + sizeof(recordstart) + rdatalen) - (char*)dh; + dq->len = (dest + sizeof(recordstart) + rdatalen) - (char*)dq->dh; return Action::HeaderModify; } string toString() const override @@ -432,9 +433,9 @@ private: class NoRecurseAction : public DNSAction { public: - DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override { - dh->rd = false; + dq->dh->rd = false; return Action::HeaderModify; } string toString() const override @@ -462,15 +463,15 @@ public: if(d_fp) fclose(d_fp); } - DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override { if(!d_fp) { - vinfolog("Packet from %s for %s %s with id %d", remote.toStringWithPort(), qname.toString(), QType(qtype).getName(), dh->id); + vinfolog("Packet from %s for %s %s with id %d", dq->remote->toStringWithPort(), dq->qname->toString(), QType(dq->qtype).getName(), dq->dh->id); } else { - string out = qname.toDNSString(); + string out = dq->qname->toDNSString(); fwrite(out.c_str(), 1, out.size(), d_fp); - fwrite((void*)&qtype, 1, 2, d_fp); + fwrite((void*)&dq->qtype, 1, 2, d_fp); } return Action::None; } @@ -487,9 +488,9 @@ private: class DisableValidationAction : public DNSAction { public: - DNSAction::Action operator()(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, uint16_t& len, uint16_t bufferSize, string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override { - dh->cd = true; + dq->dh->cd = true; return Action::HeaderModify; } string toString() const override diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 3afcded21..c92cb3ea4 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -41,8 +41,8 @@ class DNSDistTest(unittest.TestCase): mySMN:add(newDNSName("nameAndQtype.tests.powerdns.com.")) addAction(AndRule{SuffixMatchNodeRule(mySMN), QTypeRule("TXT")}, RCodeAction(4)) block=newDNSName("powerdns.org.") - function blockFilter(remote, qname, qtype, dh) - if(qname:isPartOf(block)) + function blockFilter(dq) + if(dq.qname:isPartOf(block)) then print("Blocking *.powerdns.org") return true diff --git a/regression-tests.dnsdist/test_EdnsClientSubnet.py b/regression-tests.dnsdist/test_EdnsClientSubnet.py index 949edb7bb..63c7334cc 100644 --- a/regression-tests.dnsdist/test_EdnsClientSubnet.py +++ b/regression-tests.dnsdist/test_EdnsClientSubnet.py @@ -17,8 +17,8 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest): _config_template = """ truncateTC(true) block=newDNSName("powerdns.org.") - function blockFilter(remote, qname, qtype, dh) - if(qname:isPartOf(block)) + function blockFilter(dq) + if(dq.qname:isPartOf(block)) then print("Blocking *.powerdns.org") return true @@ -152,8 +152,8 @@ class TestEdnsClientSubnetOverride(DNSDistTest): _config_template = """ truncateTC(true) block=newDNSName("powerdns.org.") - function blockFilter(remote, qname, qtype, dh) - if(qname:isPartOf(block)) + function blockFilter(dq) + if(dq.qname:isPartOf(block)) then print("Blocking *.powerdns.org") return true -- 2.40.0