From: Remi Gacogne Date: Wed, 10 Apr 2019 10:22:32 +0000 (+0200) Subject: dnsdist: Prevent copies of DNSQuestion and DNSResponse objects X-Git-Tag: dnsdist-1.4.0-alpha1~2^2~3 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=dd1a3034369980a4e444fabda134b2a927b95634;p=pdns dnsdist: Prevent copies of DNSQuestion and DNSResponse objects --- diff --git a/pdns/dnsdist-lua-actions.cc b/pdns/dnsdist-lua-actions.cc index 4f255b6e8..c9cadfc5e 100644 --- a/pdns/dnsdist-lua-actions.cc +++ b/pdns/dnsdist-lua-actions.cc @@ -745,7 +745,7 @@ private: class DnstapLogAction : public DNSAction, public boost::noncopyable { public: - DnstapLogAction(const std::string& identity, std::shared_ptr& logger, boost::optional > alterFunc): d_identity(identity), d_logger(logger), d_alterFunc(alterFunc) + DnstapLogAction(const std::string& identity, std::shared_ptr& logger, boost::optional > alterFunc): d_identity(identity), d_logger(logger), d_alterFunc(alterFunc) { } DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override @@ -755,7 +755,7 @@ public: { if (d_alterFunc) { std::lock_guard lock(g_luamutex); - (*d_alterFunc)(*dq, &message); + (*d_alterFunc)(dq, &message); } } std::string data; @@ -771,13 +771,13 @@ public: private: std::string d_identity; std::shared_ptr d_logger; - boost::optional > d_alterFunc; + boost::optional > d_alterFunc; }; class RemoteLogAction : public DNSAction, public boost::noncopyable { public: - RemoteLogAction(std::shared_ptr& logger, boost::optional > alterFunc, const std::string& serverID, const std::string& ipEncryptKey): d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID), d_ipEncryptKey(ipEncryptKey) + RemoteLogAction(std::shared_ptr& logger, boost::optional > alterFunc, const std::string& serverID, const std::string& ipEncryptKey): d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID), d_ipEncryptKey(ipEncryptKey) { } DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override @@ -801,7 +801,7 @@ public: if (d_alterFunc) { std::lock_guard lock(g_luamutex); - (*d_alterFunc)(*dq, &message); + (*d_alterFunc)(dq, &message); } std::string data; @@ -816,7 +816,7 @@ public: } private: std::shared_ptr d_logger; - boost::optional > d_alterFunc; + boost::optional > d_alterFunc; std::string d_serverID; std::string d_ipEncryptKey; }; @@ -871,7 +871,7 @@ private: class DnstapLogResponseAction : public DNSResponseAction, public boost::noncopyable { public: - DnstapLogResponseAction(const std::string& identity, std::shared_ptr& logger, boost::optional > alterFunc): d_identity(identity), d_logger(logger), d_alterFunc(alterFunc) + DnstapLogResponseAction(const std::string& identity, std::shared_ptr& logger, boost::optional > alterFunc): d_identity(identity), d_logger(logger), d_alterFunc(alterFunc) { } DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override @@ -883,7 +883,7 @@ public: { if (d_alterFunc) { std::lock_guard lock(g_luamutex); - (*d_alterFunc)(*dr, &message); + (*d_alterFunc)(dr, &message); } } std::string data; @@ -899,13 +899,13 @@ public: private: std::string d_identity; std::shared_ptr d_logger; - boost::optional > d_alterFunc; + boost::optional > d_alterFunc; }; class RemoteLogResponseAction : public DNSResponseAction, public boost::noncopyable { public: - RemoteLogResponseAction(std::shared_ptr& logger, boost::optional > alterFunc, const std::string& serverID, const std::string& ipEncryptKey, bool includeCNAME): d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID), d_ipEncryptKey(ipEncryptKey), d_includeCNAME(includeCNAME) + RemoteLogResponseAction(std::shared_ptr& logger, boost::optional > alterFunc, const std::string& serverID, const std::string& ipEncryptKey, bool includeCNAME): d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID), d_ipEncryptKey(ipEncryptKey), d_includeCNAME(includeCNAME) { } DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override @@ -929,7 +929,7 @@ public: if (d_alterFunc) { std::lock_guard lock(g_luamutex); - (*d_alterFunc)(*dr, &message); + (*d_alterFunc)(dr, &message); } std::string data; @@ -944,7 +944,7 @@ public: } private: std::shared_ptr d_logger; - boost::optional > d_alterFunc; + boost::optional > d_alterFunc; std::string d_serverID; std::string d_ipEncryptKey; bool d_includeCNAME; @@ -1226,7 +1226,7 @@ void setupLuaActions() return std::shared_ptr(new LuaResponseAction(func)); }); - g_lua.writeFunction("RemoteLogAction", [](std::shared_ptr logger, boost::optional > alterFunc, boost::optional> vars) { + g_lua.writeFunction("RemoteLogAction", [](std::shared_ptr logger, boost::optional > alterFunc, boost::optional> vars) { // avoids potentially-evaluated-expression warning with clang. RemoteLoggerInterface& rl = *logger.get(); if (typeid(rl) != typeid(RemoteLogger)) { @@ -1252,7 +1252,7 @@ void setupLuaActions() #endif }); - g_lua.writeFunction("RemoteLogResponseAction", [](std::shared_ptr logger, boost::optional > alterFunc, boost::optional includeCNAME, boost::optional> vars) { + g_lua.writeFunction("RemoteLogResponseAction", [](std::shared_ptr logger, boost::optional > alterFunc, boost::optional includeCNAME, boost::optional> vars) { // avoids potentially-evaluated-expression warning with clang. RemoteLoggerInterface& rl = *logger.get(); if (typeid(rl) != typeid(RemoteLogger)) { @@ -1278,7 +1278,7 @@ void setupLuaActions() #endif }); - g_lua.writeFunction("DnstapLogAction", [](const std::string& identity, std::shared_ptr logger, boost::optional > alterFunc) { + g_lua.writeFunction("DnstapLogAction", [](const std::string& identity, std::shared_ptr logger, boost::optional > alterFunc) { #ifdef HAVE_PROTOBUF return std::shared_ptr(new DnstapLogAction(identity, logger, alterFunc)); #else @@ -1286,7 +1286,7 @@ void setupLuaActions() #endif }); - g_lua.writeFunction("DnstapLogResponseAction", [](const std::string& identity, std::shared_ptr logger, boost::optional > alterFunc) { + g_lua.writeFunction("DnstapLogResponseAction", [](const std::string& identity, std::shared_ptr logger, boost::optional > alterFunc) { #ifdef HAVE_PROTOBUF return std::shared_ptr(new DnstapLogResponseAction(identity, logger, alterFunc)); #else diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 461c5f389..0e3d00c48 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -1006,7 +1006,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& po bool countQuery{true}; if(g_qcount.filter) { std::lock_guard lock(g_luamutex); - std::tie (countQuery, qname) = g_qcount.filter(dq); + std::tie (countQuery, qname) = g_qcount.filter(&dq); } if(countQuery) { diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index 663e18061..d0a3a1784 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -65,6 +65,9 @@ struct DNSQuestion const uint16_t* flags = getFlagsFromDNSHeader(dh); origFlags = *flags; } + DNSQuestion(const DNSQuestion&) = delete; + DNSQuestion& operator=(const DNSQuestion&) = delete; + DNSQuestion(DNSQuestion&&) = default; #ifdef HAVE_PROTOBUF boost::optional uniqueId; @@ -108,6 +111,9 @@ struct DNSResponse : DNSQuestion { DNSResponse(const DNSName* name, uint16_t type, uint16_t class_, unsigned int consumed, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t responseLen, bool isTcp, const struct timespec* queryTime_): DNSQuestion(name, type, class_, consumed, lc, rem, header, bufferSize, responseLen, isTcp, queryTime_) { } + DNSResponse(const DNSResponse&) = delete; + DNSResponse& operator=(const DNSResponse&) = delete; + DNSResponse(DNSResponse&&) = default; }; /* so what could you do: @@ -565,7 +571,7 @@ struct IDState }; typedef std::unordered_map QueryCountRecords; -typedef std::function(DNSQuestion dq)> QueryCountFilter; +typedef std::function(const DNSQuestion* dq)> QueryCountFilter; struct QueryCount { QueryCount() { diff --git a/pdns/test-dnsdist_cc.cc b/pdns/test-dnsdist_cc.cc index f34968e74..100dd4710 100644 --- a/pdns/test-dnsdist_cc.cc +++ b/pdns/test-dnsdist_cc.cc @@ -1105,8 +1105,7 @@ static DNSQuestion getDNSQuestion(const DNSName& qname, const uint16_t qtype, co { dnsheader* dh = reinterpret_cast(query.data()); - DNSQuestion dq(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, dh, query.size(), len, false, &realTime); - return dq; + return DNSQuestion(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, dh, query.size(), len, false, &realTime); } static DNSQuestion turnIntoResponse(const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& lc, const ComboAddress& rem, const struct timespec& queryRealTime, vector& query, bool resizeBuffer=true)