From: Remi Gacogne Date: Wed, 26 Jul 2017 16:44:30 +0000 (+0200) Subject: dnsdist: Add support for handling UDP queries via recvmmsg() and sendmmsg() X-Git-Tag: rec-4.1.0-rc1~22^2~6 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=0beaa5c825059986d3ce108a4bd2578a08d6d1d0;p=pdns dnsdist: Add support for handling UDP queries via recvmmsg() and sendmmsg() --- diff --git a/m4/pdns_check_network_libs.m4 b/m4/pdns_check_network_libs.m4 index 19e8d5899..36169bbae 100644 --- a/m4/pdns_check_network_libs.m4 +++ b/m4/pdns_check_network_libs.m4 @@ -3,4 +3,5 @@ AC_DEFUN([PDNS_CHECK_NETWORK_LIBS],[ AC_SEARCH_LIBS([gethostbyname], [nsl]) AC_SEARCH_LIBS([socket], [socket]) AC_SEARCH_LIBS([gethostent], [nsl]) + AC_CHECK_FUNCS([recvmmsg sendmmsg]) ]) diff --git a/pdns/dnsdist-console.cc b/pdns/dnsdist-console.cc index f34b22488..04e796322 100644 --- a/pdns/dnsdist-console.cc +++ b/pdns/dnsdist-console.cc @@ -387,6 +387,7 @@ const std::vector g_consoleKeywords{ { "setTCPUseSinglePipe", true, "bool", "whether the incoming TCP connections should be put into a single queue instead of using per-thread queues. Defaults to false" }, { "setTCPRecvTimeout", true, "n", "set the read timeout on TCP connections from the client, in seconds" }, { "setTCPSendTimeout", true, "n", "set the write timeout on TCP connections from the client, in seconds" }, + { "setUDPMultipleMessagesVectorSize", true, "n", "set the size of the vector passed to recvmmsg() to receive UDP messages. Default to 1 which means that the feature is disabled and recvmsg() is used instead" }, { "setUDPTimeout", true, "n", "set the maximum time dnsdist will wait for a response from a backend over UDP, in seconds" }, { "setVerboseHealthChecks", true, "bool", "set whether health check errors will be logged" }, { "show", true, "string", "outputs `string`" }, diff --git a/pdns/dnsdist-ecs.cc b/pdns/dnsdist-ecs.cc index 2a51d492b..eaaa6c8fc 100644 --- a/pdns/dnsdist-ecs.cc +++ b/pdns/dnsdist-ecs.cc @@ -259,7 +259,7 @@ void generateOptRR(const std::string& optRData, string& res) res.append(optRData.c_str(), optRData.length()); } -static void replaceEDNSClientSubnetOption(char * const packet, const size_t packetSize, uint16_t * const len, string& largerPacket, const ComboAddress& remote, char * const oldEcsOptionStart, size_t const oldEcsOptionSize, unsigned char * const optRDLen, uint16_t ECSPrefixLength) +static bool replaceEDNSClientSubnetOption(char * const packet, const size_t packetSize, uint16_t * const len, const ComboAddress& remote, char * const oldEcsOptionStart, size_t const oldEcsOptionSize, unsigned char * const optRDLen, uint16_t ECSPrefixLength) { assert(packet != NULL); assert(len != NULL); @@ -277,39 +277,29 @@ static void replaceEDNSClientSubnetOption(char * const packet, const size_t pack const unsigned int newPacketLen = *len + (ECSOption.length() - oldEcsOptionSize); const size_t beforeOptionLen = oldEcsOptionStart - packet; const size_t dataBehindSize = *len - beforeOptionLen - oldEcsOptionSize; - + + /* check that it fits in the existing buffer */ + if (newPacketLen > packetSize) { + return false; + } + /* fix the size of ECS Option RDLen */ uint16_t newRDLen = (optRDLen[0] * 256) + optRDLen[1]; newRDLen += (ECSOption.size() - oldEcsOptionSize); optRDLen[0] = newRDLen / 256; optRDLen[1] = newRDLen % 256; - - if (newPacketLen <= packetSize) { - /* it fits in the existing buffer */ - if (dataBehindSize > 0) { - memmove(oldEcsOptionStart, oldEcsOptionStart + oldEcsOptionSize, dataBehindSize); - } - memcpy(oldEcsOptionStart + dataBehindSize, ECSOption.c_str(), ECSOption.size()); - *len = newPacketLen; - } - else { - /* We need a larger packet */ - if (newPacketLen > largerPacket.capacity()) { - largerPacket.reserve(newPacketLen); - } - /* copy data before the existing option */ - largerPacket.append(packet, beforeOptionLen); - /* copy the new option */ - largerPacket.append(ECSOption); - /* copy data that where behind the existing option */ - if (dataBehindSize > 0) { - largerPacket.append(oldEcsOptionStart + oldEcsOptionSize, dataBehindSize); - } + + if (dataBehindSize > 0) { + memmove(oldEcsOptionStart, oldEcsOptionStart + oldEcsOptionSize, dataBehindSize); } + memcpy(oldEcsOptionStart + dataBehindSize, ECSOption.c_str(), ECSOption.size()); + *len = newPacketLen; } + + return true; } -void handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, string& largerPacket, bool* const ednsAdded, bool* const ecsAdded, const ComboAddress& remote, bool overrideExisting, uint16_t ecsPrefixLength) +bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool* const ednsAdded, bool* const ecsAdded, const ComboAddress& remote, bool overrideExisting, uint16_t ecsPrefixLength) { assert(packet != NULL); assert(len != NULL); @@ -330,7 +320,7 @@ void handleEDNSClientSubnet(char* const packet, const size_t packetSize, const u if (res == 0) { /* there is already an ECS value */ if (overrideExisting) { - replaceEDNSClientSubnetOption(packet, packetSize, len, largerPacket, remote, ecsOptionStart, ecsOptionSize, optRDLen, ecsPrefixLength); + return replaceEDNSClientSubnetOption(packet, packetSize, len, remote, ecsOptionStart, ecsOptionSize, optRDLen, ecsPrefixLength); } } else { /* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */ @@ -340,24 +330,18 @@ void handleEDNSClientSubnet(char* const packet, const size_t packetSize, const u generateECSOption(remote, ECSOption, ecsPrefixLength); const size_t ECSOptionSize = ECSOption.size(); + /* check if the existing buffer is large enough */ + if (packetSize - *len <= ECSOptionSize) { + return false; + } + uint16_t newRDLen = (optRDLen[0] * 256) + optRDLen[1]; newRDLen += ECSOptionSize; optRDLen[0] = newRDLen / 256; optRDLen[1] = newRDLen % 256; - if (packetSize - *len > ECSOptionSize) { - /* if the existing buffer is large enough */ - memcpy(packet + *len, ECSOption.c_str(), ECSOptionSize); - *len += ECSOptionSize; - } - else { - if (*len + ECSOptionSize > largerPacket.capacity()) { - largerPacket.reserve(*len + ECSOptionSize); - } - - largerPacket.append(packet, *len); - largerPacket.append(ECSOption); - } + memcpy(packet + *len, ECSOption.c_str(), ECSOptionSize); + *len += ECSOptionSize; *ecsAdded = true; } } @@ -368,25 +352,22 @@ void handleEDNSClientSubnet(char* const packet, const size_t packetSize, const u string optRData; generateECSOption(remote, optRData, ecsPrefixLength); generateOptRR(optRData, EDNSRR); + + /* does it fit in the existing buffer? */ + if (packetSize - *len <= EDNSRR.size()) { + return false; + } + uint16_t arcount = ntohs(dh->arcount); arcount++; dh->arcount = htons(arcount); *ednsAdded = true; - /* does it fit in the existing buffer? */ - if (packetSize - *len > EDNSRR.size()) { - memcpy(packet + *len, EDNSRR.c_str(), EDNSRR.size()); - *len += EDNSRR.size(); - } - else { - if (*len + EDNSRR.size() > largerPacket.capacity()) { - largerPacket.reserve(*len + EDNSRR.size()); - } - - largerPacket.append(packet, *len); - largerPacket.append(EDNSRR); - } + memcpy(packet + *len, EDNSRR.c_str(), EDNSRR.size()); + *len += EDNSRR.size(); } + + return true; } static int removeEDNSOptionFromOptions(unsigned char* optionsStart, const uint16_t optionsLen, const uint16_t optionCodeToRemove, uint16_t* newOptionsLen) diff --git a/pdns/dnsdist-ecs.hh b/pdns/dnsdist-ecs.hh index 84013eb05..d2f6a261a 100644 --- a/pdns/dnsdist-ecs.hh +++ b/pdns/dnsdist-ecs.hh @@ -23,7 +23,7 @@ int rewriteResponseWithoutEDNS(const char * packet, size_t len, vector& newContent); int locateEDNSOptRR(char * packet, size_t len, char ** optStart, size_t * optLen, bool * last); -void handleEDNSClientSubnet(char * packet, size_t packetSize, unsigned int consumed, uint16_t * len, string& largerPacket, bool* ednsAdded, bool* ecsAdded, const ComboAddress& remote, bool overrideExisting, uint16_t ecsPrefixLength); +bool handleEDNSClientSubnet(char * packet, size_t packetSize, unsigned int consumed, uint16_t * len, bool* ednsAdded, bool* ecsAdded, const ComboAddress& remote, bool overrideExisting, uint16_t ecsPrefixLength); void generateOptRR(const std::string& optRData, string& res); int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove); int rewriteResponseWithoutEDNSOption(const char * packet, const size_t len, const uint16_t optionCodeToSkip, vector& newContent); diff --git a/pdns/dnsdist-lua2.cc b/pdns/dnsdist-lua2.cc index d896c1627..40c195f4b 100644 --- a/pdns/dnsdist-lua2.cc +++ b/pdns/dnsdist-lua2.cc @@ -1463,4 +1463,19 @@ void moreLua(bool client) g_lua.writeFunction("setConsoleConnectionsLogging", [](bool enabled) { g_logConsoleConnections = enabled; }); + + g_lua.writeFunction("setUDPMultipleMessagesVectorSize", [](size_t vSize) { + if (g_configurationDone) { + errlog("setUDPMultipleMessagesVectorSize() cannot be used at runtime!"); + g_outputBuffer="setUDPMultipleMessagesVectorSize() cannot be used at runtime!\n"; + return; + } +#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) + setLuaSideEffect(); + g_udpVectorSize = vSize; +#else + errlog("recvmmsg() support is not available!"); + g_outputBuffer="recvmmsg support is not available!\n"; +#endif + }); } diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index b5fe2fd5b..5d75e811a 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -222,17 +222,8 @@ void* tcpClientThread(int pipefd) bool outstanding = false; time_t lastTCPCleanup = time(nullptr); - - auto localPolicy = g_policy.getLocal(); - auto localRulactions = g_rulactions.getLocal(); + LocalHolders holders; auto localRespRulactions = g_resprulactions.getLocal(); - auto localCacheHitRespRulactions = g_cachehitresprulactions.getLocal(); - auto localDynBlockNMG = g_dynblockNMG.getLocal(); - auto localDynBlockSMT = g_dynblockSMT.getLocal(); - auto localPools = g_pools.getLocal(); -#ifdef HAVE_PROTOBUF - boost::uuids::random_generator uuidGenerator; -#endif #ifdef HAVE_DNSCRYPT /* when the answer is encrypted in place, we need to get a copy of the original header before encryption to fill the ring buffer */ @@ -255,7 +246,6 @@ void* tcpClientThread(int pipefd) delete citmp; uint16_t qlen, rlen; - string largerQuery; vector rewrittenResponse; shared_ptr ds; ComboAddress dest; @@ -264,6 +254,7 @@ void* tcpClientThread(int pipefd) socklen_t len = dest.getSocklen(); size_t queriesCount = 0; time_t connectionStartTime = time(NULL); + std::vector queryBuffer; if (!setNonBlocking(ci.fd)) goto drop; @@ -281,11 +272,16 @@ void* tcpClientThread(int pipefd) if(!getNonBlockingMsgLen(ci.fd, &qlen, g_tcpRecvTimeout)) break; + queriesCount++; + + if (qlen < sizeof(dnsheader)) { + g_stats.nonCompliantQueries++; + break; + } + ci.cs->queries++; g_stats.queries++; - queriesCount++; - if (g_maxTCPQueriesPerConn && queriesCount > g_maxTCPQueriesPerConn) { vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", ci.remote.toStringWithPort(), queriesCount, g_maxTCPQueriesPerConn); break; @@ -296,29 +292,23 @@ void* tcpClientThread(int pipefd) break; } - if (qlen < sizeof(dnsheader)) { - g_stats.nonCompliantQueries++; - break; - } - bool ednsAdded = false; bool ecsAdded = false; - /* if the query is small, allocate a bit more - memory to be able to spoof the content, + /* allocate a bit more memory to be able to spoof the content, or to add ECS without allocating a new buffer */ - size_t querySize = qlen <= 4096 ? qlen + 512 : qlen; - char queryBuffer[querySize]; - const char* query = queryBuffer; - readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout, remainingTime); + queryBuffer.reserve(qlen + 512); + + char* query = &queryBuffer[0]; + readn2WithTimeout(ci.fd, query, qlen, g_tcpRecvTimeout, remainingTime); #ifdef HAVE_DNSCRYPT - std::shared_ptr dnsCryptQuery = 0; + std::shared_ptr dnsCryptQuery = nullptr; if (ci.cs->dnscryptCtx) { dnsCryptQuery = std::make_shared(); uint16_t decryptedQueryLen = 0; vector response; - bool decrypted = handleDnsCryptQuery(ci.cs->dnscryptCtx, queryBuffer, qlen, dnsCryptQuery, &decryptedQueryLen, true, response); + bool decrypted = handleDnsCryptQuery(ci.cs->dnscryptCtx, query, qlen, dnsCryptQuery, &decryptedQueryLen, true, response); if (!decrypted) { if (response.size() > 0) { @@ -329,30 +319,20 @@ void* tcpClientThread(int pipefd) qlen = decryptedQueryLen; } #endif - struct dnsheader* dh = (struct dnsheader*) query; - - if(dh->qr) { // don't respond to responses - g_stats.nonCompliantQueries++; - goto drop; - } + struct dnsheader* dh = reinterpret_cast(query); - if(dh->qdcount == 0) { - g_stats.emptyQueries++; + if (!checkQueryHeaders(dh)) { goto drop; } - if (dh->rd) { - g_stats.rdQueries++; - } - const uint16_t* flags = getFlagsFromDNSHeader(dh); uint16_t origFlags = *flags; uint16_t qtype, qclass; unsigned int consumed = 0; DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, &dest, &ci.remote, (dnsheader*)query, querySize, qlen, true); + DNSQuestion dq(&qname, qtype, qclass, &dest, &ci.remote, dh, queryBuffer.capacity(), qlen, true); #ifdef HAVE_PROTOBUF - dq.uniqueId = uuidGenerator(); + dq.uniqueId = t_uuidGenerator(); #endif string poolname; @@ -363,14 +343,14 @@ void* tcpClientThread(int pipefd) gettime(&now); gettime(&queryRealTime, true); - if (!processQuery(localDynBlockNMG, localDynBlockSMT, localRulactions, dq, poolname, &delayMsec, now)) { + if (!processQuery(holders, dq, poolname, &delayMsec, now)) { goto drop; } if(dq.dh->qr) { // something turned it into a response restoreFlags(dh, origFlags); #ifdef HAVE_DNSCRYPT - if (!encryptResponse(queryBuffer, &dq.len, dq.size, true, dnsCryptQuery, nullptr, nullptr)) { + if (!encryptResponse(query, &dq.len, dq.size, true, dnsCryptQuery, nullptr, nullptr)) { goto drop; } #endif @@ -379,9 +359,9 @@ void* tcpClientThread(int pipefd) goto drop; } - std::shared_ptr serverPool = getPool(*localPools, poolname); + std::shared_ptr serverPool = getPool(*holders.pools, poolname); std::shared_ptr packetCache = nullptr; - auto policy = localPolicy->policy; + auto policy = holders.policy->policy; if (serverPool->policy != nullptr) { policy = serverPool->policy->policy; } @@ -393,14 +373,11 @@ void* tcpClientThread(int pipefd) if (dq.useECS && ds && ds->useECS) { uint16_t newLen = dq.len; - handleEDNSClientSubnet(queryBuffer, dq.size, consumed, &newLen, largerQuery, &ednsAdded, &ecsAdded, ci.remote, dq.ecsOverride, dq.ecsPrefixLength); - if (largerQuery.empty() == false) { - query = largerQuery.c_str(); - dq.len = (uint16_t) largerQuery.size(); - dq.size = largerQuery.size(); - } else { - dq.len = newLen; + if (!handleEDNSClientSubnet(query, dq.size, consumed, &newLen, &ednsAdded, &ecsAdded, ci.remote, dq.ecsOverride, dq.ecsPrefixLength)) { + vinfolog("Dropping query from %s because we couldn't insert the ECS value", ci.remote.toStringWithPort()); + goto drop; } + dq.len = newLen; } uint32_t cacheKey = 0; @@ -413,7 +390,7 @@ void* tcpClientThread(int pipefd) #ifdef HAVE_PROTOBUF dr.uniqueId = dq.uniqueId; #endif - if (!processResponse(localCacheHitRespRulactions, dr, &delayMsec)) { + if (!processResponse(holders.cacheHitRespRulactions, dr, &delayMsec)) { goto drop; } @@ -438,7 +415,7 @@ void* tcpClientThread(int pipefd) dq.dh->qr = true; #ifdef HAVE_DNSCRYPT - if (!encryptResponse(queryBuffer, &dq.len, dq.size, true, dnsCryptQuery, nullptr, nullptr)) { + if (!encryptResponse(query, &dq.len, dq.size, true, dnsCryptQuery, nullptr, nullptr)) { goto drop; } #endif @@ -609,7 +586,6 @@ void* tcpClientThread(int pipefd) g_rings.respRing.push_back({answertime, ci.remote, qname, dq.qtype, (unsigned int)udiff, (unsigned int)responseLen, *dh, ds->remote}); } - largerQuery.clear(); rewrittenResponse.clear(); } } diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 8f3634a53..8bba4b3f5 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -53,6 +53,10 @@ #include #endif +#ifdef HAVE_PROTOBUF +thread_local boost::uuids::random_generator t_uuidGenerator; +#endif + /* Known sins: Receiver is currently single threaded @@ -88,6 +92,7 @@ std::vector > g_dynBPFFilters; #endif /* HAVE_EBPF */ vector g_frontends; GlobalStateHolder g_pools; +size_t g_udpVectorSize{1}; bool g_snmpEnabled{false}; bool g_snmpTrapsEnabled{false}; @@ -143,6 +148,9 @@ int g_udpTimeout{2}; bool g_servFailOnNoPolicy{false}; bool g_truncateTC{false}; bool g_fixupCase{0}; + +static const size_t s_udpIncomingBufferSize{1500}; + static void truncateTC(const char* packet, uint16_t* len) try { @@ -498,7 +506,7 @@ try { if (ids->origFD == origFD) { #ifdef HAVE_DNSCRYPT - ids->dnsCryptQuery = 0; + ids->dnsCryptQuery = nullptr; #endif ids->origFD = -1; } @@ -850,9 +858,7 @@ static void spoofResponseFromString(DNSQuestion& dq, const string& spoofContent) } } -bool processQuery(LocalStateHolder >& localDynNMGBlock, - LocalStateHolder >& localDynSMTBlock, - LocalStateHolder, std::shared_ptr > > >& localRulactions, DNSQuestion& dq, string& poolname, int* delayMsec, const struct timespec& now) +bool processQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, int* delayMsec, const struct timespec& now) { { WriteLock wl(&g_rings.queryLock); @@ -876,7 +882,7 @@ bool processQuery(LocalStateHolder >& localDynNMGBlock, } } - if(auto got=localDynNMGBlock->lookup(*dq.remote)) { + if(auto got = holders.dynNMGBlock->lookup(*dq.remote)) { auto updateBlockStats = [&got]() { g_stats.dynBlocked++; got->second.blocks++; @@ -916,7 +922,7 @@ bool processQuery(LocalStateHolder >& localDynNMGBlock, } } - if(auto got=localDynSMTBlock->lookup(*dq.qname)) { + if(auto got = holders.dynSMTBlock->lookup(*dq.qname)) { auto updateBlockStats = [&got]() { g_stats.dynBlocked++; got->blocks++; @@ -958,7 +964,7 @@ bool processQuery(LocalStateHolder >& localDynNMGBlock, DNSAction::Action action=DNSAction::Action::None; string ruleresult; - for(const auto& lr : *localRulactions) { + for(const auto& lr : *holders.rulactions) { if(lr.first->matches(&dq)) { lr.first->d_matches++; action=(*lr.second)(&dq, &ruleresult); @@ -1073,384 +1079,472 @@ static ssize_t udpClientSendRequestToBackend(DownstreamState* ss, const int sd, return result; } -// listens to incoming queries, sends out to downstream servers, noting the intended return path -static void* udpClientThread(ClientState* cs) -try +static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest) { - string largerQuery; - uint16_t qtype, qclass; - uint16_t queryId; -#ifdef HAVE_PROTOBUF - boost::uuids::random_generator uuidGenerator; -#endif + if (msgh->msg_flags & MSG_TRUNC) { + /* message was too large for our buffer */ + vinfolog("Dropping message too large for our buffer"); + g_stats.nonCompliantQueries++; + return false; + } - auto acl = g_ACL.getLocal(); - auto localPolicy = g_policy.getLocal(); - auto localRulactions = g_rulactions.getLocal(); - auto localCacheHitRespRulactions = g_cachehitresprulactions.getLocal(); - auto localServers = g_dstates.getLocal(); - auto localDynNMGBlock = g_dynblockNMG.getLocal(); - auto localDynSMTBlock = g_dynblockSMT.getLocal(); - auto localPools = g_pools.getLocal(); - - static const size_t vectSize = 50; - struct - { - char packet[4096]; - /* used by HarvestDestinationAddress */ - char cbuf[256]; - ComboAddress remote; - ComboAddress dest; - struct iovec iov; + if(!holders.acl->match(remote)) { + vinfolog("Query from %s dropped because of ACL", remote.toStringWithPort()); + g_stats.aclDrops++; + return false; } - data[vectSize]; - struct mmsghdr msgVec[vectSize]; - struct mmsghdr outMsgVec[vectSize]; - for (size_t idx = 0; idx < vectSize; idx++) { - data[idx].remote.sin4.sin_family = cs->local.sin4.sin_family; + cs.queries++; + g_stats.queries++; - fillMSGHdr(&msgVec[idx].msg_hdr, &data[idx].iov, data[idx].cbuf, sizeof(data[idx].cbuf), data[idx].packet, sizeof(data[idx].packet), &data[idx].remote); + if (HarvestDestinationAddress(msgh, &dest)) { + /* we don't get the port, only the address */ + dest.sin4.sin_port = cs.local.sin4.sin_port; + } + else { + dest.sin4.sin_family = 0; } - for(;;) { + return true; +} + #ifdef HAVE_DNSCRYPT - std::shared_ptr dnsCryptQuery = 0; -#endif - for (size_t idx = 0; idx < vectSize; idx++) { - data[idx].iov.iov_base = data[idx].packet; - data[idx].iov.iov_len = sizeof(data[idx].packet); - } - int msgsGot = recvmmsg(cs->udpFD, msgVec, vectSize, MSG_WAITFORONE | MSG_TRUNC, nullptr); +static bool checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr& dnsCryptQuery, const ComboAddress& dest, const ComboAddress& remote) +{ + if (cs.dnscryptCtx) { + vector response; + uint16_t decryptedQueryLen = 0; - if (msgsGot <= 0) { - vinfolog("recvmmsg() failed with: %s", strerror(errno)); - continue; + dnsCryptQuery = std::make_shared(); + + bool decrypted = handleDnsCryptQuery(cs.dnscryptCtx, const_cast(query), len, dnsCryptQuery, &decryptedQueryLen, false, response); + + if (!decrypted) { + if (response.size() > 0) { + sendUDPResponse(cs.udpFD, reinterpret_cast(response.data()), static_cast(response.size()), 0, dest, remote); } - //vinfolog("Got %d messages", msgsGot); - unsigned int msgsToSend = 0; + return false; + } - for (int msgIdx = 0; msgIdx < msgsGot; msgIdx++) { - const struct msghdr* msgh = &msgVec[msgIdx].msg_hdr; - unsigned int ret = msgVec[msgIdx].msg_len; - const ComboAddress& remote = data[msgIdx].remote; + len = decryptedQueryLen; + } + return true; +} +#endif /* HAVE_DNSCRYPT */ - try { - char* query = data[msgIdx].packet; +bool checkQueryHeaders(const struct dnsheader* dh) +{ + if (dh->qr) { // don't respond to responses + g_stats.nonCompliantQueries++; + return false; + } - queryId = 0; + if (dh->qdcount == 0) { + g_stats.emptyQueries++; + return false; + } - if(!acl->match(remote)) { - vinfolog("Query from %s dropped because of ACL", remote.toStringWithPort()); - g_stats.aclDrops++; - continue; - } + if (dh->rd) { + g_stats.rdQueries++; + } - cs->queries++; - g_stats.queries++; + return true; +} - if(ret < (int)sizeof(struct dnsheader)) { - g_stats.nonCompliantQueries++; - continue; - } +#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) +static void queueResponse(const ClientState& cs, const char* response, uint16_t responseLen, const ComboAddress& dest, const ComboAddress& remote, struct mmsghdr& outMsg, struct iovec* iov, char* cbuf) +{ + outMsg.msg_len = 0; + fillMSGHdr(&outMsg.msg_hdr, iov, nullptr, 0, const_cast(response), responseLen, const_cast(&remote)); - if (msgh->msg_flags & MSG_TRUNC) { - /* message was too large for our buffer */ - vinfolog("Dropping message too large for our buffer"); - g_stats.nonCompliantQueries++; - continue; - } + if (dest.sin4.sin_family == 0) { + outMsg.msg_hdr.msg_control = nullptr; + } + else { + addCMsgSrcAddr(&outMsg.msg_hdr, cbuf, &dest, 0); + } +} +#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */ - uint16_t len = (uint16_t) ret; +static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, char* query, uint16_t len, size_t queryBufferSize, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, char* respCBuf) +{ + assert(responsesVect == nullptr || (queuedResponses != nullptr && respIOV != nullptr && respCBuf != nullptr)); + uint16_t queryId = 0; - if (HarvestDestinationAddress(msgh, &data[msgIdx].dest)) { - /* we don't get the port, only the address */ - data[msgIdx].dest.sin4.sin_port = cs->local.sin4.sin_port; - } - else { - data[msgIdx].dest.sin4.sin_family = 0; - } + try { + if (!isUDPQueryAcceptable(cs, holders, msgh, remote, dest)) { + return; + } #ifdef HAVE_DNSCRYPT - if (cs->dnscryptCtx) { - vector response; - uint16_t decryptedQueryLen = 0; - dnsCryptQuery = std::make_shared(); + std::shared_ptr dnsCryptQuery = nullptr; - bool decrypted = handleDnsCryptQuery(cs->dnscryptCtx, query, len, dnsCryptQuery, &decryptedQueryLen, false, response); - - if (!decrypted) { - if (response.size() > 0) { - sendUDPResponse(cs->udpFD, reinterpret_cast(response.data()), (uint16_t) response.size(), 0, data[msgIdx].dest, remote); - } - continue; - } - len = decryptedQueryLen; - } + if (!checkDNSCryptQuery(cs, query, len, dnsCryptQuery, dest, remote)) { + return; + } #endif - struct dnsheader* dh = (struct dnsheader*) query; - queryId = ntohs(dh->id); + struct dnsheader* dh = reinterpret_cast(query); + queryId = ntohs(dh->id); - if(dh->qr) { // don't respond to responses - g_stats.nonCompliantQueries++; - continue; - } - - if(dh->qdcount == 0) { - g_stats.emptyQueries++; - continue; - } + if (!checkQueryHeaders(dh)) { + return; + } - if (dh->rd) { - g_stats.rdQueries++; - } + const uint16_t * flags = getFlagsFromDNSHeader(dh); + const uint16_t origFlags = *flags; + uint16_t qtype, qclass; + unsigned int consumed = 0; + DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed); + DNSQuestion dq(&qname, qtype, qclass, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, dh, queryBufferSize, len, false); - const uint16_t * flags = getFlagsFromDNSHeader(dh); - const uint16_t origFlags = *flags; - unsigned int consumed = 0; - DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, data[msgIdx].dest.sin4.sin_family != 0 ? &data[msgIdx].dest : &cs->local, &remote, dh, sizeof(data[msgIdx].packet), len, false); #ifdef HAVE_PROTOBUF - dq.uniqueId = uuidGenerator(); + dq.uniqueId = t_uuidGenerator(); #endif - string poolname; - int delayMsec=0; - /* we need an accurate ("real") value for the response and - to store into the IDS, but not for insertion into the - rings for example */ - struct timespec realTime; - struct timespec now; - gettime(&now); - gettime(&realTime, true); - - if (!processQuery(localDynNMGBlock, localDynSMTBlock, localRulactions, dq, poolname, &delayMsec, now)) - { - continue; - } + string poolname; + int delayMsec = 0; + /* we need an accurate ("real") value for the response and + to store into the IDS, but not for insertion into the + rings for example */ + struct timespec realTime; + struct timespec now; + gettime(&now); + gettime(&realTime, true); + + if (!processQuery(holders, dq, poolname, &delayMsec, now)) + { + return; + } - if(dq.dh->qr) { // something turned it into a response + if(dq.dh->qr) { // something turned it into a response + g_stats.selfAnswered++; + restoreFlags(dh, origFlags); + + if (!cs.muted) { char* response = query; uint16_t responseLen = dq.len; - g_stats.selfAnswered++; - - restoreFlags(dh, origFlags); - if (!cs->muted) { #ifdef HAVE_DNSCRYPT - if (!encryptResponse(response, &responseLen, dq.size, false, dnsCryptQuery, nullptr, nullptr)) { - continue; - } + if (!encryptResponse(response, &responseLen, dq.size, false, dnsCryptQuery, nullptr, nullptr)) { + return; + } #endif - outMsgVec[msgsToSend].msg_len = 0; - fillMSGHdr(&outMsgVec[msgsToSend].msg_hdr, &data[msgIdx].iov, nullptr, 0, response, responseLen, &data[msgIdx].remote); - - if (data[msgIdx].dest.sin4.sin_family == 0) { - outMsgVec[msgsToSend].msg_hdr.msg_control = nullptr; - } - else { - addCMsgSrcAddr(&outMsgVec[msgsToSend].msg_hdr, &data[msgIdx].cbuf, &data[msgIdx].dest, 0); - } - - msgsToSend++; - //sendUDPResponse(cs->udpFD, response, responseLen, delayMsec, dest, remote); +#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) + if (delayMsec == 0 && responsesVect != nullptr) { + queueResponse(cs, response, responseLen, dest, remote, responsesVect[*queuedResponses], respIOV, respCBuf); + (*queuedResponses)++; + } + else +#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */ + { + sendUDPResponse(cs.udpFD, response, responseLen, delayMsec, dest, remote); } - - continue; } - DownstreamState* ss = nullptr; - std::shared_ptr serverPool = getPool(*localPools, poolname); - std::shared_ptr packetCache = nullptr; - auto policy = localPolicy->policy; - if (serverPool->policy != nullptr) { - policy = serverPool->policy->policy; - } - { - std::lock_guard lock(g_luamutex); - ss = policy(serverPool->servers, &dq).get(); - packetCache = serverPool->packetCache; - } + return; + } + + DownstreamState* ss = nullptr; + std::shared_ptr serverPool = getPool(*holders.pools, poolname); + std::shared_ptr packetCache = nullptr; + auto policy = holders.policy->policy; + if (serverPool->policy != nullptr) { + policy = serverPool->policy->policy; + } + { + std::lock_guard lock(g_luamutex); + ss = policy(serverPool->servers, &dq).get(); + packetCache = serverPool->packetCache; + } - bool ednsAdded = false; - bool ecsAdded = false; - if (dq.useECS && ss && ss->useECS) { - handleEDNSClientSubnet(query, dq.size, consumed, &dq.len, largerQuery, &(ednsAdded), &(ecsAdded), remote, dq.ecsOverride, dq.ecsPrefixLength); + bool ednsAdded = false; + bool ecsAdded = false; + if (dq.useECS && ss && ss->useECS) { + if (!handleEDNSClientSubnet(query, dq.size, consumed, &dq.len, &(ednsAdded), &(ecsAdded), remote, dq.ecsOverride, dq.ecsPrefixLength)) { + vinfolog("Dropping query from µs because we couldn't insert the ECS value", remote.toStringWithPort()); + return; } + } - uint32_t cacheKey = 0; - if (packetCache && !dq.skipCache) { -// char cachedResponse[4096]; - uint16_t cachedResponseSize = dq.size; - uint32_t allowExpired = ss ? 0 : g_staleCacheEntriesTTL; - if (packetCache->get(dq, consumed, dh->id, query, &cachedResponseSize, &cacheKey, allowExpired)) { - DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.local, dq.remote, (dnsheader*) query, dq.size, cachedResponseSize, false, &realTime); + uint32_t cacheKey = 0; + if (packetCache && !dq.skipCache) { + uint16_t cachedResponseSize = dq.size; + uint32_t allowExpired = ss ? 0 : g_staleCacheEntriesTTL; + if (packetCache->get(dq, consumed, dh->id, query, &cachedResponseSize, &cacheKey, allowExpired)) { + DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.local, dq.remote, reinterpret_cast(query), dq.size, cachedResponseSize, false, &realTime); #ifdef HAVE_PROTOBUF - dr.uniqueId = dq.uniqueId; + dr.uniqueId = dq.uniqueId; #endif - if (!processResponse(localCacheHitRespRulactions, dr, &delayMsec)) { - continue; - } + if (!processResponse(holders.cacheHitRespRulactions, dr, &delayMsec)) { + return; + } - if (!cs->muted) { + if (!cs.muted) { #ifdef HAVE_DNSCRYPT - if (!encryptResponse(query, &cachedResponseSize, dq.size, false, dnsCryptQuery, nullptr, nullptr)) { - continue; - } + if (!encryptResponse(query, &cachedResponseSize, dq.size, false, dnsCryptQuery, nullptr, nullptr)) { + return; + } #endif - //sendUDPResponse(cs->udpFD, cachedResponse, cachedResponseSize, delayMsec, dest, remote); - outMsgVec[msgsToSend].msg_len = 0; - fillMSGHdr(&outMsgVec[msgsToSend].msg_hdr, &data[msgIdx].iov, nullptr, 0, query, cachedResponseSize, &data[msgIdx].remote); - if (data[msgIdx].dest.sin4.sin_family == 0) { - outMsgVec[msgsToSend].msg_hdr.msg_control = nullptr; - } - else { - addCMsgSrcAddr(&outMsgVec[msgsToSend].msg_hdr, &data[msgIdx].cbuf, &data[msgIdx].dest, 0); - } - - msgsToSend++; +#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) + if (delayMsec == 0 && responsesVect != nullptr) { + queueResponse(cs, query, cachedResponseSize, dest, remote, responsesVect[*queuedResponses], respIOV, respCBuf); + (*queuedResponses)++; + } + else +#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */ + { + sendUDPResponse(cs.udpFD, query, cachedResponseSize, delayMsec, dest, remote); } - - g_stats.cacheHits++; - g_stats.latency0_1++; // we're not going to measure this - doLatencyAverages(0); // same - continue; } - g_stats.cacheMisses++; + + g_stats.cacheHits++; + g_stats.latency0_1++; // we're not going to measure this + doLatencyAverages(0); // same + return; } + g_stats.cacheMisses++; + } - if(!ss) { - g_stats.noPolicy++; + if(!ss) { + g_stats.noPolicy++; - if (g_servFailOnNoPolicy) { - char* response = query; - uint16_t responseLen = dq.len; - restoreFlags(dh, origFlags); + if (g_servFailOnNoPolicy && !cs.muted) { + char* response = query; + uint16_t responseLen = dq.len; + restoreFlags(dh, origFlags); - dq.dh->rcode = RCode::ServFail; - dq.dh->qr = true; + dq.dh->rcode = RCode::ServFail; + dq.dh->qr = true; #ifdef HAVE_DNSCRYPT - if (!encryptResponse(response, &responseLen, dq.size, false, dnsCryptQuery, nullptr, nullptr)) { - continue; - } + if (!encryptResponse(response, &responseLen, dq.size, false, dnsCryptQuery, nullptr, nullptr)) { + return; + } #endif - outMsgVec[msgsToSend].msg_len = 0; - fillMSGHdr(&outMsgVec[msgsToSend].msg_hdr, &data[msgIdx].iov, nullptr, 0, response, responseLen, &data[msgIdx].remote); - if (data[msgIdx].dest.sin4.sin_family == 0) { - outMsgVec[msgsToSend].msg_hdr.msg_control = nullptr; - } - else { - addCMsgSrcAddr(&outMsgVec[msgsToSend].msg_hdr, &data[msgIdx].cbuf, &data[msgIdx].dest, 0); - } - - msgsToSend++; -// sendUDPResponse(cs->udpFD, response, responseLen, 0, dest, remote); +#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) + if (responsesVect != nullptr) { + queueResponse(cs, response, responseLen, dest, remote, responsesVect[*queuedResponses], respIOV, respCBuf); + (*queuedResponses)++; + } + else +#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */ + { + sendUDPResponse(cs.udpFD, response, responseLen, 0, dest, remote); } - vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "Dropped" : "ServFailed", dq.qname->toString(), QType(dq.qtype).getName(), remote.toStringWithPort()); - continue; - } + vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "Dropped" : "ServFailed", dq.qname->toString(), QType(dq.qtype).getName(), remote.toStringWithPort()); + return; + } - ss->queries++; - - unsigned int idOffset = (ss->idOffset++) % ss->idStates.size(); - IDState* ids = &ss->idStates[idOffset]; - ids->age = 0; + ss->queries++; - if(ids->origFD < 0) // if we are reusing, no change in outstanding - ss->outstanding++; - else { - ss->reuseds++; - g_stats.downstreamTimeouts++; - } + unsigned int idOffset = (ss->idOffset++) % ss->idStates.size(); + IDState* ids = &ss->idStates[idOffset]; + ids->age = 0; - ids->cs = cs; - ids->origFD = cs->udpFD; - ids->origID = dh->id; - ids->origRemote = remote; - ids->sentTime.set(realTime); - ids->qname = qname; - ids->qtype = dq.qtype; - ids->qclass = dq.qclass; - ids->delayMsec = delayMsec; - ids->origFlags = origFlags; - ids->cacheKey = cacheKey; - ids->skipCache = dq.skipCache; - ids->packetCache = packetCache; - ids->ednsAdded = ednsAdded; - ids->ecsAdded = ecsAdded; - - /* If we couldn't harvest the real dest addr, still - write down the listening addr since it will be useful - (especially if it's not an 'any' one). - We need to keep track of which one it is since we may - want to use the real but not the listening addr to reply. - */ - if (data[msgIdx].dest.sin4.sin_family != 0) { - ids->origDest = data[msgIdx].dest; - ids->destHarvested = true; - } - else { - ids->origDest = cs->local; - ids->destHarvested = false; - } + if(ids->origFD < 0) // if we are reusing, no change in outstanding + ss->outstanding++; + else { + ss->reuseds++; + g_stats.downstreamTimeouts++; + } + + ids->cs = &cs; + ids->origFD = cs.udpFD; + ids->origID = dh->id; + ids->origRemote = remote; + ids->sentTime.set(realTime); + ids->qname = qname; + ids->qtype = dq.qtype; + ids->qclass = dq.qclass; + ids->delayMsec = delayMsec; + ids->origFlags = origFlags; + ids->cacheKey = cacheKey; + ids->skipCache = dq.skipCache; + ids->packetCache = packetCache; + ids->ednsAdded = ednsAdded; + ids->ecsAdded = ecsAdded; + + /* If we couldn't harvest the real dest addr, still + write down the listening addr since it will be useful + (especially if it's not an 'any' one). + We need to keep track of which one it is since we may + want to use the real but not the listening addr to reply. + */ + if (dest.sin4.sin_family != 0) { + ids->origDest = dest; + ids->destHarvested = true; + } + else { + ids->origDest = cs.local; + ids->destHarvested = false; + } #ifdef HAVE_DNSCRYPT - ids->dnsCryptQuery = dnsCryptQuery; + ids->dnsCryptQuery = dnsCryptQuery; #endif #ifdef HAVE_PROTOBUF - ids->uniqueId = dq.uniqueId; + ids->uniqueId = dq.uniqueId; #endif - dh->id = idOffset; + dh->id = idOffset; - if (largerQuery.empty()) { - ret = udpClientSendRequestToBackend(ss, ss->fd, query, dq.len); - } - else { - ret = udpClientSendRequestToBackend(ss, ss->fd, largerQuery.c_str(), largerQuery.size()); - largerQuery.clear(); - } + ssize_t ret = udpClientSendRequestToBackend(ss, ss->fd, query, dq.len); - if(ret < 0) { - ss->sendErrors++; - g_stats.downstreamSendErrors++; - } + if(ret < 0) { + ss->sendErrors++; + g_stats.downstreamSendErrors++; + } + + vinfolog("Got query for %s|%s from %s, relayed to %s", ids->qname.toString(), QType(ids->qtype).getName(), remote.toStringWithPort(), ss->getName()); + } + catch(const std::exception& e){ + vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what()); + } +} + +#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) +static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holders) +{ + struct MMReceiver + { + char packet[4096]; + /* used by HarvestDestinationAddress */ + char cbuf[256]; + ComboAddress remote; + ComboAddress dest; + struct iovec iov; + }; + const size_t vectSize = g_udpVectorSize; + /* the actual buffer is larger because: + - we may have to add EDNS and/or ECS + - we use it for self-generated responses (from rule or cache) + but we only accept incoming payloads up to that size + */ + static_assert(s_udpIncomingBufferSize <= sizeof(MMReceiver::packet), "the incoming buffer size should not be larger than sizeof(MMReceiver::packet)"); + + auto recvData = std::unique_ptr(new MMReceiver[vectSize]); + auto msgVec = std::unique_ptr(new struct mmsghdr[vectSize]); + auto outMsgVec = std::unique_ptr(new struct mmsghdr[vectSize]); + + /* initialize the structures needed to receive our messages */ + for (size_t idx = 0; idx < vectSize; idx++) { + recvData[idx].remote.sin4.sin_family = cs->local.sin4.sin_family; + fillMSGHdr(&msgVec[idx].msg_hdr, &recvData[idx].iov, recvData[idx].cbuf, sizeof(recvData[idx].cbuf), recvData[idx].packet, s_udpIncomingBufferSize, &recvData[idx].remote); + } + + /* go now */ + for(;;) { - vinfolog("Got query for %s|%s from %s, relayed to %s", ids->qname.toString(), QType(ids->qtype).getName(), remote.toStringWithPort(), ss->getName()); + /* reset the IO vector, since it's also used to send the vector of responses + to avoid having to copy the data around */ + for (size_t idx = 0; idx < vectSize; idx++) { + recvData[idx].iov.iov_base = recvData[idx].packet; + recvData[idx].iov.iov_len = sizeof(recvData[idx].packet); } - catch(const std::exception& e){ - vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what()); + + /* block until we have at least one message ready, but return + as many as possible to save the syscall costs */ + int msgsGot = recvmmsg(cs->udpFD, msgVec.get(), vectSize, MSG_WAITFORONE | MSG_TRUNC, nullptr); + + if (msgsGot <= 0) { + vinfolog("Getting UDP messages via recvmmsg() failed with: %s", strerror(errno)); + continue; } + + unsigned int msgsToSend = 0; + + /* process the received messages */ + for (int msgIdx = 0; msgIdx < msgsGot; msgIdx++) { + const struct msghdr* msgh = &msgVec[msgIdx].msg_hdr; + unsigned int got = msgVec[msgIdx].msg_len; + const ComboAddress& remote = recvData[msgIdx].remote; + + if (got < sizeof(struct dnsheader)) { + g_stats.nonCompliantQueries++; + continue; + } + + processUDPQuery(*cs, holders, msgh, remote, recvData[msgIdx].dest, recvData[msgIdx].packet, static_cast(got), sizeof(recvData[msgIdx].packet), outMsgVec.get(), &msgsToSend, &recvData[msgIdx].iov, recvData[msgIdx].cbuf); + } - if (msgsToSend > 0) { - int sent = sendmmsg(cs->udpFD, outMsgVec, msgsToSend, 0); + + /* immediate (not delayed or sent to a backend) responses (mostly from a rule, dynamic block + or the cache) can be sent in batch too */ + + if (msgsToSend > 0 && msgsToSend <= static_cast(msgsGot)) { + int sent = sendmmsg(cs->udpFD, outMsgVec.get(), msgsToSend, 0); + if (sent < 0 || static_cast(sent) != msgsToSend) { - vinfolog("Error sending responses with sendmmsg (%d on %u): %s", sent, msgsToSend, strerror(errno)); + vinfolog("Error sending responses with sendmmsg() (%d on %u): %s", sent, msgsToSend, strerror(errno)); } - //vinfolog("Sent %d responses", sent); } + } - return 0; +} +#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */ + +// listens to incoming queries, sends out to downstream servers, noting the intended return path +static void* udpClientThread(ClientState* cs) +try +{ + LocalHolders holders; + +#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) + if (g_udpVectorSize > 1) { + MultipleMessagesUDPClientThread(cs, holders); + + } + else +#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */ + { + char packet[4096]; + /* the actual buffer is larger because: + - we may have to add EDNS and/or ECS + - we use it for self-generated responses (from rule or cache) + but we only accept incoming payloads up to that size + */ + static_assert(s_udpIncomingBufferSize <= sizeof(packet), "the incoming buffer size should not be larger than sizeof(MMReceiver::packet)"); + struct msghdr msgh; + struct iovec iov; + /* used by HarvestDestinationAddress */ + char cbuf[256]; + + ComboAddress remote; + ComboAddress dest; + remote.sin4.sin_family = cs->local.sin4.sin_family; + fillMSGHdr(&msgh, &iov, cbuf, sizeof(cbuf), packet, sizeof(packet), &remote); + + for(;;) { + ssize_t got = recvmsg(cs->udpFD, &msgh, 0); + + if (got < 0 || static_cast(got) < sizeof(struct dnsheader)) { + g_stats.nonCompliantQueries++; + continue; + } + + processUDPQuery(*cs, holders, &msgh, remote, dest, packet, static_cast(got), s_udpIncomingBufferSize, nullptr, nullptr, nullptr, nullptr); + } + } + + return nullptr; } catch(const std::exception &e) { errlog("UDP client thread died because of exception: %s", e.what()); - return 0; + return nullptr; } catch(const PDNSException &e) { errlog("UDP client thread died because of PowerDNS exception: %s", e.reason); - return 0; + return nullptr; } catch(...) { errlog("UDP client thread died because of an exception: %s", "unknown"); - return 0; + return nullptr; } - static bool upCheck(DownstreamState& ds) try { @@ -1974,6 +2068,9 @@ try #ifdef HAVE_DNSCRYPT cout<<"dnscrypt "; #endif +#ifdef HAVE_EBPF + cout<<"ebpf "; +#endif #ifdef HAVE_LIBSODIUM cout<<"libsodium "; #endif @@ -1983,6 +2080,12 @@ try #ifdef HAVE_RE2 cout<<"re2 "; #endif +#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) + cout<<"recvmmsg/sendmmsg "; +#endif +#ifdef HAVE_NET_SNMP + cout<<"snmp "; +#endif #ifdef HAVE_SYSTEMD cout<<"systemd"; #endif diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index a8d56fc33..020e43e59 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -54,7 +54,6 @@ extern uint16_t g_ECSSourcePrefixV4; extern uint16_t g_ECSSourcePrefixV6; extern bool g_ECSOverride; - class QTag { public: @@ -124,6 +123,9 @@ private: static constexpr char const *strSep = "\t"; }; +#ifdef HAVE_PROTOBUF +extern thread_local boost::uuids::random_generator t_uuidGenerator; +#endif struct DNSQuestion { @@ -758,6 +760,7 @@ extern bool g_servFailOnNoPolicy; extern uint32_t g_hashperturb; extern bool g_useTCPSinglePipe; extern std::atomic g_downstreamTCPCleanupInterval; +extern size_t g_udpVectorSize; struct ConsoleKeyword { std::string name; @@ -783,6 +786,22 @@ extern shared_ptr g_defaultBPFFilter; extern std::vector > g_dynBPFFilters; #endif /* HAVE_EBPF */ +struct LocalHolders +{ + LocalHolders(): acl(g_ACL.getLocal()), policy(g_policy.getLocal()), rulactions(g_rulactions.getLocal()), cacheHitRespRulactions(g_cachehitresprulactions.getLocal()), servers(g_dstates.getLocal()), dynNMGBlock(g_dynblockNMG.getLocal()), dynSMTBlock(g_dynblockSMT.getLocal()), pools(g_pools.getLocal()) + { + } + + LocalStateHolder acl; + LocalStateHolder policy; + LocalStateHolder, std::shared_ptr > > > rulactions; + LocalStateHolder, std::shared_ptr > > > cacheHitRespRulactions; + LocalStateHolder servers; + LocalStateHolder > dynNMGBlock; + LocalStateHolder > dynSMTBlock; + LocalStateHolder pools; +}; + struct dnsheader; void controlThread(int fd, ComboAddress local); @@ -817,11 +836,11 @@ bool getLuaNoSideEffect(); // set if there were only explicit declarations of _n void resetLuaSideEffect(); // reset to indeterminate state bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote); -bool processQuery(LocalStateHolder >& localDynBlockNMG, - LocalStateHolder >& localDynBlockSMT, LocalStateHolder, std::shared_ptr > > >& localRulactions, DNSQuestion& dq, string& poolname, int* delayMsec, const struct timespec& now); +bool processQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, int* delayMsec, const struct timespec& now); bool processResponse(LocalStateHolder, std::shared_ptr > > >& localRespRulactions, DNSResponse& dr, int* delayMsec); bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, std::vector& rewrittenResponse, uint16_t addRoom); void restoreFlags(struct dnsheader* dh, uint16_t origFlags); +bool checkQueryHeaders(const struct dnsheader* dh); #ifdef HAVE_DNSCRYPT extern std::vector>> g_dnsCryptLocals; diff --git a/pdns/dnsdistdist/dnsrulactions.cc b/pdns/dnsdistdist/dnsrulactions.cc index 49292763e..e5a54c80d 100644 --- a/pdns/dnsdistdist/dnsrulactions.cc +++ b/pdns/dnsdistdist/dnsrulactions.cc @@ -49,21 +49,17 @@ DNSAction::Action TeeAction::operator()(DNSQuestion* dq, string* ruleresult) con if(d_addECS) { std::string query; - std::string larger; uint16_t len = dq->len; bool ednsAdded = false; bool ecsAdded = false; query.reserve(dq->size); query.assign((char*) dq->dh, len); - handleEDNSClientSubnet((char*) query.c_str(), query.size(), dq->qname->wirelength(), &len, larger, &ednsAdded, &ecsAdded, *dq->remote, dq->ecsOverride, dq->ecsPrefixLength); - - if (larger.empty()) { - res = send(d_fd, query.c_str(), len, 0); - } - else { - res = send(d_fd, larger.c_str(), larger.length(), 0); + if (!handleEDNSClientSubnet((char*) query.c_str(), query.size(), dq->qname->wirelength(), &len, &ednsAdded, &ecsAdded, *dq->remote, dq->ecsOverride, dq->ecsPrefixLength)) { + return DNSAction::Action::None; } + + res = send(d_fd, query.c_str(), len, 0); } else { res = send(d_fd, (char*)dq->dh, dq->len, 0); diff --git a/pdns/test-dnsdist_cc.cc b/pdns/test-dnsdist_cc.cc index 36287d04a..7c9114f68 100644 --- a/pdns/test-dnsdist_cc.cc +++ b/pdns/test-dnsdist_cc.cc @@ -44,7 +44,7 @@ bool g_verbose{true}; static const uint16_t ECSSourcePrefixV4 = 24; static const uint16_t ECSSourcePrefixV6 = 56; -static void validateQuery(const char * packet, size_t packetSize) +static void validateQuery(const char * packet, size_t packetSize, bool hasEdns=true) { MOADNSParser mdp(true, packet, packetSize); @@ -53,7 +53,7 @@ static void validateQuery(const char * packet, size_t packetSize) BOOST_CHECK_EQUAL(mdp.d_header.qdcount, 1); BOOST_CHECK_EQUAL(mdp.d_header.ancount, 0); BOOST_CHECK_EQUAL(mdp.d_header.nscount, 0); - BOOST_CHECK_EQUAL(mdp.d_header.arcount, 1); + BOOST_CHECK_EQUAL(mdp.d_header.arcount, (hasEdns ? 1 : 0)); } static void validateResponse(const char * packet, size_t packetSize, bool hasEdns, uint8_t additionalCount=0) @@ -71,7 +71,6 @@ static void validateResponse(const char * packet, size_t packetSize, bool hasEdn BOOST_AUTO_TEST_CASE(addECSWithoutEDNS) { - string largerPacket; bool ednsAdded = false; bool ecsAdded = false; ComboAddress remote; @@ -92,30 +91,29 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS) BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); + BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6)); BOOST_CHECK((size_t) len > query.size()); - BOOST_CHECK_EQUAL(largerPacket.size(), 0); BOOST_CHECK_EQUAL(ednsAdded, true); BOOST_CHECK_EQUAL(ecsAdded, false); validateQuery(packet, len); /* not large enough packet */ + ednsAdded = false; + ecsAdded = false; consumed = 0; len = query.size(); - qname = DNSName((char*) query.data(), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + qname = DNSName(reinterpret_cast(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); + BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6)); BOOST_CHECK_EQUAL((size_t) len, query.size()); - BOOST_CHECK(largerPacket.size() > query.size()); - BOOST_CHECK_EQUAL(ednsAdded, true); + BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(largerPacket.c_str(), largerPacket.size()); + validateQuery(reinterpret_cast(query.data()), len, false); } BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) { - string largerPacket; bool ednsAdded = false; bool ecsAdded = false; ComboAddress remote; @@ -138,30 +136,29 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) { BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); + BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6)); BOOST_CHECK((size_t) len > query.size()); - BOOST_CHECK_EQUAL(largerPacket.size(), 0); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, true); validateQuery(packet, len); /* not large enough packet */ consumed = 0; + ednsAdded = false; + ecsAdded = false; len = query.size(); - qname = DNSName((char*) query.data(), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + qname = DNSName(reinterpret_cast(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); + BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6)); BOOST_CHECK_EQUAL((size_t) len, query.size()); - BOOST_CHECK(largerPacket.size() > query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); - BOOST_CHECK_EQUAL(ecsAdded, true); - validateQuery(largerPacket.c_str(), largerPacket.size()); + BOOST_CHECK_EQUAL(ecsAdded, false); + validateQuery(reinterpret_cast(query.data()), len); } BOOST_AUTO_TEST_CASE(replaceECSWithSameSize) { - string largerPacket; bool ednsAdded = false; bool ecsAdded = false; ComboAddress remote("192.168.1.25"); @@ -190,16 +187,14 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSize) { BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); + BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6)); BOOST_CHECK_EQUAL((size_t) len, query.size()); - BOOST_CHECK_EQUAL(largerPacket.size(), 0); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); validateQuery(packet, len); } BOOST_AUTO_TEST_CASE(replaceECSWithSmaller) { - string largerPacket; bool ednsAdded = false; bool ecsAdded = false; ComboAddress remote("192.168.1.25"); @@ -228,16 +223,14 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSmaller) { BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); + BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6)); BOOST_CHECK((size_t) len < query.size()); - BOOST_CHECK_EQUAL(largerPacket.size(), 0); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); validateQuery(packet, len); } BOOST_AUTO_TEST_CASE(replaceECSWithLarger) { - string largerPacket; bool ednsAdded = false; bool ecsAdded = false; ComboAddress remote("192.168.1.25"); @@ -266,26 +259,26 @@ BOOST_AUTO_TEST_CASE(replaceECSWithLarger) { BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); + BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6)); BOOST_CHECK((size_t) len > query.size()); - BOOST_CHECK_EQUAL(largerPacket.size(), 0); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); validateQuery(packet, len); /* not large enough packet */ + ednsAdded = false; + ecsAdded = false; consumed = 0; len = query.size(); - qname = DNSName((char*) query.data(), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + qname = DNSName(reinterpret_cast(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); + BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6)); BOOST_CHECK_EQUAL((size_t) len, query.size()); - BOOST_CHECK(largerPacket.size() > query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(largerPacket.c_str(), largerPacket.size()); + validateQuery(reinterpret_cast(query.data()), len); } BOOST_AUTO_TEST_CASE(removeEDNSWhenFirst) {