bool outstanding = false;
time_t lastTCPCleanup = time(nullptr);
-
- auto localPolicy = g_policy.getLocal();
- auto localRulactions = g_rulactions.getLocal();
+ LocalHolders holders;
auto localRespRulactions = g_resprulactions.getLocal();
- auto localCacheHitRespRulactions = g_cachehitresprulactions.getLocal();
- auto localDynBlockNMG = g_dynblockNMG.getLocal();
- auto localDynBlockSMT = g_dynblockSMT.getLocal();
- auto localPools = g_pools.getLocal();
-#ifdef HAVE_PROTOBUF
- boost::uuids::random_generator uuidGenerator;
-#endif
#ifdef HAVE_DNSCRYPT
/* when the answer is encrypted in place, we need to get a copy
of the original header before encryption to fill the ring buffer */
delete citmp;
uint16_t qlen, rlen;
- string largerQuery;
vector<uint8_t> rewrittenResponse;
shared_ptr<DownstreamState> ds;
ComboAddress dest;
socklen_t len = dest.getSocklen();
size_t queriesCount = 0;
time_t connectionStartTime = time(NULL);
+ std::vector<char> queryBuffer;
if (!setNonBlocking(ci.fd))
goto drop;
if(!getNonBlockingMsgLen(ci.fd, &qlen, g_tcpRecvTimeout))
break;
+ queriesCount++;
+
+ if (qlen < sizeof(dnsheader)) {
+ g_stats.nonCompliantQueries++;
+ break;
+ }
+
ci.cs->queries++;
g_stats.queries++;
- queriesCount++;
-
if (g_maxTCPQueriesPerConn && queriesCount > g_maxTCPQueriesPerConn) {
vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", ci.remote.toStringWithPort(), queriesCount, g_maxTCPQueriesPerConn);
break;
break;
}
- if (qlen < sizeof(dnsheader)) {
- g_stats.nonCompliantQueries++;
- break;
- }
-
bool ednsAdded = false;
bool ecsAdded = false;
- /* if the query is small, allocate a bit more
- memory to be able to spoof the content,
+ /* allocate a bit more memory to be able to spoof the content,
or to add ECS without allocating a new buffer */
- size_t querySize = qlen <= 4096 ? qlen + 512 : qlen;
- char queryBuffer[querySize];
- const char* query = queryBuffer;
- readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout, remainingTime);
+ queryBuffer.reserve(qlen + 512);
+
+ char* query = &queryBuffer[0];
+ readn2WithTimeout(ci.fd, query, qlen, g_tcpRecvTimeout, remainingTime);
#ifdef HAVE_DNSCRYPT
- std::shared_ptr<DnsCryptQuery> dnsCryptQuery = 0;
+ std::shared_ptr<DnsCryptQuery> dnsCryptQuery = nullptr;
if (ci.cs->dnscryptCtx) {
dnsCryptQuery = std::make_shared<DnsCryptQuery>();
uint16_t decryptedQueryLen = 0;
vector<uint8_t> response;
- bool decrypted = handleDnsCryptQuery(ci.cs->dnscryptCtx, queryBuffer, qlen, dnsCryptQuery, &decryptedQueryLen, true, response);
+ bool decrypted = handleDnsCryptQuery(ci.cs->dnscryptCtx, query, qlen, dnsCryptQuery, &decryptedQueryLen, true, response);
if (!decrypted) {
if (response.size() > 0) {
qlen = decryptedQueryLen;
}
#endif
- struct dnsheader* dh = (struct dnsheader*) query;
-
- if(dh->qr) { // don't respond to responses
- g_stats.nonCompliantQueries++;
- goto drop;
- }
+ struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
- if(dh->qdcount == 0) {
- g_stats.emptyQueries++;
+ if (!checkQueryHeaders(dh)) {
goto drop;
}
- if (dh->rd) {
- g_stats.rdQueries++;
- }
-
const uint16_t* flags = getFlagsFromDNSHeader(dh);
uint16_t origFlags = *flags;
uint16_t qtype, qclass;
unsigned int consumed = 0;
DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
- DNSQuestion dq(&qname, qtype, qclass, &dest, &ci.remote, (dnsheader*)query, querySize, qlen, true);
+ DNSQuestion dq(&qname, qtype, qclass, &dest, &ci.remote, dh, queryBuffer.capacity(), qlen, true);
#ifdef HAVE_PROTOBUF
- dq.uniqueId = uuidGenerator();
+ dq.uniqueId = t_uuidGenerator();
#endif
string poolname;
gettime(&now);
gettime(&queryRealTime, true);
- if (!processQuery(localDynBlockNMG, localDynBlockSMT, localRulactions, dq, poolname, &delayMsec, now)) {
+ if (!processQuery(holders, dq, poolname, &delayMsec, now)) {
goto drop;
}
if(dq.dh->qr) { // something turned it into a response
restoreFlags(dh, origFlags);
#ifdef HAVE_DNSCRYPT
- if (!encryptResponse(queryBuffer, &dq.len, dq.size, true, dnsCryptQuery, nullptr, nullptr)) {
+ if (!encryptResponse(query, &dq.len, dq.size, true, dnsCryptQuery, nullptr, nullptr)) {
goto drop;
}
#endif
goto drop;
}
- std::shared_ptr<ServerPool> serverPool = getPool(*localPools, poolname);
+ std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, poolname);
std::shared_ptr<DNSDistPacketCache> packetCache = nullptr;
- auto policy = localPolicy->policy;
+ auto policy = holders.policy->policy;
if (serverPool->policy != nullptr) {
policy = serverPool->policy->policy;
}
if (dq.useECS && ds && ds->useECS) {
uint16_t newLen = dq.len;
- handleEDNSClientSubnet(queryBuffer, dq.size, consumed, &newLen, largerQuery, &ednsAdded, &ecsAdded, ci.remote, dq.ecsOverride, dq.ecsPrefixLength);
- if (largerQuery.empty() == false) {
- query = largerQuery.c_str();
- dq.len = (uint16_t) largerQuery.size();
- dq.size = largerQuery.size();
- } else {
- dq.len = newLen;
+ if (!handleEDNSClientSubnet(query, dq.size, consumed, &newLen, &ednsAdded, &ecsAdded, ci.remote, dq.ecsOverride, dq.ecsPrefixLength)) {
+ vinfolog("Dropping query from %s because we couldn't insert the ECS value", ci.remote.toStringWithPort());
+ goto drop;
}
+ dq.len = newLen;
}
uint32_t cacheKey = 0;
#ifdef HAVE_PROTOBUF
dr.uniqueId = dq.uniqueId;
#endif
- if (!processResponse(localCacheHitRespRulactions, dr, &delayMsec)) {
+ if (!processResponse(holders.cacheHitRespRulactions, dr, &delayMsec)) {
goto drop;
}
dq.dh->qr = true;
#ifdef HAVE_DNSCRYPT
- if (!encryptResponse(queryBuffer, &dq.len, dq.size, true, dnsCryptQuery, nullptr, nullptr)) {
+ if (!encryptResponse(query, &dq.len, dq.size, true, dnsCryptQuery, nullptr, nullptr)) {
goto drop;
}
#endif
g_rings.respRing.push_back({answertime, ci.remote, qname, dq.qtype, (unsigned int)udiff, (unsigned int)responseLen, *dh, ds->remote});
}
- largerQuery.clear();
rewrittenResponse.clear();
}
}
#include <systemd/sd-daemon.h>
#endif
+#ifdef HAVE_PROTOBUF
+thread_local boost::uuids::random_generator t_uuidGenerator;
+#endif
+
/* Known sins:
Receiver is currently single threaded
#endif /* HAVE_EBPF */
vector<ClientState *> g_frontends;
GlobalStateHolder<pools_t> g_pools;
+size_t g_udpVectorSize{1};
bool g_snmpEnabled{false};
bool g_snmpTrapsEnabled{false};
bool g_servFailOnNoPolicy{false};
bool g_truncateTC{false};
bool g_fixupCase{0};
+
+static const size_t s_udpIncomingBufferSize{1500};
+
static void truncateTC(const char* packet, uint16_t* len)
try
{
if (ids->origFD == origFD) {
#ifdef HAVE_DNSCRYPT
- ids->dnsCryptQuery = 0;
+ ids->dnsCryptQuery = nullptr;
#endif
ids->origFD = -1;
}
}
}
-bool processQuery(LocalStateHolder<NetmaskTree<DynBlock> >& localDynNMGBlock,
- LocalStateHolder<SuffixMatchTree<DynBlock> >& localDynSMTBlock,
- LocalStateHolder<vector<pair<std::shared_ptr<DNSRule>, std::shared_ptr<DNSAction> > > >& localRulactions, DNSQuestion& dq, string& poolname, int* delayMsec, const struct timespec& now)
+bool processQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, int* delayMsec, const struct timespec& now)
{
{
WriteLock wl(&g_rings.queryLock);
}
}
- if(auto got=localDynNMGBlock->lookup(*dq.remote)) {
+ if(auto got = holders.dynNMGBlock->lookup(*dq.remote)) {
auto updateBlockStats = [&got]() {
g_stats.dynBlocked++;
got->second.blocks++;
}
}
- if(auto got=localDynSMTBlock->lookup(*dq.qname)) {
+ if(auto got = holders.dynSMTBlock->lookup(*dq.qname)) {
auto updateBlockStats = [&got]() {
g_stats.dynBlocked++;
got->blocks++;
DNSAction::Action action=DNSAction::Action::None;
string ruleresult;
- for(const auto& lr : *localRulactions) {
+ for(const auto& lr : *holders.rulactions) {
if(lr.first->matches(&dq)) {
lr.first->d_matches++;
action=(*lr.second)(&dq, &ruleresult);
return result;
}
-// listens to incoming queries, sends out to downstream servers, noting the intended return path
-static void* udpClientThread(ClientState* cs)
-try
+static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest)
{
- string largerQuery;
- uint16_t qtype, qclass;
- uint16_t queryId;
-#ifdef HAVE_PROTOBUF
- boost::uuids::random_generator uuidGenerator;
-#endif
+ if (msgh->msg_flags & MSG_TRUNC) {
+ /* message was too large for our buffer */
+ vinfolog("Dropping message too large for our buffer");
+ g_stats.nonCompliantQueries++;
+ return false;
+ }
- auto acl = g_ACL.getLocal();
- auto localPolicy = g_policy.getLocal();
- auto localRulactions = g_rulactions.getLocal();
- auto localCacheHitRespRulactions = g_cachehitresprulactions.getLocal();
- auto localServers = g_dstates.getLocal();
- auto localDynNMGBlock = g_dynblockNMG.getLocal();
- auto localDynSMTBlock = g_dynblockSMT.getLocal();
- auto localPools = g_pools.getLocal();
-
- static const size_t vectSize = 50;
- struct
- {
- char packet[4096];
- /* used by HarvestDestinationAddress */
- char cbuf[256];
- ComboAddress remote;
- ComboAddress dest;
- struct iovec iov;
+ if(!holders.acl->match(remote)) {
+ vinfolog("Query from %s dropped because of ACL", remote.toStringWithPort());
+ g_stats.aclDrops++;
+ return false;
}
- data[vectSize];
- struct mmsghdr msgVec[vectSize];
- struct mmsghdr outMsgVec[vectSize];
- for (size_t idx = 0; idx < vectSize; idx++) {
- data[idx].remote.sin4.sin_family = cs->local.sin4.sin_family;
+ cs.queries++;
+ g_stats.queries++;
- fillMSGHdr(&msgVec[idx].msg_hdr, &data[idx].iov, data[idx].cbuf, sizeof(data[idx].cbuf), data[idx].packet, sizeof(data[idx].packet), &data[idx].remote);
+ if (HarvestDestinationAddress(msgh, &dest)) {
+ /* we don't get the port, only the address */
+ dest.sin4.sin_port = cs.local.sin4.sin_port;
+ }
+ else {
+ dest.sin4.sin_family = 0;
}
- for(;;) {
+ return true;
+}
+
#ifdef HAVE_DNSCRYPT
- std::shared_ptr<DnsCryptQuery> dnsCryptQuery = 0;
-#endif
- for (size_t idx = 0; idx < vectSize; idx++) {
- data[idx].iov.iov_base = data[idx].packet;
- data[idx].iov.iov_len = sizeof(data[idx].packet);
- }
- int msgsGot = recvmmsg(cs->udpFD, msgVec, vectSize, MSG_WAITFORONE | MSG_TRUNC, nullptr);
+static bool checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr<DnsCryptQuery>& dnsCryptQuery, const ComboAddress& dest, const ComboAddress& remote)
+{
+ if (cs.dnscryptCtx) {
+ vector<uint8_t> response;
+ uint16_t decryptedQueryLen = 0;
- if (msgsGot <= 0) {
- vinfolog("recvmmsg() failed with: %s", strerror(errno));
- continue;
+ dnsCryptQuery = std::make_shared<DnsCryptQuery>();
+
+ bool decrypted = handleDnsCryptQuery(cs.dnscryptCtx, const_cast<char*>(query), len, dnsCryptQuery, &decryptedQueryLen, false, response);
+
+ if (!decrypted) {
+ if (response.size() > 0) {
+ sendUDPResponse(cs.udpFD, reinterpret_cast<char*>(response.data()), static_cast<uint16_t>(response.size()), 0, dest, remote);
}
- //vinfolog("Got %d messages", msgsGot);
- unsigned int msgsToSend = 0;
+ return false;
+ }
- for (int msgIdx = 0; msgIdx < msgsGot; msgIdx++) {
- const struct msghdr* msgh = &msgVec[msgIdx].msg_hdr;
- unsigned int ret = msgVec[msgIdx].msg_len;
- const ComboAddress& remote = data[msgIdx].remote;
+ len = decryptedQueryLen;
+ }
+ return true;
+}
+#endif /* HAVE_DNSCRYPT */
- try {
- char* query = data[msgIdx].packet;
+bool checkQueryHeaders(const struct dnsheader* dh)
+{
+ if (dh->qr) { // don't respond to responses
+ g_stats.nonCompliantQueries++;
+ return false;
+ }
- queryId = 0;
+ if (dh->qdcount == 0) {
+ g_stats.emptyQueries++;
+ return false;
+ }
- if(!acl->match(remote)) {
- vinfolog("Query from %s dropped because of ACL", remote.toStringWithPort());
- g_stats.aclDrops++;
- continue;
- }
+ if (dh->rd) {
+ g_stats.rdQueries++;
+ }
- cs->queries++;
- g_stats.queries++;
+ return true;
+}
- if(ret < (int)sizeof(struct dnsheader)) {
- g_stats.nonCompliantQueries++;
- continue;
- }
+#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
+static void queueResponse(const ClientState& cs, const char* response, uint16_t responseLen, const ComboAddress& dest, const ComboAddress& remote, struct mmsghdr& outMsg, struct iovec* iov, char* cbuf)
+{
+ outMsg.msg_len = 0;
+ fillMSGHdr(&outMsg.msg_hdr, iov, nullptr, 0, const_cast<char*>(response), responseLen, const_cast<ComboAddress*>(&remote));
- if (msgh->msg_flags & MSG_TRUNC) {
- /* message was too large for our buffer */
- vinfolog("Dropping message too large for our buffer");
- g_stats.nonCompliantQueries++;
- continue;
- }
+ if (dest.sin4.sin_family == 0) {
+ outMsg.msg_hdr.msg_control = nullptr;
+ }
+ else {
+ addCMsgSrcAddr(&outMsg.msg_hdr, cbuf, &dest, 0);
+ }
+}
+#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
- uint16_t len = (uint16_t) ret;
+static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, char* query, uint16_t len, size_t queryBufferSize, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, char* respCBuf)
+{
+ assert(responsesVect == nullptr || (queuedResponses != nullptr && respIOV != nullptr && respCBuf != nullptr));
+ uint16_t queryId = 0;
- if (HarvestDestinationAddress(msgh, &data[msgIdx].dest)) {
- /* we don't get the port, only the address */
- data[msgIdx].dest.sin4.sin_port = cs->local.sin4.sin_port;
- }
- else {
- data[msgIdx].dest.sin4.sin_family = 0;
- }
+ try {
+ if (!isUDPQueryAcceptable(cs, holders, msgh, remote, dest)) {
+ return;
+ }
#ifdef HAVE_DNSCRYPT
- if (cs->dnscryptCtx) {
- vector<uint8_t> response;
- uint16_t decryptedQueryLen = 0;
- dnsCryptQuery = std::make_shared<DnsCryptQuery>();
+ std::shared_ptr<DnsCryptQuery> dnsCryptQuery = nullptr;
- bool decrypted = handleDnsCryptQuery(cs->dnscryptCtx, query, len, dnsCryptQuery, &decryptedQueryLen, false, response);
-
- if (!decrypted) {
- if (response.size() > 0) {
- sendUDPResponse(cs->udpFD, reinterpret_cast<char*>(response.data()), (uint16_t) response.size(), 0, data[msgIdx].dest, remote);
- }
- continue;
- }
- len = decryptedQueryLen;
- }
+ if (!checkDNSCryptQuery(cs, query, len, dnsCryptQuery, dest, remote)) {
+ return;
+ }
#endif
- struct dnsheader* dh = (struct dnsheader*) query;
- queryId = ntohs(dh->id);
+ struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
+ queryId = ntohs(dh->id);
- if(dh->qr) { // don't respond to responses
- g_stats.nonCompliantQueries++;
- continue;
- }
-
- if(dh->qdcount == 0) {
- g_stats.emptyQueries++;
- continue;
- }
+ if (!checkQueryHeaders(dh)) {
+ return;
+ }
- if (dh->rd) {
- g_stats.rdQueries++;
- }
+ const uint16_t * flags = getFlagsFromDNSHeader(dh);
+ const uint16_t origFlags = *flags;
+ uint16_t qtype, qclass;
+ unsigned int consumed = 0;
+ DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
+ DNSQuestion dq(&qname, qtype, qclass, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, dh, queryBufferSize, len, false);
- const uint16_t * flags = getFlagsFromDNSHeader(dh);
- const uint16_t origFlags = *flags;
- unsigned int consumed = 0;
- DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
- DNSQuestion dq(&qname, qtype, qclass, data[msgIdx].dest.sin4.sin_family != 0 ? &data[msgIdx].dest : &cs->local, &remote, dh, sizeof(data[msgIdx].packet), len, false);
#ifdef HAVE_PROTOBUF
- dq.uniqueId = uuidGenerator();
+ dq.uniqueId = t_uuidGenerator();
#endif
- string poolname;
- int delayMsec=0;
- /* we need an accurate ("real") value for the response and
- to store into the IDS, but not for insertion into the
- rings for example */
- struct timespec realTime;
- struct timespec now;
- gettime(&now);
- gettime(&realTime, true);
-
- if (!processQuery(localDynNMGBlock, localDynSMTBlock, localRulactions, dq, poolname, &delayMsec, now))
- {
- continue;
- }
+ string poolname;
+ int delayMsec = 0;
+ /* we need an accurate ("real") value for the response and
+ to store into the IDS, but not for insertion into the
+ rings for example */
+ struct timespec realTime;
+ struct timespec now;
+ gettime(&now);
+ gettime(&realTime, true);
+
+ if (!processQuery(holders, dq, poolname, &delayMsec, now))
+ {
+ return;
+ }
- if(dq.dh->qr) { // something turned it into a response
+ if(dq.dh->qr) { // something turned it into a response
+ g_stats.selfAnswered++;
+ restoreFlags(dh, origFlags);
+
+ if (!cs.muted) {
char* response = query;
uint16_t responseLen = dq.len;
- g_stats.selfAnswered++;
-
- restoreFlags(dh, origFlags);
- if (!cs->muted) {
#ifdef HAVE_DNSCRYPT
- if (!encryptResponse(response, &responseLen, dq.size, false, dnsCryptQuery, nullptr, nullptr)) {
- continue;
- }
+ if (!encryptResponse(response, &responseLen, dq.size, false, dnsCryptQuery, nullptr, nullptr)) {
+ return;
+ }
#endif
- outMsgVec[msgsToSend].msg_len = 0;
- fillMSGHdr(&outMsgVec[msgsToSend].msg_hdr, &data[msgIdx].iov, nullptr, 0, response, responseLen, &data[msgIdx].remote);
-
- if (data[msgIdx].dest.sin4.sin_family == 0) {
- outMsgVec[msgsToSend].msg_hdr.msg_control = nullptr;
- }
- else {
- addCMsgSrcAddr(&outMsgVec[msgsToSend].msg_hdr, &data[msgIdx].cbuf, &data[msgIdx].dest, 0);
- }
-
- msgsToSend++;
- //sendUDPResponse(cs->udpFD, response, responseLen, delayMsec, dest, remote);
+#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
+ if (delayMsec == 0 && responsesVect != nullptr) {
+ queueResponse(cs, response, responseLen, dest, remote, responsesVect[*queuedResponses], respIOV, respCBuf);
+ (*queuedResponses)++;
+ }
+ else
+#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
+ {
+ sendUDPResponse(cs.udpFD, response, responseLen, delayMsec, dest, remote);
}
-
- continue;
}
- DownstreamState* ss = nullptr;
- std::shared_ptr<ServerPool> serverPool = getPool(*localPools, poolname);
- std::shared_ptr<DNSDistPacketCache> packetCache = nullptr;
- auto policy = localPolicy->policy;
- if (serverPool->policy != nullptr) {
- policy = serverPool->policy->policy;
- }
- {
- std::lock_guard<std::mutex> lock(g_luamutex);
- ss = policy(serverPool->servers, &dq).get();
- packetCache = serverPool->packetCache;
- }
+ return;
+ }
+
+ DownstreamState* ss = nullptr;
+ std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, poolname);
+ std::shared_ptr<DNSDistPacketCache> packetCache = nullptr;
+ auto policy = holders.policy->policy;
+ if (serverPool->policy != nullptr) {
+ policy = serverPool->policy->policy;
+ }
+ {
+ std::lock_guard<std::mutex> lock(g_luamutex);
+ ss = policy(serverPool->servers, &dq).get();
+ packetCache = serverPool->packetCache;
+ }
- bool ednsAdded = false;
- bool ecsAdded = false;
- if (dq.useECS && ss && ss->useECS) {
- handleEDNSClientSubnet(query, dq.size, consumed, &dq.len, largerQuery, &(ednsAdded), &(ecsAdded), remote, dq.ecsOverride, dq.ecsPrefixLength);
+ bool ednsAdded = false;
+ bool ecsAdded = false;
+ if (dq.useECS && ss && ss->useECS) {
+ if (!handleEDNSClientSubnet(query, dq.size, consumed, &dq.len, &(ednsAdded), &(ecsAdded), remote, dq.ecsOverride, dq.ecsPrefixLength)) {
+ vinfolog("Dropping query from µs because we couldn't insert the ECS value", remote.toStringWithPort());
+ return;
}
+ }
- uint32_t cacheKey = 0;
- if (packetCache && !dq.skipCache) {
-// char cachedResponse[4096];
- uint16_t cachedResponseSize = dq.size;
- uint32_t allowExpired = ss ? 0 : g_staleCacheEntriesTTL;
- if (packetCache->get(dq, consumed, dh->id, query, &cachedResponseSize, &cacheKey, allowExpired)) {
- DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.local, dq.remote, (dnsheader*) query, dq.size, cachedResponseSize, false, &realTime);
+ uint32_t cacheKey = 0;
+ if (packetCache && !dq.skipCache) {
+ uint16_t cachedResponseSize = dq.size;
+ uint32_t allowExpired = ss ? 0 : g_staleCacheEntriesTTL;
+ if (packetCache->get(dq, consumed, dh->id, query, &cachedResponseSize, &cacheKey, allowExpired)) {
+ DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.local, dq.remote, reinterpret_cast<dnsheader*>(query), dq.size, cachedResponseSize, false, &realTime);
#ifdef HAVE_PROTOBUF
- dr.uniqueId = dq.uniqueId;
+ dr.uniqueId = dq.uniqueId;
#endif
- if (!processResponse(localCacheHitRespRulactions, dr, &delayMsec)) {
- continue;
- }
+ if (!processResponse(holders.cacheHitRespRulactions, dr, &delayMsec)) {
+ return;
+ }
- if (!cs->muted) {
+ if (!cs.muted) {
#ifdef HAVE_DNSCRYPT
- if (!encryptResponse(query, &cachedResponseSize, dq.size, false, dnsCryptQuery, nullptr, nullptr)) {
- continue;
- }
+ if (!encryptResponse(query, &cachedResponseSize, dq.size, false, dnsCryptQuery, nullptr, nullptr)) {
+ return;
+ }
#endif
- //sendUDPResponse(cs->udpFD, cachedResponse, cachedResponseSize, delayMsec, dest, remote);
- outMsgVec[msgsToSend].msg_len = 0;
- fillMSGHdr(&outMsgVec[msgsToSend].msg_hdr, &data[msgIdx].iov, nullptr, 0, query, cachedResponseSize, &data[msgIdx].remote);
- if (data[msgIdx].dest.sin4.sin_family == 0) {
- outMsgVec[msgsToSend].msg_hdr.msg_control = nullptr;
- }
- else {
- addCMsgSrcAddr(&outMsgVec[msgsToSend].msg_hdr, &data[msgIdx].cbuf, &data[msgIdx].dest, 0);
- }
-
- msgsToSend++;
+#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
+ if (delayMsec == 0 && responsesVect != nullptr) {
+ queueResponse(cs, query, cachedResponseSize, dest, remote, responsesVect[*queuedResponses], respIOV, respCBuf);
+ (*queuedResponses)++;
+ }
+ else
+#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
+ {
+ sendUDPResponse(cs.udpFD, query, cachedResponseSize, delayMsec, dest, remote);
}
-
- g_stats.cacheHits++;
- g_stats.latency0_1++; // we're not going to measure this
- doLatencyAverages(0); // same
- continue;
}
- g_stats.cacheMisses++;
+
+ g_stats.cacheHits++;
+ g_stats.latency0_1++; // we're not going to measure this
+ doLatencyAverages(0); // same
+ return;
}
+ g_stats.cacheMisses++;
+ }
- if(!ss) {
- g_stats.noPolicy++;
+ if(!ss) {
+ g_stats.noPolicy++;
- if (g_servFailOnNoPolicy) {
- char* response = query;
- uint16_t responseLen = dq.len;
- restoreFlags(dh, origFlags);
+ if (g_servFailOnNoPolicy && !cs.muted) {
+ char* response = query;
+ uint16_t responseLen = dq.len;
+ restoreFlags(dh, origFlags);
- dq.dh->rcode = RCode::ServFail;
- dq.dh->qr = true;
+ dq.dh->rcode = RCode::ServFail;
+ dq.dh->qr = true;
#ifdef HAVE_DNSCRYPT
- if (!encryptResponse(response, &responseLen, dq.size, false, dnsCryptQuery, nullptr, nullptr)) {
- continue;
- }
+ if (!encryptResponse(response, &responseLen, dq.size, false, dnsCryptQuery, nullptr, nullptr)) {
+ return;
+ }
#endif
- outMsgVec[msgsToSend].msg_len = 0;
- fillMSGHdr(&outMsgVec[msgsToSend].msg_hdr, &data[msgIdx].iov, nullptr, 0, response, responseLen, &data[msgIdx].remote);
- if (data[msgIdx].dest.sin4.sin_family == 0) {
- outMsgVec[msgsToSend].msg_hdr.msg_control = nullptr;
- }
- else {
- addCMsgSrcAddr(&outMsgVec[msgsToSend].msg_hdr, &data[msgIdx].cbuf, &data[msgIdx].dest, 0);
- }
-
- msgsToSend++;
-// sendUDPResponse(cs->udpFD, response, responseLen, 0, dest, remote);
+#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
+ if (responsesVect != nullptr) {
+ queueResponse(cs, response, responseLen, dest, remote, responsesVect[*queuedResponses], respIOV, respCBuf);
+ (*queuedResponses)++;
+ }
+ else
+#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
+ {
+ sendUDPResponse(cs.udpFD, response, responseLen, 0, dest, remote);
}
- vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "Dropped" : "ServFailed", dq.qname->toString(), QType(dq.qtype).getName(), remote.toStringWithPort());
- continue;
-
}
+ vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "Dropped" : "ServFailed", dq.qname->toString(), QType(dq.qtype).getName(), remote.toStringWithPort());
+ return;
+ }
- ss->queries++;
-
- unsigned int idOffset = (ss->idOffset++) % ss->idStates.size();
- IDState* ids = &ss->idStates[idOffset];
- ids->age = 0;
+ ss->queries++;
- if(ids->origFD < 0) // if we are reusing, no change in outstanding
- ss->outstanding++;
- else {
- ss->reuseds++;
- g_stats.downstreamTimeouts++;
- }
+ unsigned int idOffset = (ss->idOffset++) % ss->idStates.size();
+ IDState* ids = &ss->idStates[idOffset];
+ ids->age = 0;
- ids->cs = cs;
- ids->origFD = cs->udpFD;
- ids->origID = dh->id;
- ids->origRemote = remote;
- ids->sentTime.set(realTime);
- ids->qname = qname;
- ids->qtype = dq.qtype;
- ids->qclass = dq.qclass;
- ids->delayMsec = delayMsec;
- ids->origFlags = origFlags;
- ids->cacheKey = cacheKey;
- ids->skipCache = dq.skipCache;
- ids->packetCache = packetCache;
- ids->ednsAdded = ednsAdded;
- ids->ecsAdded = ecsAdded;
-
- /* If we couldn't harvest the real dest addr, still
- write down the listening addr since it will be useful
- (especially if it's not an 'any' one).
- We need to keep track of which one it is since we may
- want to use the real but not the listening addr to reply.
- */
- if (data[msgIdx].dest.sin4.sin_family != 0) {
- ids->origDest = data[msgIdx].dest;
- ids->destHarvested = true;
- }
- else {
- ids->origDest = cs->local;
- ids->destHarvested = false;
- }
+ if(ids->origFD < 0) // if we are reusing, no change in outstanding
+ ss->outstanding++;
+ else {
+ ss->reuseds++;
+ g_stats.downstreamTimeouts++;
+ }
+
+ ids->cs = &cs;
+ ids->origFD = cs.udpFD;
+ ids->origID = dh->id;
+ ids->origRemote = remote;
+ ids->sentTime.set(realTime);
+ ids->qname = qname;
+ ids->qtype = dq.qtype;
+ ids->qclass = dq.qclass;
+ ids->delayMsec = delayMsec;
+ ids->origFlags = origFlags;
+ ids->cacheKey = cacheKey;
+ ids->skipCache = dq.skipCache;
+ ids->packetCache = packetCache;
+ ids->ednsAdded = ednsAdded;
+ ids->ecsAdded = ecsAdded;
+
+ /* If we couldn't harvest the real dest addr, still
+ write down the listening addr since it will be useful
+ (especially if it's not an 'any' one).
+ We need to keep track of which one it is since we may
+ want to use the real but not the listening addr to reply.
+ */
+ if (dest.sin4.sin_family != 0) {
+ ids->origDest = dest;
+ ids->destHarvested = true;
+ }
+ else {
+ ids->origDest = cs.local;
+ ids->destHarvested = false;
+ }
#ifdef HAVE_DNSCRYPT
- ids->dnsCryptQuery = dnsCryptQuery;
+ ids->dnsCryptQuery = dnsCryptQuery;
#endif
#ifdef HAVE_PROTOBUF
- ids->uniqueId = dq.uniqueId;
+ ids->uniqueId = dq.uniqueId;
#endif
- dh->id = idOffset;
+ dh->id = idOffset;
- if (largerQuery.empty()) {
- ret = udpClientSendRequestToBackend(ss, ss->fd, query, dq.len);
- }
- else {
- ret = udpClientSendRequestToBackend(ss, ss->fd, largerQuery.c_str(), largerQuery.size());
- largerQuery.clear();
- }
+ ssize_t ret = udpClientSendRequestToBackend(ss, ss->fd, query, dq.len);
- if(ret < 0) {
- ss->sendErrors++;
- g_stats.downstreamSendErrors++;
- }
+ if(ret < 0) {
+ ss->sendErrors++;
+ g_stats.downstreamSendErrors++;
+ }
+
+ vinfolog("Got query for %s|%s from %s, relayed to %s", ids->qname.toString(), QType(ids->qtype).getName(), remote.toStringWithPort(), ss->getName());
+ }
+ catch(const std::exception& e){
+ vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
+ }
+}
+
+#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
+static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holders)
+{
+ struct MMReceiver
+ {
+ char packet[4096];
+ /* used by HarvestDestinationAddress */
+ char cbuf[256];
+ ComboAddress remote;
+ ComboAddress dest;
+ struct iovec iov;
+ };
+ const size_t vectSize = g_udpVectorSize;
+ /* the actual buffer is larger because:
+ - we may have to add EDNS and/or ECS
+ - we use it for self-generated responses (from rule or cache)
+ but we only accept incoming payloads up to that size
+ */
+ static_assert(s_udpIncomingBufferSize <= sizeof(MMReceiver::packet), "the incoming buffer size should not be larger than sizeof(MMReceiver::packet)");
+
+ auto recvData = std::unique_ptr<MMReceiver[]>(new MMReceiver[vectSize]);
+ auto msgVec = std::unique_ptr<struct mmsghdr[]>(new struct mmsghdr[vectSize]);
+ auto outMsgVec = std::unique_ptr<struct mmsghdr[]>(new struct mmsghdr[vectSize]);
+
+ /* initialize the structures needed to receive our messages */
+ for (size_t idx = 0; idx < vectSize; idx++) {
+ recvData[idx].remote.sin4.sin_family = cs->local.sin4.sin_family;
+ fillMSGHdr(&msgVec[idx].msg_hdr, &recvData[idx].iov, recvData[idx].cbuf, sizeof(recvData[idx].cbuf), recvData[idx].packet, s_udpIncomingBufferSize, &recvData[idx].remote);
+ }
+
+ /* go now */
+ for(;;) {
- vinfolog("Got query for %s|%s from %s, relayed to %s", ids->qname.toString(), QType(ids->qtype).getName(), remote.toStringWithPort(), ss->getName());
+ /* reset the IO vector, since it's also used to send the vector of responses
+ to avoid having to copy the data around */
+ for (size_t idx = 0; idx < vectSize; idx++) {
+ recvData[idx].iov.iov_base = recvData[idx].packet;
+ recvData[idx].iov.iov_len = sizeof(recvData[idx].packet);
}
- catch(const std::exception& e){
- vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
+
+ /* block until we have at least one message ready, but return
+ as many as possible to save the syscall costs */
+ int msgsGot = recvmmsg(cs->udpFD, msgVec.get(), vectSize, MSG_WAITFORONE | MSG_TRUNC, nullptr);
+
+ if (msgsGot <= 0) {
+ vinfolog("Getting UDP messages via recvmmsg() failed with: %s", strerror(errno));
+ continue;
}
+
+ unsigned int msgsToSend = 0;
+
+ /* process the received messages */
+ for (int msgIdx = 0; msgIdx < msgsGot; msgIdx++) {
+ const struct msghdr* msgh = &msgVec[msgIdx].msg_hdr;
+ unsigned int got = msgVec[msgIdx].msg_len;
+ const ComboAddress& remote = recvData[msgIdx].remote;
+
+ if (got < sizeof(struct dnsheader)) {
+ g_stats.nonCompliantQueries++;
+ continue;
+ }
+
+ processUDPQuery(*cs, holders, msgh, remote, recvData[msgIdx].dest, recvData[msgIdx].packet, static_cast<uint16_t>(got), sizeof(recvData[msgIdx].packet), outMsgVec.get(), &msgsToSend, &recvData[msgIdx].iov, recvData[msgIdx].cbuf);
+
}
- if (msgsToSend > 0) {
- int sent = sendmmsg(cs->udpFD, outMsgVec, msgsToSend, 0);
+
+ /* immediate (not delayed or sent to a backend) responses (mostly from a rule, dynamic block
+ or the cache) can be sent in batch too */
+
+ if (msgsToSend > 0 && msgsToSend <= static_cast<unsigned int>(msgsGot)) {
+ int sent = sendmmsg(cs->udpFD, outMsgVec.get(), msgsToSend, 0);
+
if (sent < 0 || static_cast<unsigned int>(sent) != msgsToSend) {
- vinfolog("Error sending responses with sendmmsg (%d on %u): %s", sent, msgsToSend, strerror(errno));
+ vinfolog("Error sending responses with sendmmsg() (%d on %u): %s", sent, msgsToSend, strerror(errno));
}
- //vinfolog("Sent %d responses", sent);
}
+
}
- return 0;
+}
+#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
+
+// listens to incoming queries, sends out to downstream servers, noting the intended return path
+static void* udpClientThread(ClientState* cs)
+try
+{
+ LocalHolders holders;
+
+#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
+ if (g_udpVectorSize > 1) {
+ MultipleMessagesUDPClientThread(cs, holders);
+
+ }
+ else
+#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
+ {
+ char packet[4096];
+ /* the actual buffer is larger because:
+ - we may have to add EDNS and/or ECS
+ - we use it for self-generated responses (from rule or cache)
+ but we only accept incoming payloads up to that size
+ */
+ static_assert(s_udpIncomingBufferSize <= sizeof(packet), "the incoming buffer size should not be larger than sizeof(MMReceiver::packet)");
+ struct msghdr msgh;
+ struct iovec iov;
+ /* used by HarvestDestinationAddress */
+ char cbuf[256];
+
+ ComboAddress remote;
+ ComboAddress dest;
+ remote.sin4.sin_family = cs->local.sin4.sin_family;
+ fillMSGHdr(&msgh, &iov, cbuf, sizeof(cbuf), packet, sizeof(packet), &remote);
+
+ for(;;) {
+ ssize_t got = recvmsg(cs->udpFD, &msgh, 0);
+
+ if (got < 0 || static_cast<size_t>(got) < sizeof(struct dnsheader)) {
+ g_stats.nonCompliantQueries++;
+ continue;
+ }
+
+ processUDPQuery(*cs, holders, &msgh, remote, dest, packet, static_cast<uint16_t>(got), s_udpIncomingBufferSize, nullptr, nullptr, nullptr, nullptr);
+ }
+ }
+
+ return nullptr;
}
catch(const std::exception &e)
{
errlog("UDP client thread died because of exception: %s", e.what());
- return 0;
+ return nullptr;
}
catch(const PDNSException &e)
{
errlog("UDP client thread died because of PowerDNS exception: %s", e.reason);
- return 0;
+ return nullptr;
}
catch(...)
{
errlog("UDP client thread died because of an exception: %s", "unknown");
- return 0;
+ return nullptr;
}
-
static bool upCheck(DownstreamState& ds)
try
{
#ifdef HAVE_DNSCRYPT
cout<<"dnscrypt ";
#endif
+#ifdef HAVE_EBPF
+ cout<<"ebpf ";
+#endif
#ifdef HAVE_LIBSODIUM
cout<<"libsodium ";
#endif
#ifdef HAVE_RE2
cout<<"re2 ";
#endif
+#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
+ cout<<"recvmmsg/sendmmsg ";
+#endif
+#ifdef HAVE_NET_SNMP
+ cout<<"snmp ";
+#endif
#ifdef HAVE_SYSTEMD
cout<<"systemd";
#endif
static const uint16_t ECSSourcePrefixV4 = 24;
static const uint16_t ECSSourcePrefixV6 = 56;
-static void validateQuery(const char * packet, size_t packetSize)
+static void validateQuery(const char * packet, size_t packetSize, bool hasEdns=true)
{
MOADNSParser mdp(true, packet, packetSize);
BOOST_CHECK_EQUAL(mdp.d_header.qdcount, 1);
BOOST_CHECK_EQUAL(mdp.d_header.ancount, 0);
BOOST_CHECK_EQUAL(mdp.d_header.nscount, 0);
- BOOST_CHECK_EQUAL(mdp.d_header.arcount, 1);
+ BOOST_CHECK_EQUAL(mdp.d_header.arcount, (hasEdns ? 1 : 0));
}
static void validateResponse(const char * packet, size_t packetSize, bool hasEdns, uint8_t additionalCount=0)
BOOST_AUTO_TEST_CASE(addECSWithoutEDNS)
{
- string largerPacket;
bool ednsAdded = false;
bool ecsAdded = false;
ComboAddress remote;
BOOST_CHECK_EQUAL(qname, name);
BOOST_CHECK(qtype == QType::A);
- handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+ BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6));
BOOST_CHECK((size_t) len > query.size());
- BOOST_CHECK_EQUAL(largerPacket.size(), 0);
BOOST_CHECK_EQUAL(ednsAdded, true);
BOOST_CHECK_EQUAL(ecsAdded, false);
validateQuery(packet, len);
/* not large enough packet */
+ ednsAdded = false;
+ ecsAdded = false;
consumed = 0;
len = query.size();
- qname = DNSName((char*) query.data(), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+ qname = DNSName(reinterpret_cast<char*>(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
BOOST_CHECK_EQUAL(qname, name);
BOOST_CHECK(qtype == QType::A);
- handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+ BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6));
BOOST_CHECK_EQUAL((size_t) len, query.size());
- BOOST_CHECK(largerPacket.size() > query.size());
- BOOST_CHECK_EQUAL(ednsAdded, true);
+ BOOST_CHECK_EQUAL(ednsAdded, false);
BOOST_CHECK_EQUAL(ecsAdded, false);
- validateQuery(largerPacket.c_str(), largerPacket.size());
+ validateQuery(reinterpret_cast<char*>(query.data()), len, false);
}
BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) {
- string largerPacket;
bool ednsAdded = false;
bool ecsAdded = false;
ComboAddress remote;
BOOST_CHECK_EQUAL(qname, name);
BOOST_CHECK(qtype == QType::A);
- handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+ BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6));
BOOST_CHECK((size_t) len > query.size());
- BOOST_CHECK_EQUAL(largerPacket.size(), 0);
BOOST_CHECK_EQUAL(ednsAdded, false);
BOOST_CHECK_EQUAL(ecsAdded, true);
validateQuery(packet, len);
/* not large enough packet */
consumed = 0;
+ ednsAdded = false;
+ ecsAdded = false;
len = query.size();
- qname = DNSName((char*) query.data(), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+ qname = DNSName(reinterpret_cast<char*>(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
BOOST_CHECK_EQUAL(qname, name);
BOOST_CHECK(qtype == QType::A);
- handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+ BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6));
BOOST_CHECK_EQUAL((size_t) len, query.size());
- BOOST_CHECK(largerPacket.size() > query.size());
BOOST_CHECK_EQUAL(ednsAdded, false);
- BOOST_CHECK_EQUAL(ecsAdded, true);
- validateQuery(largerPacket.c_str(), largerPacket.size());
+ BOOST_CHECK_EQUAL(ecsAdded, false);
+ validateQuery(reinterpret_cast<char*>(query.data()), len);
}
BOOST_AUTO_TEST_CASE(replaceECSWithSameSize) {
- string largerPacket;
bool ednsAdded = false;
bool ecsAdded = false;
ComboAddress remote("192.168.1.25");
BOOST_CHECK_EQUAL(qname, name);
BOOST_CHECK(qtype == QType::A);
- handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+ BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6));
BOOST_CHECK_EQUAL((size_t) len, query.size());
- BOOST_CHECK_EQUAL(largerPacket.size(), 0);
BOOST_CHECK_EQUAL(ednsAdded, false);
BOOST_CHECK_EQUAL(ecsAdded, false);
validateQuery(packet, len);
}
BOOST_AUTO_TEST_CASE(replaceECSWithSmaller) {
- string largerPacket;
bool ednsAdded = false;
bool ecsAdded = false;
ComboAddress remote("192.168.1.25");
BOOST_CHECK_EQUAL(qname, name);
BOOST_CHECK(qtype == QType::A);
- handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+ BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6));
BOOST_CHECK((size_t) len < query.size());
- BOOST_CHECK_EQUAL(largerPacket.size(), 0);
BOOST_CHECK_EQUAL(ednsAdded, false);
BOOST_CHECK_EQUAL(ecsAdded, false);
validateQuery(packet, len);
}
BOOST_AUTO_TEST_CASE(replaceECSWithLarger) {
- string largerPacket;
bool ednsAdded = false;
bool ecsAdded = false;
ComboAddress remote("192.168.1.25");
BOOST_CHECK_EQUAL(qname, name);
BOOST_CHECK(qtype == QType::A);
- handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+ BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6));
BOOST_CHECK((size_t) len > query.size());
- BOOST_CHECK_EQUAL(largerPacket.size(), 0);
BOOST_CHECK_EQUAL(ednsAdded, false);
BOOST_CHECK_EQUAL(ecsAdded, false);
validateQuery(packet, len);
/* not large enough packet */
+ ednsAdded = false;
+ ecsAdded = false;
consumed = 0;
len = query.size();
- qname = DNSName((char*) query.data(), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+ qname = DNSName(reinterpret_cast<char*>(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
BOOST_CHECK_EQUAL(qname, name);
BOOST_CHECK(qtype == QType::A);
- handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+ BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6));
BOOST_CHECK_EQUAL((size_t) len, query.size());
- BOOST_CHECK(largerPacket.size() > query.size());
BOOST_CHECK_EQUAL(ednsAdded, false);
BOOST_CHECK_EQUAL(ecsAdded, false);
- validateQuery(largerPacket.c_str(), largerPacket.size());
+ validateQuery(reinterpret_cast<char*>(query.data()), len);
}
BOOST_AUTO_TEST_CASE(removeEDNSWhenFirst) {