]> granicus.if.org Git - pdns/commitdiff
dnsdist: Add `checkFunction` to implement a dynamic health check
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 28 Mar 2018 13:03:16 +0000 (15:03 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 28 Mar 2018 13:03:16 +0000 (15:03 +0200)
pdns/dnsdist-lua.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/docs/guides/downstreams.rst
pdns/dnsdistdist/docs/reference/config.rst
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_HealthChecks.py [new file with mode: 0644]
regression-tests.dnsdist/test_TCPLimits.py

index 1c9f18c604c19079c44dd655c7f5e076e35388eb..ca5c1b8ce4773d434d2f879d838a7269874ed06a 100644 (file)
@@ -103,7 +103,7 @@ static void parseLocalBindVars(boost::optional<localbind_t> vars, bool& doTCP, b
 
 void setupLuaConfig(bool client)
 {
-  typedef std::unordered_map<std::string, boost::variant<bool, std::string, vector<pair<int, std::string> > > > newserver_t;
+  typedef std::unordered_map<std::string, boost::variant<bool, std::string, vector<pair<int, std::string> >, DownstreamState::checkfunc_t > > newserver_t;
 
   g_lua.writeFunction("inClientStartup", [client]() {
         return client && !g_configurationDone;
@@ -302,6 +302,10 @@ void setupLuaConfig(bool client)
                          ret->checkClass=std::stoi(boost::get<string>(vars["checkClass"]));
                        }
 
+                        if(vars.count("checkFunction")) {
+                         ret->checkFunction= boost::get<DownstreamState::checkfunc_t>(vars["checkFunction"]);
+                       }
+
                        if(vars.count("setCD")) {
                          ret->setCD=boost::get<bool>(vars["setCD"]);
                        }
index fe665e5134ae2b69751b863d30835ae9bc5bd76f..fd05ef4b69b88efd1c43ef608242d585ea097002 100644 (file)
@@ -1657,14 +1657,38 @@ catch(...)
 static bool upCheck(DownstreamState& ds)
 try
 {
-  vector<uint8_t> packet;
-  DNSPacketWriter dpw(packet, ds.checkName, ds.checkType.getCode(), ds.checkClass);
-  dnsheader * requestHeader = dpw.getHeader();
-  requestHeader->rd=true;
+  DNSName checkName = ds.checkName;
+  uint16_t checkType = ds.checkType.getCode();
+  uint16_t checkClass = ds.checkClass;
+  dnsheader checkHeader;
+  memset(&checkHeader, 0, sizeof(checkHeader));
+
+  checkHeader.qdcount = htons(1);
+#ifdef HAVE_LIBSODIUM
+  checkHeader.id = randombytes_random() % 65536;
+#else
+  checkHeader.id = random() % 65536;
+#endif
+
+  checkHeader.rd = true;
   if (ds.setCD) {
-    requestHeader->cd = true;
+    checkHeader.cd = true;
+  }
+
+
+  if (ds.checkFunction) {
+    std::lock_guard<std::mutex> lock(g_luamutex);
+    auto ret = ds.checkFunction(checkName, checkType, checkClass, &checkHeader);
+    checkName = std::get<0>(ret);
+    checkType = std::get<1>(ret);
+    checkClass = std::get<2>(ret);
   }
 
+  vector<uint8_t> packet;
+  DNSPacketWriter dpw(packet, checkName, checkType, checkClass);
+  dnsheader * requestHeader = dpw.getHeader();
+  *requestHeader = checkHeader;
+
   Socket sock(ds.remote.sin4.sin_family, SOCK_DGRAM);
   sock.setNonBlocking();
   if (!IsAnyAddress(ds.sourceAddr)) {
@@ -1697,7 +1721,7 @@ try
   string reply;
   sock.recvFrom(reply, ds.remote);
 
-  const dnsheader * responseHeader = (const dnsheader *) reply.c_str();
+  const dnsheader * responseHeader = reinterpret_cast<const dnsheader *>(reply.c_str());
 
   if (reply.size() < sizeof(*responseHeader)) {
     if (g_verboseHealthChecks)
@@ -1729,7 +1753,16 @@ try
     return false;
   }
 
-  // XXX fixme do bunch of checking here etc 
+  uint16_t receivedType;
+  uint16_t receivedClass;
+  DNSName receivedName(reply.c_str(), reply.size(), sizeof(dnsheader), false, &receivedType, &receivedClass);
+
+  if (receivedName != checkName || receivedType != checkType || receivedClass != checkClass) {
+    if (g_verboseHealthChecks)
+      infolog("Backend %s responded to health check with an invalid qname (%s vs %s), qtype (%s vs %s) or qclass (%d vs %d)", ds.getNameWithAddr(), receivedName.toLogString(), checkName.toLogString(), QType(receivedType).getName(), QType(checkType).getName(), receivedClass, checkClass);
+    return false;
+  }
+
   return true;
 }
 catch(const std::exception& e)
index 862efa60b5ac51a3890bbd5f7f1f16c3bb87d8bf..fc2f0bb6281f7ff2eab5de30a80f413382c71e8c 100644 (file)
@@ -524,6 +524,8 @@ extern std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
 
 struct DownstreamState
 {
+  typedef std::function<std::tuple<DNSName, uint16_t, uint16_t>(const DNSName&, uint16_t, uint16_t, dnsheader*)> checkfunc_t;
+
   DownstreamState(const ComboAddress& remote_, const ComboAddress& sourceAddr_, unsigned int sourceItf, size_t numberOfSockets);
   DownstreamState(const ComboAddress& remote_): DownstreamState(remote_, ComboAddress(), 0, 1) {}
   ~DownstreamState()
@@ -544,6 +546,7 @@ struct DownstreamState
   QPSLimiter qps;
   vector<IDState> idStates;
   ComboAddress sourceAddr;
+  checkfunc_t checkFunction;
   DNSName checkName{"a.root-servers.net."};
   QType checkType{QType::A};
   uint16_t checkClass{QClass::IN};
index 3e73b162c91951c146ebb70e2adb959c8dc4ad3a..47666517a118d2eb20de72bd752ca5dd77505901 100644 (file)
@@ -21,6 +21,8 @@ These two equivalent configurations give you sane load balancing using a very se
 Many users will simply be done with this configuration.
 It works as well for authoritative as for recursive servers.
 
+.. _Healthcheck:
+
 Healthcheck
 -----------
 dnsdist uses a health check, sent once every second, to determine the availability of a backend server.
@@ -37,6 +39,21 @@ e.g.::
 
   newServer({address="192.0.2.1", checkType="AAAA", checkType=DNSClass.CHAOS, checkName="a.root-servers.net.", mustResolve=true})
 
+Since the 1.3.0 release, the ``checkFunction`` option is also supported, taking a ``Lua`` function as parameter. This function receives a DNSName, two integers and a ``DNSHeader`` object (:ref:`DNSHeader`)
+representing the QName, QType and QClass of the health check query as well as the DNS header, as they are defined before the function was called. The function must return a DNSName and two integers
+representing the new QName, QType and QClass, and can directly modify the ``DNSHeader`` object.
+
+The following example sets the CD flag to true and change the QName to "powerdns.com." and the QType to AAAA while keeping the initial QClass.
+.. code-block:: lua
+
+    function myHealthCheck(qname, qtype, qclass, dh)
+      dh:setCD(true)
+
+      return newDNSName("powerdns.com."), dnsdist.AAAA, qclass
+    end
+
+    newServer("2620:0:0ccd::2")
+
 Source address selection
 ------------------------
 
index b136c5ceaa9fb66c106f20d7c0ca518388fc8a03..4094ed1b453684e12981d8867e615a2b744e1151 100644 (file)
@@ -21,7 +21,7 @@ Within dnsdist several core object types exist:
 * :class:`NetmaskGroup`: represents a group of netmasks
 * :class:`QPSLimiter`: implements a QPS-based filter
 * :class:`SuffixMatchNode`: represents a group of domain suffixes for rapid testing of membership
-* :class:`DNSHeader`: represents the header of a DNS packet
+* :class:`DNSHeader`: represents the header of a DNS packet, see :ref:`DNSHeader`
 * :class:`ClientState`: sometimes also called Bind or Frontend, represents the addresses and ports dnsdist is listening on
 
 The existence of most of these objects can mostly be ignored, unless you plan to write your own hooks and policies, but it helps to understand an expressions like:
@@ -235,6 +235,9 @@ Servers
 .. function:: newServer(server_string)
               newServer(server_table)
 
+  .. versionchanged:: 1.3.0
+    ``checkFunction`` option added.
+
   Add a new backend server. Call this function with either a string::
 
     newServer(
@@ -259,6 +262,7 @@ Servers
       checkClass=NUM,        -- Use NUM as QCLASS in the health-check query, default: DNSClass.IN
       checkName=STRING,      -- Use STRING as QNAME in the health-check query, default: "a.root-servers.net."
       checkType=STRING,      -- Use STRING as QTYPE in the health-check query, default: "A"
+      checkFunction=FUNCTION -- Use this function to dynamically set the QNAME, QTYPE and QCLASS to use in the health-check query (see :ref:`Healthcheck`)
       setCD=BOOL,            -- Set the CD (Checking Disabled) flag in the health-check query, default: false
       maxCheckFailures=NUM,  -- Allow NUM check failures before declaring the backend down, default: 1
       mustResolve=BOOL,      -- Set to true when the health check MUST return a NOERROR RCODE and an answer
index 4d5a83ca7569f70cb833537966e2499562bdee3f..bcdb65376a9f8ad7780e26fbc2b3f43a8d2162c8 100644 (file)
@@ -50,6 +50,9 @@ class DNSDistTest(unittest.TestCase):
     _acl = ['127.0.0.1/32']
     _consolePort = 5199
     _consoleKey = None
+    _healthCheckName = 'a.root-servers.net.'
+    _healthCheckCounter = 0
+    _healthCheckAnswerUnexpected = False
 
     @classmethod
     def startResponders(cls):
@@ -144,8 +147,10 @@ class DNSDistTest(unittest.TestCase):
         if len(request.question) != 1:
             print("Skipping query with question count %d" % (len(request.question)))
             return None
-        healthcheck = not str(request.question[0].name).endswith('tests.powerdns.com.')
-        if not healthcheck:
+        healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
+        if healthCheck:
+            cls._healthCheckCounter += 1
+        else:
             cls._ResponderIncrementCounter()
             if not fromQueue.empty():
                 response = fromQueue.get(True, cls._queueTimeout)
@@ -154,7 +159,7 @@ class DNSDistTest(unittest.TestCase):
                     response.id = request.id
                     toQueue.put(request, True, cls._queueTimeout)
 
-        if not response:
+        if not response and (healthCheck or cls._healthCheckAnswerUnexpected):
             # unexpected query, or health check
             response = dns.message.make_response(request)
 
@@ -193,6 +198,10 @@ class DNSDistTest(unittest.TestCase):
             (conn, _) = sock.accept()
             conn.settimeout(2.0)
             data = conn.recv(2)
+            if not data:
+                conn.close()
+                continue
+
             (datalen,) = struct.unpack("!H", data)
             data = conn.recv(datalen)
             request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
@@ -381,6 +390,8 @@ class DNSDistTest(unittest.TestCase):
         for key in self._responsesCounter:
             self._responsesCounter[key] = 0
 
+        self._healthCheckCounter = 0
+
         # Make sure the queues are empty, in case
         # a previous test failed
         while not self._toResponderQueue.empty():
diff --git a/regression-tests.dnsdist/test_HealthChecks.py b/regression-tests.dnsdist/test_HealthChecks.py
new file mode 100644 (file)
index 0000000..7acbbf5
--- /dev/null
@@ -0,0 +1,179 @@
+#!/usr/bin/env python
+import base64
+import time
+import dns
+from dnsdisttests import DNSDistTest
+
+class HealthCheckTest(DNSDistTest):
+    _consoleKey = DNSDistTest.generateConsoleKey()
+    _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
+    _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort']
+    _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%d")
+    newServer{address="127.0.0.1:%d"}
+    """
+
+    def getBackendStatus(self):
+        return self.sendConsoleCommand("if getServer(0):isUp() then return 'up' else return 'down' end").strip("\n")
+
+class TestDefaultHealthCheck(HealthCheckTest):
+    # this test suite uses a different responder port
+    # because we need fresh counters
+    _testServerPort = 5380
+
+    def testDefault(self):
+        """
+        HealthChecks: Default
+        """
+        before = TestDefaultHealthCheck._healthCheckCounter
+        time.sleep(1)
+        self.assertGreater(TestDefaultHealthCheck._healthCheckCounter, before)
+        self.assertEquals(self.getBackendStatus(), 'up')
+
+        self.sendConsoleCommand("getServer(0):setUp()")
+        self.assertEquals(self.getBackendStatus(), 'up')
+
+        before = TestDefaultHealthCheck._healthCheckCounter
+        time.sleep(1)
+        self.assertEquals(TestDefaultHealthCheck._healthCheckCounter, before)
+
+        self.sendConsoleCommand("getServer(0):setDown()")
+        self.assertEquals(self.getBackendStatus(), 'down')
+
+        before = TestDefaultHealthCheck._healthCheckCounter
+        time.sleep(1)
+        self.assertEquals(TestDefaultHealthCheck._healthCheckCounter, before)
+
+        self.sendConsoleCommand("getServer(0):setAuto()")
+        # we get back the previous state, which was up
+        self.assertEquals(self.getBackendStatus(), 'up')
+
+        before = TestDefaultHealthCheck._healthCheckCounter
+        time.sleep(1)
+        self.assertGreater(TestDefaultHealthCheck._healthCheckCounter, before)
+        self.assertEquals(self.getBackendStatus(), 'up')
+
+        self.sendConsoleCommand("getServer(0):setDown()")
+        self.assertEquals(self.getBackendStatus(), 'down')
+        self.sendConsoleCommand("getServer(0):setAuto(false)")
+        # we specified that the new state should be up until the next health check
+        self.assertEquals(self.getBackendStatus(), 'down')
+
+        before = TestDefaultHealthCheck._healthCheckCounter
+        time.sleep(1)
+        self.assertGreater(TestDefaultHealthCheck._healthCheckCounter, before)
+        self.assertEquals(self.getBackendStatus(), 'up')
+
+class TestHealthCheckForcedUP(HealthCheckTest):
+    # this test suite uses a different responder port
+    # because we need fresh counters
+    _testServerPort = 5381
+
+    _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%d")
+    srv = newServer{address="127.0.0.1:%d"}
+    srv:setUp()
+    """
+
+    def testForcedUp(self):
+        """
+        HealthChecks: Forced UP
+        """
+        before = TestHealthCheckForcedUP._healthCheckCounter
+        time.sleep(1)
+        self.assertEquals(TestHealthCheckForcedUP._healthCheckCounter, before)
+        self.assertEquals(self.getBackendStatus(), 'up')
+
+class TestHealthCheckForcedDown(HealthCheckTest):
+    # this test suite uses a different responder port
+    # because we need fresh counters
+    _testServerPort = 5382
+
+    _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%d")
+    srv = newServer{address="127.0.0.1:%d"}
+    srv:setDown()
+    """
+
+    def testForcedDown(self):
+        """
+        HealthChecks: Forced Down
+        """
+        before = TestHealthCheckForcedDown._healthCheckCounter
+        time.sleep(1)
+        self.assertEquals(TestHealthCheckForcedDown._healthCheckCounter, before)
+
+class TestHealthCheckCustomName(HealthCheckTest):
+    # this test suite uses a different responder port
+    # because it uses a different health check name
+    _testServerPort = 5383
+
+    _healthCheckName = 'powerdns.com.'
+    _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_healthCheckName']
+    _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%d")
+    srv = newServer{address="127.0.0.1:%d", checkName='%s'}
+    """
+
+    def testAuto(self):
+        """
+        HealthChecks: Custom name
+        """
+        before = TestHealthCheckCustomName._healthCheckCounter
+        time.sleep(1)
+        self.assertGreater(TestHealthCheckCustomName._healthCheckCounter, before)
+        self.assertEquals(self.getBackendStatus(), 'up')
+
+class TestHealthCheckCustomNameNoAnswer(HealthCheckTest):
+    # this test suite uses a different responder port
+    # because it uses a different health check configuration
+    _testServerPort = 5384
+
+    _healthCheckAnswerUnexpected = False
+    _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%d")
+    srv = newServer{address="127.0.0.1:%d", checkName='powerdns.com.'}
+    """
+
+    def testAuto(self):
+        """
+        HealthChecks: Custom name not expected by the responder
+        """
+        before = TestHealthCheckCustomNameNoAnswer._healthCheckCounter
+        time.sleep(1)
+        self.assertEquals(TestHealthCheckCustomNameNoAnswer._healthCheckCounter, before)
+        self.assertEquals(self.getBackendStatus(), 'down')
+
+class TestHealthCheckCustomFunction(HealthCheckTest):
+    # this test suite uses a different responder port
+    # because it uses a different health check configuration
+    _testServerPort = 5385
+    _healthCheckAnswerUnexpected = False
+
+    _healthCheckName = 'powerdns.com.'
+    _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%d")
+
+    function myHealthCheckFunction(qname, qtype, qclass, dh)
+      dh:setCD(true)
+
+      return newDNSName('powerdns.com.'), dnsdist.AAAA, qclass
+    end
+
+    srv = newServer{address="127.0.0.1:%d", checkName='powerdns.org.', checkFunction=myHealthCheckFunction}
+    """
+
+    def testAuto(self):
+        """
+        HealthChecks: Custom function
+        """
+        before = TestHealthCheckCustomFunction._healthCheckCounter
+        time.sleep(1)
+        self.assertGreater(TestHealthCheckCustomFunction._healthCheckCounter, before)
+        self.assertEquals(self.getBackendStatus(), 'up')
index 395ade98ca98b02dc736da6f8e691a396cb7b8a8..41ccbb9b451770ac8a74e1f811bae993c1ed5867 100644 (file)
@@ -6,6 +6,11 @@ from dnsdisttests import DNSDistTest, range
 
 class TestTCPLimits(DNSDistTest):
 
+    # this test suite uses a different responder port
+    # because it uses a different health check configuration
+    _testServerPort = 5395
+    _healthCheckAnswerUnexpected = True
+
     _tcpIdleTimeout = 2
     _maxTCPQueriesPerConn = 5
     _maxTCPConnsPerClient = 3