]> granicus.if.org Git - pdns/commitdiff
dnsdist: add ability to update webserver credentials
authorCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Tue, 30 Oct 2018 13:29:51 +0000 (14:29 +0100)
committerCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Tue, 30 Oct 2018 13:29:51 +0000 (14:29 +0100)
pdns/dnsdist-console.cc
pdns/dnsdist-lua.cc
pdns/dnsdist-web.cc
pdns/dnsdist.hh
pdns/dnsdistdist/docs/guides/webserver.rst
pdns/dnsdistdist/docs/reference/config.rst
regression-tests.dnsdist/test_API.py

index d0195c9ecc766f2bae176d3e636403d5afda70d8..874ea4aac0b989b1489e7f458391d178cefefdf7 100644 (file)
@@ -446,6 +446,7 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "setUDPMultipleMessagesVectorSize", true, "n", "set the size of the vector passed to recvmmsg() to receive UDP messages. Default to 1 which means that the feature is disabled and recvmsg() is used instead" },
   { "setUDPTimeout", true, "n", "set the maximum time dnsdist will wait for a response from a backend over UDP, in seconds" },
   { "setVerboseHealthChecks", true, "bool", "set whether health check errors will be logged" },
+  { "setWebserverConfig", true, "password [, apiKey [, customHeaders ]]", "Updates webserver configuration" },
   { "show", true, "string", "outputs `string`" },
   { "showACL", true, "", "show our ACL set" },
   { "showBinds", true, "", "show listening addresses (frontends)" },
index 40abb4bce865c5b91f718a57f4042fb3fc55dd82..cae1df01812c1c933c1dcbcb3bdfd2de614a5ca0 100644 (file)
@@ -631,7 +631,8 @@ void setupLuaConfig(bool client)
        SBind(sock, local);
        SListen(sock, 5);
        auto launch=[sock, local, password, apiKey, customHeaders]() {
-         thread t(dnsdistWebserverThread, sock, local, password, apiKey ? *apiKey : "", customHeaders);
+          setWebserverConfig(password, apiKey, customHeaders);
+          thread t(dnsdistWebserverThread, sock, local);
          t.detach();
        };
        if(g_launchWork)
@@ -646,6 +647,11 @@ void setupLuaConfig(bool client)
 
     });
 
+  g_lua.writeFunction("setWebserverConfig", [](const std::string& password, const boost::optional<std::string> apiKey, const boost::optional<std::map<std::string, std::string> > customHeaders) {
+      setLuaSideEffect();
+      setWebserverConfig(password, apiKey, customHeaders);
+    });
+
   g_lua.writeFunction("controlSocket", [client](const std::string& str) {
       setLuaSideEffect();
       ComboAddress local(str, 5199);
index a50d68f5549d5ab6b792323aad22ec3ab37594a7..f038b2cf90613ecc450e193155225461c6f6a1cf 100644 (file)
@@ -38,6 +38,7 @@
 #include  <boost/format.hpp>
 
 bool g_apiReadWrite{false};
+WebserverConfig g_webserverConfig;
 std::string g_apiConfigDirectory;
 
 static bool apiWriteConfigFile(const string& filebasename, const string& content)
@@ -133,23 +134,25 @@ static bool isAStatsRequest(const YaHTTP::Request& req)
   return req.url.path == "/jsonstat" || req.url.path == "/metrics";
 }
 
-static bool compareAuthorization(const YaHTTP::Request& req, const string &expected_password, const string& expectedApiKey)
+static bool compareAuthorization(const YaHTTP::Request& req)
 {
+  std::lock_guard<std::mutex> lock(g_webserverConfig.lock);
+
   if (isAnAPIRequest(req)) {
     /* Access to the API requires a valid API key */
-    if (checkAPIKey(req, expectedApiKey)) {
+    if (checkAPIKey(req, g_webserverConfig.apiKey)) {
       return true;
     }
 
-    return isAnAPIRequestAllowedWithWebAuth(req) && checkWebPassword(req, expected_password);
+    return isAnAPIRequestAllowedWithWebAuth(req) && checkWebPassword(req, g_webserverConfig.password);
   }
 
   if (isAStatsRequest(req)) {
     /* Access to the stats is allowed for both API and Web users */
-    return checkAPIKey(req, expectedApiKey) || checkWebPassword(req, expected_password);
+    return checkAPIKey(req, g_webserverConfig.apiKey) || checkWebPassword(req, g_webserverConfig.password);
   }
 
-  return checkWebPassword(req, expected_password);
+  return checkWebPassword(req, g_webserverConfig.password);
 }
 
 static bool isMethodAllowed(const YaHTTP::Request& req)
@@ -241,7 +244,7 @@ static json11::Json::array someResponseRulesToJson(GlobalStateHolder<vector<T>>*
   return responseRules;
 }
 
-static void connectionThread(int sock, ComboAddress remote, string password, string apiKey, const boost::optional<std::map<std::string, std::string> >& customHeaders)
+static void connectionThread(int sock, ComboAddress remote)
 {
   setThreadName("dnsdist/webConn");
 
@@ -276,8 +279,12 @@ static void connectionThread(int sock, ComboAddress remote, string password, str
     resp.version = req.version;
     const string charset = "; charset=utf-8";
 
-    addCustomHeaders(resp, customHeaders);
-    addSecurityHeaders(resp, customHeaders);
+    {
+      std::lock_guard<std::mutex> lock(g_webserverConfig.lock);
+
+      addCustomHeaders(resp, g_webserverConfig.customHeaders);
+      addSecurityHeaders(resp, g_webserverConfig.customHeaders);
+    }
     /* indicate that the connection will be closed after completion of the response */
     resp.headers["Connection"] = "close";
 
@@ -289,7 +296,7 @@ static void connectionThread(int sock, ComboAddress remote, string password, str
       handleCORS(req, resp);
       resp.status=200;
     }
-    else if (!compareAuthorization(req, password, apiKey)) {
+    else if (!compareAuthorization(req)) {
       YaHTTP::strstr_map_t::iterator header = req.headers.find("authorization");
       if (header != req.headers.end())
         errlog("HTTP Request \"%s\" from %s: Web Authentication failed", req.url.path, remote.toStringWithPort());
@@ -402,7 +409,7 @@ static void connectionThread(int sock, ComboAddress remote, string password, str
           // Prometheus suggest using '_' instead of '-'
           std::string prometheusMetricName = "dnsdist_" + boost::replace_all_copy(metricName, "-", "_");
 
-          MetricDefinition metricDetails; 
+          MetricDefinition metricDetails;
 
           if (!g_metricDefinitions.getMetricDetails(metricName, metricDetails)) {
               vinfolog("Do not have metric details for %s", metricName);
@@ -427,7 +434,7 @@ static void connectionThread(int sock, ComboAddress remote, string password, str
             output << **dval;
           else
             output << (*boost::get<DNSDistStats::statfunction_t>(&std::get<1>(e)))(std::get<0>(e));
-          
+
           output << "\n";
         }
 
@@ -448,10 +455,10 @@ static void connectionThread(int sock, ComboAddress remote, string password, str
         output << "# TYPE " << statesbase << "order "       << "gauge"                                                             << "\n";
         output << "# HELP " << statesbase << "weight "      << "The weight within the order in which this server is picked"        << "\n";
         output << "# TYPE " << statesbase << "weight "      << "gauge"                                                             << "\n";
-        
+
         for (const auto& state : *states) {
           string serverName;
-           
+
           if (state->name.empty())
               serverName = state->remote.toStringWithPort();
           else
@@ -484,10 +491,10 @@ static void connectionThread(int sock, ComboAddress remote, string password, str
 
         auto localPools = g_pools.getLocal();
         const string cachebase = "dnsdist_pool_";
-        
+
         for (const auto& entry : *localPools) {
           string poolName = entry.first;
-          
+
           if (poolName.empty()) {
             poolName = "_default_";
           }
@@ -524,18 +531,18 @@ static void connectionThread(int sock, ComboAddress remote, string password, str
       int num=0;
       for(const auto& a : *localServers) {
        string status;
-       if(a->availability == DownstreamState::Availability::Up) 
+       if(a->availability == DownstreamState::Availability::Up)
          status = "UP";
-       else if(a->availability == DownstreamState::Availability::Down) 
+       else if(a->availability == DownstreamState::Availability::Down)
          status = "DOWN";
-       else 
+       else
          status = (a->upStatus ? "up" : "down");
 
        Json::array pools;
        for(const auto& p: a->pools)
          pools.push_back(p);
 
-       Json::object server{ 
+       Json::object server{
          {"id", num++},
          {"name", a->name},
           {"address", a->remote.toStringWithPort()},
@@ -612,7 +619,7 @@ static void connectionThread(int sock, ComboAddress remote, string password, str
         };
        rules.push_back(rule);
       }
-      
+
       auto responseRules = someResponseRulesToJson(&g_resprulactions);
       auto cacheHitResponseRules = someResponseRulesToJson(&g_cachehitresprulactions);
       auto selfAnsweredResponseRules = someResponseRulesToJson(&g_selfansweredresprulactions);
@@ -631,7 +638,7 @@ static void connectionThread(int sock, ComboAddress remote, string password, str
         if(!localaddresses.empty()) localaddresses += ", ";
         localaddresses += std::get<0>(loc).toStringWithPort();
       }
+
       Json my_json = Json::object {
         { "daemon_type", "dnsdist" },
         { "version", VERSION},
@@ -831,7 +838,20 @@ static void connectionThread(int sock, ComboAddress remote, string password, str
     close(sock);
   }
 }
-void dnsdistWebserverThread(int sock, const ComboAddress& local, const std::string& password, const std::string& apiKey, const boost::optional<std::map<std::string, std::string> >& customHeaders)
+void setWebserverConfig(const std::string& password, const boost::optional<std::string> apiKey, const boost::optional<std::map<std::string, std::string> > customHeaders)
+{
+  std::lock_guard<std::mutex> lock(g_webserverConfig.lock);
+
+  g_webserverConfig.password = password;
+  if (apiKey) {
+    g_webserverConfig.apiKey = *apiKey;
+  } else {
+    g_webserverConfig.apiKey.clear();
+  }
+  g_webserverConfig.customHeaders = customHeaders;
+}
+
+void dnsdistWebserverThread(int sock, const ComboAddress& local)
 {
   setThreadName("dnsdist/webserv");
   warnlog("Webserver launched on %s", local.toStringWithPort());
@@ -840,7 +860,7 @@ void dnsdistWebserverThread(int sock, const ComboAddress& local, const std::stri
       ComboAddress remote(local);
       int fd = SAccept(sock, remote);
       vinfolog("Got connection from %s", remote.toStringWithPort());
-      std::thread t(connectionThread, fd, remote, password, apiKey, customHeaders);
+      std::thread t(connectionThread, fd, remote);
       t.detach();
     }
     catch(std::exception& e) {
index 08393675e88db78ae9be24bae695479f673d2c7b..8127eac2b87adacfdd29005dee3138463090d37e 100644 (file)
@@ -985,7 +985,16 @@ std::shared_ptr<DownstreamState> whashed(const NumberedServerVector& servers, co
 std::shared_ptr<DownstreamState> chashed(const NumberedServerVector& servers, const DNSQuestion* dq);
 std::shared_ptr<DownstreamState> roundrobin(const NumberedServerVector& servers, const DNSQuestion* dq);
 
-void dnsdistWebserverThread(int sock, const ComboAddress& local, const string& password, const string& apiKey, const boost::optional<std::map<std::string, std::string> >&);
+struct WebserverConfig
+{
+  std::string password;
+  std::string apiKey;
+  boost::optional<std::map<std::string, std::string> > customHeaders;
+  std::mutex lock;
+};
+
+void setWebserverConfig(const std::string& password, const boost::optional<std::string> apiKey, const boost::optional<std::map<std::string, std::string> > customHeaders);
+void dnsdistWebserverThread(int sock, const ComboAddress& local);
 bool getMsgLen32(int fd, uint32_t* len);
 bool putMsgLen32(int fd, uint32_t len);
 void* tcpAcceptorThread(void* p);
index 0c60c8d94dead396d3fe841ce0976174785fbabf..bb8d8e3dd3956ede0a3c07ea79adcc20a021658f 100644 (file)
@@ -29,6 +29,8 @@ For example, to remove the X-Frame-Options header and add a X-Custom one:
 
   webserver("127.0.0.1:8080", "supersecret", "apikey", {["X-Frame-Options"]= "", ["X-Custom"]="custom"}
 
+Credentials can be changed over time using the :func:`setWebserverConfig` function.
+
 dnsdist API
 -----------
 
index 345a255410b22998045a5af7255627e7c8ee3cdf..86bd479f0bb5a4e40eea257bc62ba9990dfd86bc 100644 (file)
@@ -211,10 +211,10 @@ Control Socket, Console and Webserver
 
   :param int size: The new maximum size.
 
-Webserver
-~~~~~~~~~
+Webserver configuration
+~~~~~~~~~~~~~~~~~~~~~~~
 
-.. function:: webServer(listen_address, password[, apikey[, custom_headers]])
+.. function:: webserver(listen_address, password[, apikey[, custom_headers]])
 
   Launch the :doc:`../guides/webserver` with statistics and the API.
 
@@ -232,6 +232,16 @@ Webserver
   :param bool allow: Set to true to allow modification through the API
   :param str dir: A valid directory where the configuration files will be written by the API.
 
+.. function:: setWebserverConfig(password[, apikey[, custom_headers]])
+
+  .. versionadded:: 1.3.3
+
+  Setup webserver configuration. See :func:`webserver`.
+
+  :param str password: The password required to access the webserver
+  :param str apikey: The key required to access the API
+  :param {[str]=str,...} custom_headers: Allows setting custom headers and removing the defaults
+                 
 Access Control Lists
 ~~~~~~~~~~~~~~~~~~~~
 
index f3047581c4df95b812b93ae44fb58c83becdee75..f3d3ab8d1fc19ed8624e36bc0a9588e52750039f 100644 (file)
@@ -1,10 +1,12 @@
 #!/usr/bin/env python
 import os.path
 
+import base64
 import json
 import requests
 from dnsdisttests import DNSDistTest
 
+
 class TestAPIBasics(DNSDistTest):
 
     _webTimeout = 2.0
@@ -30,6 +32,8 @@ class TestAPIBasics(DNSDistTest):
         """
         for path in self._basicOnlyPaths + self._statsPaths:
             url = 'http://127.0.0.1:' + str(self._webServerPort) + path
+            r = requests.get(url, auth=('whatever', "evilsecret"), timeout=self._webTimeout)
+            self.assertEquals(r.status_code, 401)
             r = requests.get(url, auth=('whatever', self._webServerBasicAuthPassword), timeout=self._webTimeout)
             self.assertTrue(r)
             self.assertEquals(r.status_code, 200)
@@ -45,6 +49,15 @@ class TestAPIBasics(DNSDistTest):
             self.assertTrue(r)
             self.assertEquals(r.status_code, 200)
 
+    def testWrongXAPIKey(self):
+        """
+        API: Wrong X-Api-Key
+        """
+        headers = {'x-api-key': "evilapikey"}
+        for path in self._apiOnlyPaths + self._statsPaths:
+            url = 'http://127.0.0.1:' + str(self._webServerPort) + path
+            r = requests.get(url, headers=headers, timeout=self._webTimeout)
+            self.assertEquals(r.status_code, 401)
     def testBasicAuthOnly(self):
         """
         API: Basic Authentication Only
@@ -370,3 +383,83 @@ class TestAPIWritable(DNSDistTest):
         self.assertEquals(fileContent, """-- Generated by the REST API, DO NOT EDIT
 setACL({"192.0.2.0/24", "198.51.100.0/24", "203.0.113.0/24"})
 """)
+
+class TestAPIAuth(DNSDistTest):
+
+    _webTimeout = 2.0
+    _webServerPort = 8083
+    _webServerBasicAuthPassword = 'secret'
+    _webServerBasicAuthPasswordNew = 'password'
+    _webServerAPIKey = 'apisecret'
+    _webServerAPIKeyNew = 'apipassword'
+    # paths accessible using the API key only
+    _apiOnlyPath = '/api/v1/servers/localhost/config'
+    # paths accessible using basic auth only (list not exhaustive)
+    _basicOnlyPath = '/'
+    _consoleKey = DNSDistTest.generateConsoleKey()
+    _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
+    _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey']
+    _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%s")
+    setACL({"127.0.0.1/32", "::1/128"})
+    newServer{address="127.0.0.1:%s"}
+    webserver("127.0.0.1:%s", "%s", "%s")
+    """
+
+    def testBasicAuthChange(self):
+        """
+        API: Basic Authentication updating credentials
+        """
+
+        self.sendConsoleCommand('setWebserverConfig("{}")'.format(self._webServerBasicAuthPasswordNew))
+
+        url = 'http://127.0.0.1:' + str(self._webServerPort) + self._basicOnlyPath
+        r = requests.get(url, auth=('whatever', self._webServerBasicAuthPasswordNew), timeout=self._webTimeout)
+        self.assertTrue(r)
+        self.assertEquals(r.status_code, 200)
+
+        # Make sure the old password is not usable any more
+        url = 'http://127.0.0.1:' + str(self._webServerPort) + self._basicOnlyPath
+        r = requests.get(url, auth=('whatever', self._webServerBasicAuthPassword), timeout=self._webTimeout)
+        self.assertEquals(r.status_code, 401)
+
+    def testXAPIKeyChange(self):
+        """
+        API: X-Api-Key updating credentials
+        """
+
+        self.sendConsoleCommand('setWebserverConfig("{}", "{}")'.format(self._webServerBasicAuthPasswordNew, self._webServerAPIKeyNew))
+
+        headers = {'x-api-key': self._webServerAPIKeyNew}
+        url = 'http://127.0.0.1:' + str(self._webServerPort) + self._apiOnlyPath
+        r = requests.get(url, headers=headers, timeout=self._webTimeout)
+        self.assertTrue(r)
+        self.assertEquals(r.status_code, 200)
+
+        # Make sure the old password is not usable any more
+        headers = {'x-api-key': self._webServerAPIKey}
+        url = 'http://127.0.0.1:' + str(self._webServerPort) + self._apiOnlyPath
+        r = requests.get(url, headers=headers, timeout=self._webTimeout)
+        self.assertEquals(r.status_code, 401)
+
+    def testBasicAuthOnlyChange(self):
+        """
+        API: X-Api-Key updated to none (disabled)
+        """
+
+        self.sendConsoleCommand('setWebserverConfig("{}", "{}")'.format(self._webServerBasicAuthPasswordNew, self._webServerAPIKeyNew))
+
+        headers = {'x-api-key': self._webServerAPIKeyNew}
+        url = 'http://127.0.0.1:' + str(self._webServerPort) + self._apiOnlyPath
+        r = requests.get(url, headers=headers, timeout=self._webTimeout)
+        self.assertTrue(r)
+        self.assertEquals(r.status_code, 200)
+
+        # now disable apiKey
+        self.sendConsoleCommand('setWebserverConfig("{}")'.format(self._webServerBasicAuthPasswordNew))
+
+        headers = {'x-api-key': self._webServerAPIKeyNew}
+        url = 'http://127.0.0.1:' + str(self._webServerPort) + self._apiOnlyPath
+        r = requests.get(url, headers=headers, timeout=self._webTimeout)
+        self.assertEquals(r.status_code, 401)