From: Remi Gacogne Date: Wed, 28 Mar 2018 13:03:16 +0000 (+0200) Subject: dnsdist: Add `checkFunction` to implement a dynamic health check X-Git-Tag: dnsdist-1.3.0~4^2 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=98650fdefa53920e2f30ead3ed8ff71080c62d60;p=pdns dnsdist: Add `checkFunction` to implement a dynamic health check --- diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index 1c9f18c60..ca5c1b8ce 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -103,7 +103,7 @@ static void parseLocalBindVars(boost::optional vars, bool& doTCP, b void setupLuaConfig(bool client) { - typedef std::unordered_map > > > newserver_t; + typedef std::unordered_map >, DownstreamState::checkfunc_t > > newserver_t; g_lua.writeFunction("inClientStartup", [client]() { return client && !g_configurationDone; @@ -302,6 +302,10 @@ void setupLuaConfig(bool client) ret->checkClass=std::stoi(boost::get(vars["checkClass"])); } + if(vars.count("checkFunction")) { + ret->checkFunction= boost::get(vars["checkFunction"]); + } + if(vars.count("setCD")) { ret->setCD=boost::get(vars["setCD"]); } diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index fe665e513..fd05ef4b6 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -1657,14 +1657,38 @@ catch(...) static bool upCheck(DownstreamState& ds) try { - vector packet; - DNSPacketWriter dpw(packet, ds.checkName, ds.checkType.getCode(), ds.checkClass); - dnsheader * requestHeader = dpw.getHeader(); - requestHeader->rd=true; + DNSName checkName = ds.checkName; + uint16_t checkType = ds.checkType.getCode(); + uint16_t checkClass = ds.checkClass; + dnsheader checkHeader; + memset(&checkHeader, 0, sizeof(checkHeader)); + + checkHeader.qdcount = htons(1); +#ifdef HAVE_LIBSODIUM + checkHeader.id = randombytes_random() % 65536; +#else + checkHeader.id = random() % 65536; +#endif + + checkHeader.rd = true; if (ds.setCD) { - requestHeader->cd = true; + checkHeader.cd = true; + } + + + if (ds.checkFunction) { + std::lock_guard lock(g_luamutex); + auto ret = ds.checkFunction(checkName, checkType, checkClass, &checkHeader); + checkName = std::get<0>(ret); + checkType = std::get<1>(ret); + checkClass = std::get<2>(ret); } + vector packet; + DNSPacketWriter dpw(packet, checkName, checkType, checkClass); + dnsheader * requestHeader = dpw.getHeader(); + *requestHeader = checkHeader; + Socket sock(ds.remote.sin4.sin_family, SOCK_DGRAM); sock.setNonBlocking(); if (!IsAnyAddress(ds.sourceAddr)) { @@ -1697,7 +1721,7 @@ try string reply; sock.recvFrom(reply, ds.remote); - const dnsheader * responseHeader = (const dnsheader *) reply.c_str(); + const dnsheader * responseHeader = reinterpret_cast(reply.c_str()); if (reply.size() < sizeof(*responseHeader)) { if (g_verboseHealthChecks) @@ -1729,7 +1753,16 @@ try return false; } - // XXX fixme do bunch of checking here etc + uint16_t receivedType; + uint16_t receivedClass; + DNSName receivedName(reply.c_str(), reply.size(), sizeof(dnsheader), false, &receivedType, &receivedClass); + + if (receivedName != checkName || receivedType != checkType || receivedClass != checkClass) { + if (g_verboseHealthChecks) + infolog("Backend %s responded to health check with an invalid qname (%s vs %s), qtype (%s vs %s) or qclass (%d vs %d)", ds.getNameWithAddr(), receivedName.toLogString(), checkName.toLogString(), QType(receivedType).getName(), QType(checkType).getName(), receivedClass, checkClass); + return false; + } + return true; } catch(const std::exception& e) diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index 862efa60b..fc2f0bb62 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -524,6 +524,8 @@ extern std::shared_ptr g_tcpclientthreads; struct DownstreamState { + typedef std::function(const DNSName&, uint16_t, uint16_t, dnsheader*)> checkfunc_t; + DownstreamState(const ComboAddress& remote_, const ComboAddress& sourceAddr_, unsigned int sourceItf, size_t numberOfSockets); DownstreamState(const ComboAddress& remote_): DownstreamState(remote_, ComboAddress(), 0, 1) {} ~DownstreamState() @@ -544,6 +546,7 @@ struct DownstreamState QPSLimiter qps; vector idStates; ComboAddress sourceAddr; + checkfunc_t checkFunction; DNSName checkName{"a.root-servers.net."}; QType checkType{QType::A}; uint16_t checkClass{QClass::IN}; diff --git a/pdns/dnsdistdist/docs/guides/downstreams.rst b/pdns/dnsdistdist/docs/guides/downstreams.rst index 3e73b162c..47666517a 100644 --- a/pdns/dnsdistdist/docs/guides/downstreams.rst +++ b/pdns/dnsdistdist/docs/guides/downstreams.rst @@ -21,6 +21,8 @@ These two equivalent configurations give you sane load balancing using a very se Many users will simply be done with this configuration. It works as well for authoritative as for recursive servers. +.. _Healthcheck: + Healthcheck ----------- dnsdist uses a health check, sent once every second, to determine the availability of a backend server. @@ -37,6 +39,21 @@ e.g.:: newServer({address="192.0.2.1", checkType="AAAA", checkType=DNSClass.CHAOS, checkName="a.root-servers.net.", mustResolve=true}) +Since the 1.3.0 release, the ``checkFunction`` option is also supported, taking a ``Lua`` function as parameter. This function receives a DNSName, two integers and a ``DNSHeader`` object (:ref:`DNSHeader`) +representing the QName, QType and QClass of the health check query as well as the DNS header, as they are defined before the function was called. The function must return a DNSName and two integers +representing the new QName, QType and QClass, and can directly modify the ``DNSHeader`` object. + +The following example sets the CD flag to true and change the QName to "powerdns.com." and the QType to AAAA while keeping the initial QClass. +.. code-block:: lua + + function myHealthCheck(qname, qtype, qclass, dh) + dh:setCD(true) + + return newDNSName("powerdns.com."), dnsdist.AAAA, qclass + end + + newServer("2620:0:0ccd::2") + Source address selection ------------------------ diff --git a/pdns/dnsdistdist/docs/reference/config.rst b/pdns/dnsdistdist/docs/reference/config.rst index b136c5cea..4094ed1b4 100644 --- a/pdns/dnsdistdist/docs/reference/config.rst +++ b/pdns/dnsdistdist/docs/reference/config.rst @@ -21,7 +21,7 @@ Within dnsdist several core object types exist: * :class:`NetmaskGroup`: represents a group of netmasks * :class:`QPSLimiter`: implements a QPS-based filter * :class:`SuffixMatchNode`: represents a group of domain suffixes for rapid testing of membership -* :class:`DNSHeader`: represents the header of a DNS packet +* :class:`DNSHeader`: represents the header of a DNS packet, see :ref:`DNSHeader` * :class:`ClientState`: sometimes also called Bind or Frontend, represents the addresses and ports dnsdist is listening on The existence of most of these objects can mostly be ignored, unless you plan to write your own hooks and policies, but it helps to understand an expressions like: @@ -235,6 +235,9 @@ Servers .. function:: newServer(server_string) newServer(server_table) + .. versionchanged:: 1.3.0 + ``checkFunction`` option added. + Add a new backend server. Call this function with either a string:: newServer( @@ -259,6 +262,7 @@ Servers checkClass=NUM, -- Use NUM as QCLASS in the health-check query, default: DNSClass.IN checkName=STRING, -- Use STRING as QNAME in the health-check query, default: "a.root-servers.net." checkType=STRING, -- Use STRING as QTYPE in the health-check query, default: "A" + checkFunction=FUNCTION -- Use this function to dynamically set the QNAME, QTYPE and QCLASS to use in the health-check query (see :ref:`Healthcheck`) setCD=BOOL, -- Set the CD (Checking Disabled) flag in the health-check query, default: false maxCheckFailures=NUM, -- Allow NUM check failures before declaring the backend down, default: 1 mustResolve=BOOL, -- Set to true when the health check MUST return a NOERROR RCODE and an answer diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 4d5a83ca7..bcdb65376 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -50,6 +50,9 @@ class DNSDistTest(unittest.TestCase): _acl = ['127.0.0.1/32'] _consolePort = 5199 _consoleKey = None + _healthCheckName = 'a.root-servers.net.' + _healthCheckCounter = 0 + _healthCheckAnswerUnexpected = False @classmethod def startResponders(cls): @@ -144,8 +147,10 @@ class DNSDistTest(unittest.TestCase): if len(request.question) != 1: print("Skipping query with question count %d" % (len(request.question))) return None - healthcheck = not str(request.question[0].name).endswith('tests.powerdns.com.') - if not healthcheck: + healthCheck = str(request.question[0].name).endswith(cls._healthCheckName) + if healthCheck: + cls._healthCheckCounter += 1 + else: cls._ResponderIncrementCounter() if not fromQueue.empty(): response = fromQueue.get(True, cls._queueTimeout) @@ -154,7 +159,7 @@ class DNSDistTest(unittest.TestCase): response.id = request.id toQueue.put(request, True, cls._queueTimeout) - if not response: + if not response and (healthCheck or cls._healthCheckAnswerUnexpected): # unexpected query, or health check response = dns.message.make_response(request) @@ -193,6 +198,10 @@ class DNSDistTest(unittest.TestCase): (conn, _) = sock.accept() conn.settimeout(2.0) data = conn.recv(2) + if not data: + conn.close() + continue + (datalen,) = struct.unpack("!H", data) data = conn.recv(datalen) request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) @@ -381,6 +390,8 @@ class DNSDistTest(unittest.TestCase): for key in self._responsesCounter: self._responsesCounter[key] = 0 + self._healthCheckCounter = 0 + # Make sure the queues are empty, in case # a previous test failed while not self._toResponderQueue.empty(): diff --git a/regression-tests.dnsdist/test_HealthChecks.py b/regression-tests.dnsdist/test_HealthChecks.py new file mode 100644 index 000000000..7acbbf552 --- /dev/null +++ b/regression-tests.dnsdist/test_HealthChecks.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python +import base64 +import time +import dns +from dnsdisttests import DNSDistTest + +class HealthCheckTest(DNSDistTest): + _consoleKey = DNSDistTest.generateConsoleKey() + _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii') + _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort'] + _config_template = """ + setKey("%s") + controlSocket("127.0.0.1:%d") + newServer{address="127.0.0.1:%d"} + """ + + def getBackendStatus(self): + return self.sendConsoleCommand("if getServer(0):isUp() then return 'up' else return 'down' end").strip("\n") + +class TestDefaultHealthCheck(HealthCheckTest): + # this test suite uses a different responder port + # because we need fresh counters + _testServerPort = 5380 + + def testDefault(self): + """ + HealthChecks: Default + """ + before = TestDefaultHealthCheck._healthCheckCounter + time.sleep(1) + self.assertGreater(TestDefaultHealthCheck._healthCheckCounter, before) + self.assertEquals(self.getBackendStatus(), 'up') + + self.sendConsoleCommand("getServer(0):setUp()") + self.assertEquals(self.getBackendStatus(), 'up') + + before = TestDefaultHealthCheck._healthCheckCounter + time.sleep(1) + self.assertEquals(TestDefaultHealthCheck._healthCheckCounter, before) + + self.sendConsoleCommand("getServer(0):setDown()") + self.assertEquals(self.getBackendStatus(), 'down') + + before = TestDefaultHealthCheck._healthCheckCounter + time.sleep(1) + self.assertEquals(TestDefaultHealthCheck._healthCheckCounter, before) + + self.sendConsoleCommand("getServer(0):setAuto()") + # we get back the previous state, which was up + self.assertEquals(self.getBackendStatus(), 'up') + + before = TestDefaultHealthCheck._healthCheckCounter + time.sleep(1) + self.assertGreater(TestDefaultHealthCheck._healthCheckCounter, before) + self.assertEquals(self.getBackendStatus(), 'up') + + self.sendConsoleCommand("getServer(0):setDown()") + self.assertEquals(self.getBackendStatus(), 'down') + self.sendConsoleCommand("getServer(0):setAuto(false)") + # we specified that the new state should be up until the next health check + self.assertEquals(self.getBackendStatus(), 'down') + + before = TestDefaultHealthCheck._healthCheckCounter + time.sleep(1) + self.assertGreater(TestDefaultHealthCheck._healthCheckCounter, before) + self.assertEquals(self.getBackendStatus(), 'up') + +class TestHealthCheckForcedUP(HealthCheckTest): + # this test suite uses a different responder port + # because we need fresh counters + _testServerPort = 5381 + + _config_template = """ + setKey("%s") + controlSocket("127.0.0.1:%d") + srv = newServer{address="127.0.0.1:%d"} + srv:setUp() + """ + + def testForcedUp(self): + """ + HealthChecks: Forced UP + """ + before = TestHealthCheckForcedUP._healthCheckCounter + time.sleep(1) + self.assertEquals(TestHealthCheckForcedUP._healthCheckCounter, before) + self.assertEquals(self.getBackendStatus(), 'up') + +class TestHealthCheckForcedDown(HealthCheckTest): + # this test suite uses a different responder port + # because we need fresh counters + _testServerPort = 5382 + + _config_template = """ + setKey("%s") + controlSocket("127.0.0.1:%d") + srv = newServer{address="127.0.0.1:%d"} + srv:setDown() + """ + + def testForcedDown(self): + """ + HealthChecks: Forced Down + """ + before = TestHealthCheckForcedDown._healthCheckCounter + time.sleep(1) + self.assertEquals(TestHealthCheckForcedDown._healthCheckCounter, before) + +class TestHealthCheckCustomName(HealthCheckTest): + # this test suite uses a different responder port + # because it uses a different health check name + _testServerPort = 5383 + + _healthCheckName = 'powerdns.com.' + _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_healthCheckName'] + _config_template = """ + setKey("%s") + controlSocket("127.0.0.1:%d") + srv = newServer{address="127.0.0.1:%d", checkName='%s'} + """ + + def testAuto(self): + """ + HealthChecks: Custom name + """ + before = TestHealthCheckCustomName._healthCheckCounter + time.sleep(1) + self.assertGreater(TestHealthCheckCustomName._healthCheckCounter, before) + self.assertEquals(self.getBackendStatus(), 'up') + +class TestHealthCheckCustomNameNoAnswer(HealthCheckTest): + # this test suite uses a different responder port + # because it uses a different health check configuration + _testServerPort = 5384 + + _healthCheckAnswerUnexpected = False + _config_template = """ + setKey("%s") + controlSocket("127.0.0.1:%d") + srv = newServer{address="127.0.0.1:%d", checkName='powerdns.com.'} + """ + + def testAuto(self): + """ + HealthChecks: Custom name not expected by the responder + """ + before = TestHealthCheckCustomNameNoAnswer._healthCheckCounter + time.sleep(1) + self.assertEquals(TestHealthCheckCustomNameNoAnswer._healthCheckCounter, before) + self.assertEquals(self.getBackendStatus(), 'down') + +class TestHealthCheckCustomFunction(HealthCheckTest): + # this test suite uses a different responder port + # because it uses a different health check configuration + _testServerPort = 5385 + _healthCheckAnswerUnexpected = False + + _healthCheckName = 'powerdns.com.' + _config_template = """ + setKey("%s") + controlSocket("127.0.0.1:%d") + + function myHealthCheckFunction(qname, qtype, qclass, dh) + dh:setCD(true) + + return newDNSName('powerdns.com.'), dnsdist.AAAA, qclass + end + + srv = newServer{address="127.0.0.1:%d", checkName='powerdns.org.', checkFunction=myHealthCheckFunction} + """ + + def testAuto(self): + """ + HealthChecks: Custom function + """ + before = TestHealthCheckCustomFunction._healthCheckCounter + time.sleep(1) + self.assertGreater(TestHealthCheckCustomFunction._healthCheckCounter, before) + self.assertEquals(self.getBackendStatus(), 'up') diff --git a/regression-tests.dnsdist/test_TCPLimits.py b/regression-tests.dnsdist/test_TCPLimits.py index 395ade98c..41ccbb9b4 100644 --- a/regression-tests.dnsdist/test_TCPLimits.py +++ b/regression-tests.dnsdist/test_TCPLimits.py @@ -6,6 +6,11 @@ from dnsdisttests import DNSDistTest, range class TestTCPLimits(DNSDistTest): + # this test suite uses a different responder port + # because it uses a different health check configuration + _testServerPort = 5395 + _healthCheckAnswerUnexpected = True + _tcpIdleTimeout = 2 _maxTCPQueriesPerConn = 5 _maxTCPConnsPerClient = 3