From 80dbd7d275f17b5ad98c1df9c638a968a1e1ed67 Mon Sep 17 00:00:00 2001 From: Charles-Henri Bruyand Date: Tue, 30 Oct 2018 14:29:51 +0100 Subject: [PATCH] dnsdist: add ability to update webserver credentials --- pdns/dnsdist-console.cc | 1 + pdns/dnsdist-lua.cc | 8 +- pdns/dnsdist-web.cc | 66 +++++++++------ pdns/dnsdist.hh | 11 ++- pdns/dnsdistdist/docs/guides/webserver.rst | 2 + pdns/dnsdistdist/docs/reference/config.rst | 16 +++- regression-tests.dnsdist/test_API.py | 93 ++++++++++++++++++++++ 7 files changed, 169 insertions(+), 28 deletions(-) diff --git a/pdns/dnsdist-console.cc b/pdns/dnsdist-console.cc index d0195c9ec..874ea4aac 100644 --- a/pdns/dnsdist-console.cc +++ b/pdns/dnsdist-console.cc @@ -446,6 +446,7 @@ const std::vector g_consoleKeywords{ { "setUDPMultipleMessagesVectorSize", true, "n", "set the size of the vector passed to recvmmsg() to receive UDP messages. Default to 1 which means that the feature is disabled and recvmsg() is used instead" }, { "setUDPTimeout", true, "n", "set the maximum time dnsdist will wait for a response from a backend over UDP, in seconds" }, { "setVerboseHealthChecks", true, "bool", "set whether health check errors will be logged" }, + { "setWebserverConfig", true, "password [, apiKey [, customHeaders ]]", "Updates webserver configuration" }, { "show", true, "string", "outputs `string`" }, { "showACL", true, "", "show our ACL set" }, { "showBinds", true, "", "show listening addresses (frontends)" }, diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index 40abb4bce..cae1df018 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -631,7 +631,8 @@ void setupLuaConfig(bool client) SBind(sock, local); SListen(sock, 5); auto launch=[sock, local, password, apiKey, customHeaders]() { - thread t(dnsdistWebserverThread, sock, local, password, apiKey ? *apiKey : "", customHeaders); + setWebserverConfig(password, apiKey, customHeaders); + thread t(dnsdistWebserverThread, sock, local); t.detach(); }; if(g_launchWork) @@ -646,6 +647,11 @@ void setupLuaConfig(bool client) }); + g_lua.writeFunction("setWebserverConfig", [](const std::string& password, const boost::optional apiKey, const boost::optional > customHeaders) { + setLuaSideEffect(); + setWebserverConfig(password, apiKey, customHeaders); + }); + g_lua.writeFunction("controlSocket", [client](const std::string& str) { setLuaSideEffect(); ComboAddress local(str, 5199); diff --git a/pdns/dnsdist-web.cc b/pdns/dnsdist-web.cc index a50d68f55..f038b2cf9 100644 --- a/pdns/dnsdist-web.cc +++ b/pdns/dnsdist-web.cc @@ -38,6 +38,7 @@ #include bool g_apiReadWrite{false}; +WebserverConfig g_webserverConfig; std::string g_apiConfigDirectory; static bool apiWriteConfigFile(const string& filebasename, const string& content) @@ -133,23 +134,25 @@ static bool isAStatsRequest(const YaHTTP::Request& req) return req.url.path == "/jsonstat" || req.url.path == "/metrics"; } -static bool compareAuthorization(const YaHTTP::Request& req, const string &expected_password, const string& expectedApiKey) +static bool compareAuthorization(const YaHTTP::Request& req) { + std::lock_guard lock(g_webserverConfig.lock); + if (isAnAPIRequest(req)) { /* Access to the API requires a valid API key */ - if (checkAPIKey(req, expectedApiKey)) { + if (checkAPIKey(req, g_webserverConfig.apiKey)) { return true; } - return isAnAPIRequestAllowedWithWebAuth(req) && checkWebPassword(req, expected_password); + return isAnAPIRequestAllowedWithWebAuth(req) && checkWebPassword(req, g_webserverConfig.password); } if (isAStatsRequest(req)) { /* Access to the stats is allowed for both API and Web users */ - return checkAPIKey(req, expectedApiKey) || checkWebPassword(req, expected_password); + return checkAPIKey(req, g_webserverConfig.apiKey) || checkWebPassword(req, g_webserverConfig.password); } - return checkWebPassword(req, expected_password); + return checkWebPassword(req, g_webserverConfig.password); } static bool isMethodAllowed(const YaHTTP::Request& req) @@ -241,7 +244,7 @@ static json11::Json::array someResponseRulesToJson(GlobalStateHolder>* return responseRules; } -static void connectionThread(int sock, ComboAddress remote, string password, string apiKey, const boost::optional >& customHeaders) +static void connectionThread(int sock, ComboAddress remote) { setThreadName("dnsdist/webConn"); @@ -276,8 +279,12 @@ static void connectionThread(int sock, ComboAddress remote, string password, str resp.version = req.version; const string charset = "; charset=utf-8"; - addCustomHeaders(resp, customHeaders); - addSecurityHeaders(resp, customHeaders); + { + std::lock_guard lock(g_webserverConfig.lock); + + addCustomHeaders(resp, g_webserverConfig.customHeaders); + addSecurityHeaders(resp, g_webserverConfig.customHeaders); + } /* indicate that the connection will be closed after completion of the response */ resp.headers["Connection"] = "close"; @@ -289,7 +296,7 @@ static void connectionThread(int sock, ComboAddress remote, string password, str handleCORS(req, resp); resp.status=200; } - else if (!compareAuthorization(req, password, apiKey)) { + else if (!compareAuthorization(req)) { YaHTTP::strstr_map_t::iterator header = req.headers.find("authorization"); if (header != req.headers.end()) errlog("HTTP Request \"%s\" from %s: Web Authentication failed", req.url.path, remote.toStringWithPort()); @@ -402,7 +409,7 @@ static void connectionThread(int sock, ComboAddress remote, string password, str // Prometheus suggest using '_' instead of '-' std::string prometheusMetricName = "dnsdist_" + boost::replace_all_copy(metricName, "-", "_"); - MetricDefinition metricDetails; + MetricDefinition metricDetails; if (!g_metricDefinitions.getMetricDetails(metricName, metricDetails)) { vinfolog("Do not have metric details for %s", metricName); @@ -427,7 +434,7 @@ static void connectionThread(int sock, ComboAddress remote, string password, str output << **dval; else output << (*boost::get(&std::get<1>(e)))(std::get<0>(e)); - + output << "\n"; } @@ -448,10 +455,10 @@ static void connectionThread(int sock, ComboAddress remote, string password, str output << "# TYPE " << statesbase << "order " << "gauge" << "\n"; output << "# HELP " << statesbase << "weight " << "The weight within the order in which this server is picked" << "\n"; output << "# TYPE " << statesbase << "weight " << "gauge" << "\n"; - + for (const auto& state : *states) { string serverName; - + if (state->name.empty()) serverName = state->remote.toStringWithPort(); else @@ -484,10 +491,10 @@ static void connectionThread(int sock, ComboAddress remote, string password, str auto localPools = g_pools.getLocal(); const string cachebase = "dnsdist_pool_"; - + for (const auto& entry : *localPools) { string poolName = entry.first; - + if (poolName.empty()) { poolName = "_default_"; } @@ -524,18 +531,18 @@ static void connectionThread(int sock, ComboAddress remote, string password, str int num=0; for(const auto& a : *localServers) { string status; - if(a->availability == DownstreamState::Availability::Up) + if(a->availability == DownstreamState::Availability::Up) status = "UP"; - else if(a->availability == DownstreamState::Availability::Down) + else if(a->availability == DownstreamState::Availability::Down) status = "DOWN"; - else + else status = (a->upStatus ? "up" : "down"); Json::array pools; for(const auto& p: a->pools) pools.push_back(p); - Json::object server{ + Json::object server{ {"id", num++}, {"name", a->name}, {"address", a->remote.toStringWithPort()}, @@ -612,7 +619,7 @@ static void connectionThread(int sock, ComboAddress remote, string password, str }; rules.push_back(rule); } - + auto responseRules = someResponseRulesToJson(&g_resprulactions); auto cacheHitResponseRules = someResponseRulesToJson(&g_cachehitresprulactions); auto selfAnsweredResponseRules = someResponseRulesToJson(&g_selfansweredresprulactions); @@ -631,7 +638,7 @@ static void connectionThread(int sock, ComboAddress remote, string password, str if(!localaddresses.empty()) localaddresses += ", "; localaddresses += std::get<0>(loc).toStringWithPort(); } - + Json my_json = Json::object { { "daemon_type", "dnsdist" }, { "version", VERSION}, @@ -831,7 +838,20 @@ static void connectionThread(int sock, ComboAddress remote, string password, str close(sock); } } -void dnsdistWebserverThread(int sock, const ComboAddress& local, const std::string& password, const std::string& apiKey, const boost::optional >& customHeaders) +void setWebserverConfig(const std::string& password, const boost::optional apiKey, const boost::optional > customHeaders) +{ + std::lock_guard lock(g_webserverConfig.lock); + + g_webserverConfig.password = password; + if (apiKey) { + g_webserverConfig.apiKey = *apiKey; + } else { + g_webserverConfig.apiKey.clear(); + } + g_webserverConfig.customHeaders = customHeaders; +} + +void dnsdistWebserverThread(int sock, const ComboAddress& local) { setThreadName("dnsdist/webserv"); warnlog("Webserver launched on %s", local.toStringWithPort()); @@ -840,7 +860,7 @@ void dnsdistWebserverThread(int sock, const ComboAddress& local, const std::stri ComboAddress remote(local); int fd = SAccept(sock, remote); vinfolog("Got connection from %s", remote.toStringWithPort()); - std::thread t(connectionThread, fd, remote, password, apiKey, customHeaders); + std::thread t(connectionThread, fd, remote); t.detach(); } catch(std::exception& e) { diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index 08393675e..8127eac2b 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -985,7 +985,16 @@ std::shared_ptr whashed(const NumberedServerVector& servers, co std::shared_ptr chashed(const NumberedServerVector& servers, const DNSQuestion* dq); std::shared_ptr roundrobin(const NumberedServerVector& servers, const DNSQuestion* dq); -void dnsdistWebserverThread(int sock, const ComboAddress& local, const string& password, const string& apiKey, const boost::optional >&); +struct WebserverConfig +{ + std::string password; + std::string apiKey; + boost::optional > customHeaders; + std::mutex lock; +}; + +void setWebserverConfig(const std::string& password, const boost::optional apiKey, const boost::optional > customHeaders); +void dnsdistWebserverThread(int sock, const ComboAddress& local); bool getMsgLen32(int fd, uint32_t* len); bool putMsgLen32(int fd, uint32_t len); void* tcpAcceptorThread(void* p); diff --git a/pdns/dnsdistdist/docs/guides/webserver.rst b/pdns/dnsdistdist/docs/guides/webserver.rst index 0c60c8d94..bb8d8e3dd 100644 --- a/pdns/dnsdistdist/docs/guides/webserver.rst +++ b/pdns/dnsdistdist/docs/guides/webserver.rst @@ -29,6 +29,8 @@ For example, to remove the X-Frame-Options header and add a X-Custom one: webserver("127.0.0.1:8080", "supersecret", "apikey", {["X-Frame-Options"]= "", ["X-Custom"]="custom"} +Credentials can be changed over time using the :func:`setWebserverConfig` function. + dnsdist API ----------- diff --git a/pdns/dnsdistdist/docs/reference/config.rst b/pdns/dnsdistdist/docs/reference/config.rst index 345a25541..86bd479f0 100644 --- a/pdns/dnsdistdist/docs/reference/config.rst +++ b/pdns/dnsdistdist/docs/reference/config.rst @@ -211,10 +211,10 @@ Control Socket, Console and Webserver :param int size: The new maximum size. -Webserver -~~~~~~~~~ +Webserver configuration +~~~~~~~~~~~~~~~~~~~~~~~ -.. function:: webServer(listen_address, password[, apikey[, custom_headers]]) +.. function:: webserver(listen_address, password[, apikey[, custom_headers]]) Launch the :doc:`../guides/webserver` with statistics and the API. @@ -232,6 +232,16 @@ Webserver :param bool allow: Set to true to allow modification through the API :param str dir: A valid directory where the configuration files will be written by the API. +.. function:: setWebserverConfig(password[, apikey[, custom_headers]]) + + .. versionadded:: 1.3.3 + + Setup webserver configuration. See :func:`webserver`. + + :param str password: The password required to access the webserver + :param str apikey: The key required to access the API + :param {[str]=str,...} custom_headers: Allows setting custom headers and removing the defaults + Access Control Lists ~~~~~~~~~~~~~~~~~~~~ diff --git a/regression-tests.dnsdist/test_API.py b/regression-tests.dnsdist/test_API.py index f3047581c..f3d3ab8d1 100644 --- a/regression-tests.dnsdist/test_API.py +++ b/regression-tests.dnsdist/test_API.py @@ -1,10 +1,12 @@ #!/usr/bin/env python import os.path +import base64 import json import requests from dnsdisttests import DNSDistTest + class TestAPIBasics(DNSDistTest): _webTimeout = 2.0 @@ -30,6 +32,8 @@ class TestAPIBasics(DNSDistTest): """ for path in self._basicOnlyPaths + self._statsPaths: url = 'http://127.0.0.1:' + str(self._webServerPort) + path + r = requests.get(url, auth=('whatever', "evilsecret"), timeout=self._webTimeout) + self.assertEquals(r.status_code, 401) r = requests.get(url, auth=('whatever', self._webServerBasicAuthPassword), timeout=self._webTimeout) self.assertTrue(r) self.assertEquals(r.status_code, 200) @@ -45,6 +49,15 @@ class TestAPIBasics(DNSDistTest): self.assertTrue(r) self.assertEquals(r.status_code, 200) + def testWrongXAPIKey(self): + """ + API: Wrong X-Api-Key + """ + headers = {'x-api-key': "evilapikey"} + for path in self._apiOnlyPaths + self._statsPaths: + url = 'http://127.0.0.1:' + str(self._webServerPort) + path + r = requests.get(url, headers=headers, timeout=self._webTimeout) + self.assertEquals(r.status_code, 401) def testBasicAuthOnly(self): """ API: Basic Authentication Only @@ -370,3 +383,83 @@ class TestAPIWritable(DNSDistTest): self.assertEquals(fileContent, """-- Generated by the REST API, DO NOT EDIT setACL({"192.0.2.0/24", "198.51.100.0/24", "203.0.113.0/24"}) """) + +class TestAPIAuth(DNSDistTest): + + _webTimeout = 2.0 + _webServerPort = 8083 + _webServerBasicAuthPassword = 'secret' + _webServerBasicAuthPasswordNew = 'password' + _webServerAPIKey = 'apisecret' + _webServerAPIKeyNew = 'apipassword' + # paths accessible using the API key only + _apiOnlyPath = '/api/v1/servers/localhost/config' + # paths accessible using basic auth only (list not exhaustive) + _basicOnlyPath = '/' + _consoleKey = DNSDistTest.generateConsoleKey() + _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii') + _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey'] + _config_template = """ + setKey("%s") + controlSocket("127.0.0.1:%s") + setACL({"127.0.0.1/32", "::1/128"}) + newServer{address="127.0.0.1:%s"} + webserver("127.0.0.1:%s", "%s", "%s") + """ + + def testBasicAuthChange(self): + """ + API: Basic Authentication updating credentials + """ + + self.sendConsoleCommand('setWebserverConfig("{}")'.format(self._webServerBasicAuthPasswordNew)) + + url = 'http://127.0.0.1:' + str(self._webServerPort) + self._basicOnlyPath + r = requests.get(url, auth=('whatever', self._webServerBasicAuthPasswordNew), timeout=self._webTimeout) + self.assertTrue(r) + self.assertEquals(r.status_code, 200) + + # Make sure the old password is not usable any more + url = 'http://127.0.0.1:' + str(self._webServerPort) + self._basicOnlyPath + r = requests.get(url, auth=('whatever', self._webServerBasicAuthPassword), timeout=self._webTimeout) + self.assertEquals(r.status_code, 401) + + def testXAPIKeyChange(self): + """ + API: X-Api-Key updating credentials + """ + + self.sendConsoleCommand('setWebserverConfig("{}", "{}")'.format(self._webServerBasicAuthPasswordNew, self._webServerAPIKeyNew)) + + headers = {'x-api-key': self._webServerAPIKeyNew} + url = 'http://127.0.0.1:' + str(self._webServerPort) + self._apiOnlyPath + r = requests.get(url, headers=headers, timeout=self._webTimeout) + self.assertTrue(r) + self.assertEquals(r.status_code, 200) + + # Make sure the old password is not usable any more + headers = {'x-api-key': self._webServerAPIKey} + url = 'http://127.0.0.1:' + str(self._webServerPort) + self._apiOnlyPath + r = requests.get(url, headers=headers, timeout=self._webTimeout) + self.assertEquals(r.status_code, 401) + + def testBasicAuthOnlyChange(self): + """ + API: X-Api-Key updated to none (disabled) + """ + + self.sendConsoleCommand('setWebserverConfig("{}", "{}")'.format(self._webServerBasicAuthPasswordNew, self._webServerAPIKeyNew)) + + headers = {'x-api-key': self._webServerAPIKeyNew} + url = 'http://127.0.0.1:' + str(self._webServerPort) + self._apiOnlyPath + r = requests.get(url, headers=headers, timeout=self._webTimeout) + self.assertTrue(r) + self.assertEquals(r.status_code, 200) + + # now disable apiKey + self.sendConsoleCommand('setWebserverConfig("{}")'.format(self._webServerBasicAuthPasswordNew)) + + headers = {'x-api-key': self._webServerAPIKeyNew} + url = 'http://127.0.0.1:' + str(self._webServerPort) + self._apiOnlyPath + r = requests.get(url, headers=headers, timeout=self._webTimeout) + self.assertEquals(r.status_code, 401) -- 2.40.0