]> granicus.if.org Git - pdns/commitdiff
dnsdist: Add support for custom HTTP headers in early responses
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 28 Aug 2019 13:56:43 +0000 (15:56 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 29 Aug 2019 13:44:46 +0000 (15:44 +0200)
pdns/dnsdist-lua-bindings.cc
pdns/dnsdistdist/docs/reference/config.rst
pdns/dnsdistdist/doh.cc
pdns/doh.hh
regression-tests.dnsdist/test_DOH.py

index a8c1f931d3f307cec0b5cc5015265835c66b8475..efa8d94b7bf8c95c53ff924ce9e6be3c14579207 100644 (file)
@@ -369,7 +369,14 @@ void setupLuaBindings(bool client)
     return values;
   });
 
-  g_lua.writeFunction("newDOHResponseMapEntry", [](const std::string& regex, uint16_t status, const std::string& content) {
-    return std::make_shared<DOHResponseMapEntry>(regex, status, content);
+  g_lua.writeFunction("newDOHResponseMapEntry", [](const std::string& regex, uint16_t status, const std::string& content, boost::optional<std::map<std::string, std::string>> customHeaders) {
+    boost::optional<std::vector<std::pair<std::string, std::string>>> headers{boost::none};
+    if (customHeaders) {
+      headers = std::vector<std::pair<std::string, std::string>>();
+      for (const auto& header : *customHeaders) {
+        headers->push_back({ header.first, header.second });
+      }
+    }
+    return std::make_shared<DOHResponseMapEntry>(regex, status, content, headers);
   });
 }
index a202ea01a52730d578cb927139cca6320d2624f4..7c31244a490a0414a9a28221d9732de454311ea9 100644 (file)
@@ -1152,7 +1152,7 @@ DOHFrontend
      :param list of DOHResponseMapEntry objects rules: A list of DOHResponseMapEntry objects, obtained with :func:`newDOHResponseMapEntry`.
 
 
-.. function:: newDOHResponseMapEntry(regex, status, content) -> DOHResponseMapEntry
+.. function:: newDOHResponseMapEntry(regex, status, content [, headers]) -> DOHResponseMapEntry
 
   .. versionadded:: 1.4.0
 
@@ -1162,6 +1162,7 @@ DOHFrontend
   :param str regex: A regular expression to match the path against.
   :param int status: The HTTP code to answer with.
   :param str content: The content of the HTTP response, or a URL if the status is a redirection (3xx).
+  :param table of headers: The custom headers to set for the HTTP response, if any. The default is to use the value of the ``customResponseHeaders`` parameter passed to :func:`addDOHLocal`.
 
 TLSContext
 ~~~~~~~~~~
index c4efb7999e5db89fd1fb65587ca34f6c96ff1ace..24ded5167ae4a061b0de358b89fe98cc060584e1 100644 (file)
@@ -201,8 +201,14 @@ static const std::string& getReasonFromStatusCode(uint16_t statusCode)
   }
 }
 
-static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCode, const std::string& response, const std::string& contentType, bool addContentType)
+static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCode, const std::string& response, const std::vector<std::pair<std::string, std::string>>& customResponseHeaders, const std::string& contentType, bool addContentType)
 {
+  constexpr int overwrite_if_exists = 1;
+  constexpr int maybe_token = 1;
+  for (auto const& headerPair : customResponseHeaders) {
+    h2o_set_header_by_str(&req->pool, &req->res.headers, headerPair.first.c_str(), headerPair.first.size(), maybe_token, headerPair.second.c_str(), headerPair.second.size(), overwrite_if_exists);
+  }
+
   if (statusCode == 200) {
     ++df.d_validresponses;
     req->res.status = 200;
@@ -230,7 +236,7 @@ static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCo
   }
   else {
     if (!response.empty()) {
-      h2o_send_error_generic(req, statusCode, getReasonFromStatusCode(statusCode).c_str(), response.c_str(), 0);
+      h2o_send_error_generic(req, statusCode, getReasonFromStatusCode(statusCode).c_str(), response.c_str(), H2O_SEND_ERROR_KEEP_HEADERS);
     }
     else {
       switch(statusCode) {
@@ -555,12 +561,6 @@ try
     return 0;
   }
 
-  constexpr int overwrite_if_exists = 1;
-  constexpr int maybe_token = 1;
-  for (auto const& headerPair : dsc->df->d_customResponseHeaders) {
-    h2o_set_header_by_str(&req->pool, &req->res.headers, headerPair.first.c_str(), headerPair.first.size(), maybe_token, headerPair.second.c_str(), headerPair.second.size(), overwrite_if_exists);
-  }
-
   if(auto tlsversion = h2o_socket_get_ssl_protocol_version(sock)) {
     if(!strcmp(tlsversion, "TLSv1.0"))
       ++dsc->df->d_tls10queries;
@@ -578,7 +578,8 @@ try
 
   for (const auto& entry : dsc->df->d_responsesMap) {
     if (entry->matches(path)) {
-      handleResponse(*dsc->df, req, entry->getStatusCode(), entry->getContent(), std::string(), false);
+      const auto& customHeaders = entry->getHeaders();
+      handleResponse(*dsc->df, req, entry->getStatusCode(), entry->getContent(), customHeaders ? *customHeaders : dsc->df->d_customResponseHeaders, std::string(), false);
       return 0;
     }
   }
@@ -848,7 +849,7 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err)
 
   *du->self = nullptr; // so we don't clean up again in on_generator_dispose
 
-  handleResponse(*dsc->df, du->req, du->status_code, du->response, du->contentType, true);
+  handleResponse(*dsc->df, du->req, du->status_code, du->response, dsc->df->d_customResponseHeaders, du->contentType, true);
 
   delete du;
 }
index a058404be2f25d3ec66f2b17386c4fda8c5ae638..0b90c02c4d18236f597becd1e232f4548c2649d0 100644 (file)
@@ -7,7 +7,7 @@ struct DOHServerConfig;
 class DOHResponseMapEntry
 {
 public:
-  DOHResponseMapEntry(const std::string& regex, uint16_t status, const std::string& content): d_regex(regex), d_content(content), d_status(status)
+  DOHResponseMapEntry(const std::string& regex, uint16_t status, const std::string& content, const boost::optional<std::vector<std::pair<std::string, std::string>>>& headers): d_regex(regex), d_customHeaders(headers), d_content(content), d_status(status)
   {
   }
 
@@ -26,8 +26,14 @@ public:
     return d_content;
   }
 
+  const boost::optional<std::vector<std::pair<std::string, std::string>>>& getHeaders() const
+  {
+    return d_customHeaders;
+  }
+
 private:
   Regex d_regex;
+  boost::optional<std::vector<std::pair<std::string, std::string>>> d_customHeaders;
   std::string d_content;
   uint16_t d_status;
 };
index 7b7e652cead5b721b287576f9c81eac6fc373124..189dc7a4d241341e0382c7553375d20e22631957 100644 (file)
@@ -149,7 +149,7 @@ class TestDOH(DNSDistDOHTest):
 
     addDOHLocal("127.0.0.1:%s", "%s", "%s", { "/" }, {customResponseHeaders={["access-control-allow-origin"]="*",["user-agent"]="derp"}})
     dohFE = getDOHFrontend(0)
-    dohFE:setResponsesMap({newDOHResponseMapEntry('^/coffee$', 418, 'C0FFEE')})
+    dohFE:setResponsesMap({newDOHResponseMapEntry('^/coffee$', 418, 'C0FFEE', {['foo']='bar'})})
 
     addAction("drop.doh.tests.powerdns.com.", DropAction())
     addAction("refused.doh.tests.powerdns.com.", RCodeAction(DNSRCode.REFUSED))
@@ -568,6 +568,7 @@ class TestDOH(DNSDistDOHTest):
         """
         DOH: HTTP Early Response
         """
+        response_headers = BytesIO()
         url = self._dohBaseURL + 'coffee'
         conn = self.openDOHConnection(self._dohServerPort, caFile=self._caCert, timeout=2.0)
         conn.setopt(pycurl.URL, url)
@@ -575,26 +576,35 @@ class TestDOH(DNSDistDOHTest):
         conn.setopt(pycurl.SSL_VERIFYPEER, 1)
         conn.setopt(pycurl.SSL_VERIFYHOST, 2)
         conn.setopt(pycurl.CAINFO, self._caCert)
+        conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
         data = conn.perform_rb()
         rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+        headers = response_headers.getvalue()
 
         self.assertEquals(rcode, 418)
         self.assertEquals(data, b'C0FFEE')
+        self.assertIn('foo: bar', headers)
+        self.assertNotIn(self._customResponseHeader2, headers)
 
+        response_headers = BytesIO()
         conn = self.openDOHConnection(self._dohServerPort, caFile=self._caCert, timeout=2.0)
         conn.setopt(pycurl.URL, url)
         conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)])
         conn.setopt(pycurl.SSL_VERIFYPEER, 1)
         conn.setopt(pycurl.SSL_VERIFYHOST, 2)
         conn.setopt(pycurl.CAINFO, self._caCert)
+        conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
         conn.setopt(pycurl.POST, True)
         data = ''
         conn.setopt(pycurl.POSTFIELDS, data)
 
         data = conn.perform_rb()
         rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+        headers = response_headers.getvalue()
         self.assertEquals(rcode, 418)
         self.assertEquals(data, b'C0FFEE')
+        self.assertIn('foo: bar', headers)
+        self.assertNotIn(self._customResponseHeader2, headers)
 
 class TestDOHAddingECS(DNSDistDOHTest):