--enable-unit-tests \
--enable-libsodium \
--enable-dnscrypt \
+ --enable-dns-over-tls \
--prefix=$HOME/dnsdist \
--disable-silent-rules"
run "make -k -j3"
{ "addLuaResponseAction", true, "x, func", "where 'x' is all the combinations from `addAction`, and func is a function with the parameter `dr`, which returns an action to be taken on this response packet. Good for rare packets but where you want to do a lot of processing" },
{ "addCacheHitResponseAction", true, "DNS rule, DNS response action", "add a cache hit response rule" },
{ "addResponseAction", true, "DNS rule, DNS response action", "add a response rule" },
+ { "addTLSLocal", true, "addr, certFile, keyFile[,params]", "listen to incoming DNS over TLS queries on the specified address using the specified certificate and key. The last parameter is a table" },
{ "AllowAction", true, "", "let these packets go through" },
{ "AllowResponseAction", true, "", "let these packets go through" },
{ "AllRule", true, "", "matches all traffic" },
{ "getResponseRing", true, "", "return the current content of the response ring" },
{ "getServer", true, "n", "returns server with index n" },
{ "getServers", true, "", "returns a table with all defined servers" },
+ { "getTLSContext", true, "n", "returns the TLS context with index n" },
{ "inClientStartup", true, "", "returns true during console client parsing of configuration" },
{ "grepq", true, "Netmask|DNS Name|100ms|{\"::1\", \"powerdns.com\", \"100ms\"} [, n]", "shows the last n queries and responses matching the specified client address or range (Netmask), or the specified DNS Name, or slower than 100ms" },
{ "leastOutstanding", false, "", "Send traffic to downstream server with least outstanding queries, with the lowest 'order', and within that the lowest recent latency"},
{ "showServerPolicy", true, "", "show name of currently operational server selection policy" },
{ "showServers", true, "", "output all servers" },
{ "showTCPStats", true, "", "show some statistics regarding TCP" },
+ { "showTLSContext", true, "", "list all the available TLS contexts" },
{ "showVersion", true, "", "show the current version" },
{ "shutdown", true, "", "shut down `dnsdist`" },
{ "snmpAgent", true, "enableTraps [, masterSocket]", "enable `SNMP` support. `enableTraps` is a boolean indicating whether traps should be sent and `masterSocket` an optional string specifying how to connect to the master agent"},
#include "dnswriter.hh"
#include "dolog.hh"
#include "lock.hh"
+#include "protobuf.hh"
#include "sodcrypto.hh"
#include <boost/logic/tribool.hpp>
g_lua.writeFunction("shutdown", []() {
#ifdef HAVE_SYSTEMD
sd_notify(0, "STOPPING=1");
-#endif
+#endif /* HAVE_SYSTEMD */
+#if 0
+ // Useful for debugging leaks, but might lead to race under load
+ // since other threads are still runing.
+ for(auto& frontend : g_tlslocals) {
+ frontend->cleanup();
+ }
+ g_tlslocals.clear();
+#ifdef HAVE_PROTOBUF
+ google::protobuf::ShutdownProtobufLibrary();
+#endif /* HAVE_PROTOBUF */
+#endif /* 0 */
_exit(0);
} );
g_outputBuffer="recvmmsg support is not available!\n";
#endif
});
+
+ g_lua.writeFunction("addTLSLocal", [client](const std::string& addr, const std::string& certFile, const std::string& keyFile, boost::optional<localbind_t> vars) {
+ if (client)
+ return;
+#ifdef HAVE_DNS_OVER_TLS
+ setLuaSideEffect();
+ if (g_configurationDone) {
+ g_outputBuffer="addTLSLocal cannot be used at runtime!\n";
+ return;
+ }
+ shared_ptr<TLSFrontend> frontend = std::make_shared<TLSFrontend>();
+ frontend->d_certFile = certFile;
+ frontend->d_keyFile = keyFile;
+
+ if (vars) {
+ bool doTCP = true;
+ parseLocalBindVars(vars, doTCP, frontend->d_reusePort, frontend->d_tcpFastOpenQueueSize, frontend->d_interface, frontend->d_cpus);
+
+ if (vars->count("provider")) {
+ frontend->d_provider = boost::get<const string>((*vars)["provider"]);
+ }
+
+ if (vars->count("ciphers")) {
+ frontend->d_ciphers = boost::get<const string>((*vars)["ciphers"]);
+ }
+
+ if (vars->count("ticketKeyFile")) {
+ frontend->d_ticketKeyFile = boost::get<const string>((*vars)["ticketKeyFile"]);
+ }
+
+ if (vars->count("ticketsKeysRotationDelay")) {
+ frontend->d_ticketsKeyRotationDelay = std::stoi(boost::get<const string>((*vars)["ticketsKeysRotationDelay"]));
+ }
+
+ if (vars->count("numberOfTicketsKeys")) {
+ frontend->d_numberOfTicketsKeys = std::stoi(boost::get<const string>((*vars)["numberOfTicketsKeys"]));
+ }
+ }
+
+ try {
+ frontend->d_addr = ComboAddress(addr, 853);
+ vinfolog("Loading TLS provider %s", frontend->d_provider);
+ g_tlslocals.push_back(frontend); /// only works pre-startup, so no sync necessary
+ }
+ catch(const std::exception& e) {
+ g_outputBuffer="Error: "+string(e.what())+"\n";
+ }
+#else
+ g_outputBuffer="DNS over TLS support is not present!\n";
+#endif
+ });
+
+ g_lua.writeFunction("showTLSContexts", [client]() {
+#ifdef HAVE_DNS_OVER_TLS
+ setLuaNoSideEffect();
+ try {
+ ostringstream ret;
+ boost::format fmt("%1$-3d %2$-20.20s %|25t|%3$-14d %|40t|%4$-14d %|54t|%5$-21.21s");
+ // 1 2 3 4 5
+ ret << (fmt % "#" % "Address" % "# ticket keys" % "Rotation delay" % "Next rotation" ) << endl;
+ size_t counter = 0;
+ for (const auto& ctx : g_tlslocals) {
+ ret << (fmt % counter % ctx->d_addr.toStringWithPort() % ctx->getTicketsKeysCount() % ctx->getTicketsKeyRotationDelay() % ctx->getNextTicketsKeyRotation()) << endl;
+ counter++;
+ }
+ g_outputBuffer = ret.str();
+ }
+ catch(const std::exception& e) {
+ g_outputBuffer = e.what();
+ throw;
+ }
+#else
+ g_outputBuffer="DNS over TLS support is not present!\n";
+#endif
+ });
+
+ g_lua.writeFunction("getTLSContext", [client](size_t index) {
+ std::shared_ptr<TLSCtx> result = nullptr;
+#ifdef HAVE_DNS_OVER_TLS
+ setLuaNoSideEffect();
+ try {
+ if (index < g_tlslocals.size()) {
+ result = g_tlslocals.at(index)->getContext();
+ }
+ else {
+ errlog("Error: trying to get TLS context with index %zu but we only have %zu\n", index, g_tlslocals.size());
+ g_outputBuffer="Error: trying to get TLS context with index " + std::to_string(index) + " but we only have " + std::to_string(g_tlslocals.size()) + "\n";
+ }
+ }
+ catch(const std::exception& e) {
+ g_outputBuffer="Error: "+string(e.what())+"\n";
+ errlog("Error: %s\n", string(e.what()));
+ }
+#else
+ g_outputBuffer="DNS over TLS support is not present!\n";
+#endif
+ return result;
+ });
+
+ g_lua.registerFunction<void(std::shared_ptr<TLSCtx>::*)()>("rotateTicketsKey", [](std::shared_ptr<TLSCtx> ctx) {
+ if (ctx != nullptr) {
+ ctx->rotateTicketsKey(time(nullptr));
+ }
+ });
+
+ g_lua.registerFunction<void(std::shared_ptr<TLSCtx>::*)(const std::string&)>("loadTicketsKeys", [](std::shared_ptr<TLSCtx> ctx, const std::string& file) {
+ if (ctx != nullptr) {
+ ctx->loadTicketsKeys(file);
+ }
+ });
}
vector<std::function<void(void)>> setupLua(bool client, const std::string& config)
#include "dolog.hh"
#include "lock.hh"
#include "gettime.hh"
+#include "tcpiohandler.hh"
#include <thread>
#include <atomic>
using std::thread;
using std::atomic;
-/* TCP: the grand design.
- We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
- An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
+/* TCP: the grand design.
+ We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
+ An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
we will not go there.
In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
return false;
}
-static bool sendResponseToClient(int fd, const char* response, uint16_t responseLen)
+static bool getNonBlockingMsgLenFromClient(TCPIOHandler& handler, uint16_t* len)
+try
{
- return sendSizeAndMsgWithTimeout(fd, responseLen, response, g_tcpSendTimeout, nullptr, nullptr, 0, 0, 0);
+ uint16_t raw;
+ size_t ret = handler.read(&raw, sizeof raw, g_tcpRecvTimeout);
+ if(ret != sizeof raw)
+ return false;
+ *len = ntohs(raw);
+ return true;
+}
+catch(...) {
+ return false;
}
static bool maxConnectionDurationReached(unsigned int maxConnectionDuration, time_t start, unsigned int& remainingTime)
{
/* we get launched with a pipe on which we receive file descriptors from clients that we own
from that point on */
-
+
bool outstanding = false;
time_t lastTCPCleanup = time(nullptr);
g_tcpclientthreads->decrementQueuedCount();
ci=*citmp;
- delete citmp;
+ delete citmp;
uint16_t qlen, rlen;
vector<uint8_t> rewrittenResponse;
}
try {
+ TCPIOHandler handler(ci.fd, g_tcpRecvTimeout, ci.cs->tlsFrontend ? ci.cs->tlsFrontend->getContext() : nullptr, connectionStartTime);
+
for(;;) {
unsigned int remainingTime = 0;
ds = nullptr;
outstanding = false;
- if(!getNonBlockingMsgLen(ci.fd, &qlen, g_tcpRecvTimeout)) {
+ if(!getNonBlockingMsgLenFromClient(handler, &qlen)) {
break;
}
queryBuffer.reserve(qlen + 512);
char* query = &queryBuffer[0];
- readn2WithTimeout(ci.fd, query, qlen, g_tcpRecvTimeout, remainingTime);
-
+ handler.read(query, qlen, g_tcpRecvTimeout, remainingTime);
#ifdef HAVE_DNSCRYPT
std::shared_ptr<DnsCryptQuery> dnsCryptQuery = nullptr;
if (!decrypted) {
if (response.size() > 0) {
- sendResponseToClient(ci.fd, reinterpret_cast<char*>(response.data()), (uint16_t) response.size());
+ handler.writeSizeAndMsg(response.data(), response.size(), g_tcpSendTimeout);
}
break;
}
goto drop;
}
#endif
- sendResponseToClient(ci.fd, query, dq.len);
+ handler.writeSizeAndMsg(query, dq.len, g_tcpSendTimeout);
g_stats.selfAnswered++;
continue;
}
goto drop;
}
#endif
- sendResponseToClient(ci.fd, cachedResponse, cachedResponseSize);
+ handler.writeSizeAndMsg(cachedResponse, cachedResponseSize, g_tcpSendTimeout);
g_stats.cacheHits++;
continue;
}
goto drop;
}
#endif
- sendResponseToClient(ci.fd, query, dq.len);
+ handler.writeSizeAndMsg(query, dq.len, g_tcpSendTimeout);
continue;
}
goto drop;
}
#endif
- if (!sendResponseToClient(ci.fd, response, responseLen)) {
+ if (!handler.writeSizeAndMsg(response, responseLen, g_tcpSendTimeout)) {
break;
}
rewrittenResponse.clear();
}
}
- catch(...){}
+ catch(...) {}
drop:;
-
+
vinfolog("Closing TCP client connection with %s", ci.remote.toStringWithPort());
if (ci.fd >= 0) {
close(ci.fd);
}
ci.fd = -1;
+
if (ds && outstanding) {
outstanding = false;
--ds->outstanding;
return 0;
}
-
bool getMsgLen32(int fd, uint32_t* len)
try
{
GlobalStateHolder<NetmaskGroup> g_ACL;
string g_outputBuffer;
+
vector<std::tuple<ComboAddress, bool, bool, int, string, std::set<int>>> g_locals;
+std::vector<std::shared_ptr<TLSFrontend>> g_tlslocals;
#ifdef HAVE_DNSCRYPT
std::vector<std::tuple<ComboAddress,DnsCryptContext,bool, int, string, std::set<int>>> g_dnsCryptLocals;
#endif
cout<<"dnsdist "<<VERSION<<" ("<<LUA_RELEASE<<")"<<endl;
#endif
cout<<"Enabled features: ";
+#ifdef HAVE_DNS_OVER_TLS
+ cout<<"dns-over-tls(";
+#ifdef HAVE_GNUTLS
+ cout<<"gnutls ";
+#endif
+#ifdef HAVE_LIBSSL
+ cout<<"openssl";
+#endif
+ cout<<") ";
+#endif
#ifdef HAVE_DNSCRYPT
cout<<"dnscrypt ";
#endif
}
#endif
+ for(auto& frontend : g_tlslocals) {
+ ClientState* cs = new ClientState;
+ cs->local = frontend->d_addr;
+ cs->tcpFD = SSocket(cs->local.sin4.sin_family, SOCK_STREAM, 0);
+ SSetsockopt(cs->tcpFD, SOL_SOCKET, SO_REUSEADDR, 1);
+#ifdef TCP_DEFER_ACCEPT
+ SSetsockopt(cs->tcpFD, SOL_TCP,TCP_DEFER_ACCEPT, 1);
+#endif
+ if (frontend->d_tcpFastOpenQueueSize > 0) {
+#ifdef TCP_FASTOPEN
+ SSetsockopt(cs->tcpFD, SOL_TCP, TCP_FASTOPEN, frontend->d_tcpFastOpenQueueSize);
+#else
+ warnlog("TCP Fast Open has been configured on local address '%s' but is not supported", cs->local.toStringWithPort());
+#endif
+ }
+ if (frontend->d_reusePort) {
+#ifdef SO_REUSEPORT
+ SSetsockopt(cs->tcpFD, SOL_SOCKET, SO_REUSEPORT, 1);
+#else
+ warnlog("SO_REUSEPORT has been configured on local address '%s' but is not supported", cs.local.toStringWithPort());
+#endif
+ }
+ if(cs->local.sin4.sin_family == AF_INET6) {
+ SSetsockopt(cs->tcpFD, IPPROTO_IPV6, IPV6_V6ONLY, 1);
+ }
+
+ if (!frontend->d_interface.empty()) {
+#ifdef SO_BINDTODEVICE
+ int res = setsockopt(cs->tcpFD, SOL_SOCKET, SO_BINDTODEVICE, frontend->d_interface.c_str(), frontend->d_interface.length());
+ if (res != 0) {
+ warnlog("Error setting up the interface on local address '%s': %s", cs->local.toStringWithPort(), strerror(errno));
+ }
+#else
+ warnlog("An interface has been configured on local address '%s' but SO_BINDTODEVICE is not supported", cs->local.toStringWithPort());
+#endif
+ }
+
+ cs->cpus = frontend->d_cpus;
+
+ bindAny(cs->local.sin4.sin_family, cs->tcpFD);
+ if (frontend->setupTLS()) {
+ cs->tlsFrontend = frontend;
+ SBind(cs->tcpFD, cs->local);
+ SListen(cs->tcpFD, 64);
+ warnlog("Listening on %s for TLS", cs->local.toStringWithPort());
+ toLaunch.push_back(cs);
+ g_frontends.push_back(cs);
+ tcpBindsCount++;
+ }
+ else {
+ delete cs;
+ errlog("Error while setting up TLS on local address '%s', exiting", cs->local.toStringWithPort());
+ _exit(EXIT_FAILURE);
+ }
+ }
+
if(g_cmdLine.beDaemon) {
g_console=false;
daemonize();
#include "bpf-filter.hh"
#include <string>
#include <unordered_map>
-
+#include "tcpiohandler.hh"
#ifdef HAVE_PROTOBUF
#include <boost/uuid/uuid.hpp>
#ifdef HAVE_DNSCRYPT
DnsCryptContext* dnscryptCtx{0};
#endif
+ shared_ptr<TLSFrontend> tlsFrontend;
std::atomic<uint64_t> queries{0};
int udpFD{-1};
int tcpFD{-1};
extern ComboAddress g_serverControl; // not changed during runtime
extern std::vector<std::tuple<ComboAddress, bool, bool, int, std::string, std::set<int>>> g_locals; // not changed at runtime (we hope XXX)
+extern std::vector<shared_ptr<TLSFrontend>> g_tlslocals;
extern vector<ClientState*> g_frontends;
extern std::string g_key; // in theory needs locking
extern bool g_truncateTC;
AM_CPPFLAGS += $(RE2_CFLAGS)
endif
+if HAVE_DNS_OVER_TLS
+if HAVE_LIBSSL
+AM_CPPFLAGS += $(LIBSSL_CFLAGS)
+endif
+
+if HAVE_GNUTLS
+AM_CPPFLAGS += $(GNUTLS_CFLAGS)
+endif
+endif
EXTRA_DIST=dnslabeltext.rl \
dnsdistconf.lua \
sodcrypto.cc sodcrypto.hh \
sstuff.hh \
statnode.cc statnode.hh \
+ tcpiohandler.cc tcpiohandler.hh \
ext/luawrapper/include/LuaContext.hpp \
ext/json11/json11.cpp \
ext/json11/json11.hpp \
dnsdist_LDADD += $(RE2_LIBS)
endif
+if HAVE_DNS_OVER_TLS
+if HAVE_GNUTLS
+dnsdist_LDADD += -lgnutls
+endif
+
+if HAVE_LIBSSL
+dnsdist_LDADD += $(LIBSSL_LIBS) $(LIBCRYPTO_LIBS)
+endif
+endif
+
if !HAVE_LUA_HPP
BUILT_SOURCES += lua.hpp
nodist_dnsdist_SOURCES = lua.hpp
])
PDNS_CHECK_LUA_HPP
+DNSDIST_ENABLE_DNS_OVER_TLS
+DNSDIST_CHECK_GNUTLS
+DNSDIST_CHECK_LIBSSL
+AS_IF([test "x$enable_dns_over_tls" != "xno"], [
+ AS_IF([test "$HAVE_LIBSSL" = "1"], [
+ # we need libcrypto if libssl is enabled
+ PDNS_CHECK_LIBCRYPTO
+ ])
+ AS_IF([test "$HAVE_GNUTLS" = "0" -a "$HAVE_LIBSSL" = "0"], [
+ AC_MSG_ERROR([DNS over TLS support requested but neither GnuTLS nor OpenSSL are available])
+ ])
+])
+
AX_CXX_COMPILE_STDCXX_11([ext], [mandatory])
AC_MSG_CHECKING([whether we will enable compiler security checks])
[AC_MSG_NOTICE([SNMP: yes])],
[AC_MSG_NOTICE([SNMP: no])]
)
+AS_IF([test "x$enable_dns_over_tls" != "xno"],
+ [AC_MSG_NOTICE([DNS over TLS: yes])],
+ [AC_MSG_NOTICE([DNS over TLS: no])]
+)
+AS_IF([test "x$GNUTLS_LIBS" != "x"],
+ [AC_MSG_NOTICE([GnuTLS: yes])],
+ [AC_MSG_NOTICE([GnuTLS: no])]
+)
+AS_IF([test "x$LIBSSL_LIBS" != "x"],
+ [AC_MSG_NOTICE([OpenSSL: yes])],
+ [AC_MSG_NOTICE([OpenSSL: no])]
+)
AC_MSG_NOTICE([])
higher than 0 to enable TCP Fast Open when available.
Default is 0.
+.. function:: addTLSLocal(address, certFile, keyFile[, options])
+
+ .. versionadded:: 1.3.0
+
+ Listen on the specified address and TCP port for incoming DNS over TLS connections, presenting the specified X.509 certificate.
+
+ :param str address: The IP Address with an optional port to listen on.
+ The default port is 853.
+ :param str certFile: The path to a X.509 certificate file in PEM format.
+ :param str keyFile: The path to the private key file corresponding to the certificate.
+ :param table options: A table with key: value pairs with listen options.
+
+ Options:
+
+ * ``doTCP=true``: bool - Also bind on TCP on ``address``.
+ * ``reusePort=false``: bool - Set the ``SO_REUSEPORT`` socket option.
+ * ``tcpFastOpenSize=0``: int - Set the TCP Fast Open queue size, enabling TCP Fast Open when available and the value is larger than 0.
+ * ``interface=""``: str - Set the network interface to use.
+ * ``cpus={}``: table - Set the CPU affinity for this listener thread, asking the scheduler to run it on a single CPU id, or a set of CPU ids. This parameter is only available if the OS provides the pthread_setaffinity_np() function.
+ * ``provider``: str - The TLS library to use between GnuTLS and OpenSSL, if they were available and enabled at compilation time.
+ * ``ciphers``: str - The TLS ciphers to use. The exact format depends on the provider used.
+ * ``numberOfTicketsKeys``: int - The maximum number of tickets keys to keep in memory at the same time, if the provider supports it (GnuTLS doesn't, OpenSSL does). Only one key is marked as active and used to encrypt new tickets while the remaining ones can still be used to decrypt existing tickets after a rotation. Default to 5.
+ * ``ticketKeyFile``: str - The path to a file from where TLS tickets keys should be loaded, to support RFC 5077. These keys should be rotated often and never written to persistent storage to preserve forward secrecy. The default is to generate a random key. The OpenSSL provider supports several tickets keys to be able to decrypt existing sessions after the rotation, while the GnuTLS provider only supports one key.
+ * ``ticketsKeysRotationDelay``: int - Set the delay before the TLS tickets key is rotated, in seconds. Default is 43200 (12h).
+
.. function:: setLocal(address[, options])
.. versionadded:: 1.2.0
Print all statistics dnsdist gathers
+.. function:: getTLSContext(idx)
+ .. versionadded:: 1.3.0
+
+ Return the TLSContext object for the context of index ``idx``.
+
.. function:: grepq(selector[, num])
grepq(selectors[, num])
Show some statistics regarding TCP
+.. function:: showTLSContexts()
+ .. versionadded:: 1.3.0
+
+ Print the list of all availables DNS over TLS contexts.
+
.. function:: showVersion()
Print the version of dnsdist
If this function exists, it is called every second to so regular tasks.
This can be used for e.g. :doc:`Dynamic Blocks <../guides/dynblocks>`.
+
+TLSContext
+~~~~~~~~~~
+
+.. class:: TLSContext
+ .. versionadded:: 1.3.0
+
+ This object represents an address and port dnsdist is listening on for DNS over TLS queries.
+
+.. classmethod:: TLSContext:rotateTicketsKey()
+
+ Replace the current TLS tickets key by a new random one.
+
+.. classmethod:: TLSContext:loadTicketsKeys(ticketsKeysFile)
+
+ Load new tickets keys from the selected file, replacing the existing ones. These keys should be rotated often and never written to persistent storage to preserve forward secrecy. The default is to generate a random key. The OpenSSL provider supports several tickets keys to be able to decrypt existing sessions after the rotation, while the GnuTLS provider only supports one key.
+
+ :param str ticketsKeysFile: The path to a file from where TLS tickets keys should be loaded.
--- /dev/null
+AC_DEFUN([DNSDIST_CHECK_GNUTLS], [
+ AC_MSG_CHECKING([whether we will be linking in GnuTLS])
+ AC_ARG_ENABLE([gnutls],
+ AS_HELP_STRING([--enable-gnutls],[use GnuTLS @<:@default=auto@:>@]),
+ [enable_gnutls=$enableval],
+ [enable_gnutls=no],
+ )
+ AC_MSG_RESULT([$enable_gnutls])
+
+ AS_IF([test "x$enable_gnutls" != "xno"], [
+ AS_IF([test "x$enable_gnutls" = "xyes" -o "x$enable_gnutls" = "xauto"], [
+ # we require gnutls_certificate_set_x509_key_file, added in 3.1.11
+ PKG_CHECK_MODULES([GNUTLS], [gnutls >= 3.1.11], [
+ AC_DEFINE([HAVE_GNUTLS], [1], [Define to 1 if you have GnuTLS])
+ ], [ : ])
+ ])
+ ])
+ AM_CONDITIONAL([HAVE_GNUTLS], [test "x$GNUTLS_LIBS" != "x"])
+ AS_IF([test "x$enable_gnutls" = "xyes"], [
+ AS_IF([test x"$GNUTLS_LIBS" = "x"], [
+ AC_MSG_ERROR([GnuTLS requested but libraries were not found])
+ ])
+ ])
+])
--- /dev/null
+AC_DEFUN([DNSDIST_CHECK_LIBSSL], [
+ HAVE_LIBSSL=0
+ AC_MSG_CHECKING([if OpenSSL libssl is available])
+ PKG_CHECK_MODULES([LIBSSL], [libssl], [
+ [HAVE_LIBSSL=1],
+ AC_DEFINE([HAVE_LIBSSL], [1], [Define to 1 if you have OpenSSL libssl])
+ ])
+ AM_CONDITIONAL([HAVE_LIBSSL], [test "x$LIBSSL_LIBS" != "x"])
+])
--- /dev/null
+AC_DEFUN([DNSDIST_ENABLE_DNS_OVER_TLS], [
+ AC_MSG_CHECKING([whether to enable DNS over TLS support])
+ AC_ARG_ENABLE([dns-over-tls],
+ AS_HELP_STRING([--enable-dns-over-tls], [enable DNS over TLS support (require OpenSSL or s2n) @<:@default=no@:>@]),
+ [enable_dns_over_tls=$enableval],
+ [enable_dns_over_tls=no]
+ )
+ AC_MSG_RESULT([$enable_dns_over_tls])
+ AM_CONDITIONAL([HAVE_DNS_OVER_TLS], [test "x$enable_dns_over_tls" != "xno"])
+
+ AM_COND_IF([HAVE_DNS_OVER_TLS], [
+ AC_DEFINE([HAVE_DNS_OVER_TLS], [1], [Define to 1 if you enable DNS over TLS support])
+ ])
+])
--- /dev/null
+#include <fstream>
+
+#include "config.h"
+#include "dolog.hh"
+#include "iputils.hh"
+#include "lock.hh"
+#include "tcpiohandler.hh"
+
+#ifdef HAVE_LIBSODIUM
+#include <sodium.h>
+#endif /* HAVE_LIBSODIUM */
+
+#ifdef HAVE_DNS_OVER_TLS
+#ifdef HAVE_LIBSSL
+#include <openssl/conf.h>
+#include <openssl/err.h>
+#include <openssl/rand.h>
+#include <openssl/ssl.h>
+
+#include <boost/circular_buffer.hpp>
+
+#if (OPENSSL_VERSION_NUMBER < 0x1010000fL || defined LIBRESSL_VERSION_NUMBER)
+/* OpenSSL < 1.1.0 needs support for threading/locking in the calling application. */
+static pthread_mutex_t *openssllocks{nullptr};
+
+extern "C" {
+static void openssl_pthreads_locking_callback(int mode, int type, const char *file, int line)
+{
+ if (mode & CRYPTO_LOCK) {
+ pthread_mutex_lock(&(openssllocks[type]));
+
+ } else {
+ pthread_mutex_unlock(&(openssllocks[type]));
+ }
+}
+
+static unsigned long openssl_pthreads_id_callback()
+{
+ return (unsigned long)pthread_self();
+}
+}
+
+static void openssl_thread_setup()
+{
+ openssllocks = (pthread_mutex_t*)OPENSSL_malloc(CRYPTO_num_locks() * sizeof(pthread_mutex_t));
+
+ for (int i = 0; i < CRYPTO_num_locks(); i++)
+ pthread_mutex_init(&(openssllocks[i]), NULL);
+
+ CRYPTO_set_id_callback(openssl_pthreads_id_callback);
+ CRYPTO_set_locking_callback(openssl_pthreads_locking_callback);
+}
+
+static void openssl_thread_cleanup()
+{
+ CRYPTO_set_locking_callback(NULL);
+
+ for (int i=0; i<CRYPTO_num_locks(); i++) {
+ pthread_mutex_destroy(&(openssllocks[i]));
+ }
+
+ OPENSSL_free(openssllocks);
+}
+
+#else
+static void openssl_thread_setup()
+{
+}
+
+static void openssl_thread_cleanup()
+{
+}
+#endif /* (OPENSSL_VERSION_NUMBER < 0x1010000fL || defined LIBRESSL_VERSION_NUMBER) */
+
+/* From rfc5077 Section 4. Recommended Ticket Construction */
+#define TLS_TICKETS_KEY_NAME_SIZE (16)
+
+/* AES-256 */
+#define TLS_TICKETS_CIPHER_KEY_SIZE (32)
+#define TLS_TICKETS_CIPHER_ALGO (EVP_aes_256_cbc)
+
+/* HMAC SHA-256 */
+#define TLS_TICKETS_MAC_KEY_SIZE (32)
+#define TLS_TICKETS_MAC_ALGO (EVP_sha256)
+
+static int s_ticketsKeyIndex{-1};
+
+class OpenSSLTLSTicketKey
+{
+public:
+ OpenSSLTLSTicketKey()
+ {
+ if (RAND_bytes(d_name, sizeof(d_name)) != 1) {
+ throw std::runtime_error("Error while generating the name of the OpenSSL TLS ticket key");
+ }
+
+ if (RAND_bytes(d_cipherKey, sizeof(d_cipherKey)) != 1) {
+ throw std::runtime_error("Error while generating the cipher key of the OpenSSL TLS ticket key");
+ }
+
+ if (RAND_bytes(d_hmacKey, sizeof(d_hmacKey)) != 1) {
+ throw std::runtime_error("Error while generating the HMAC key of the OpenSSL TLS ticket key");
+ }
+#ifdef HAVE_LIBSODIUM
+ sodium_mlock(d_name, sizeof(d_name));
+ sodium_mlock(d_cipherKey, sizeof(d_cipherKey));
+ sodium_mlock(d_hmacKey, sizeof(d_hmacKey));
+#endif /* HAVE_LIBSODIUM */
+ }
+
+ OpenSSLTLSTicketKey(ifstream& file)
+ {
+ file.read(reinterpret_cast<char*>(d_name), sizeof(d_name));
+ file.read(reinterpret_cast<char*>(d_cipherKey), sizeof(d_cipherKey));
+ file.read(reinterpret_cast<char*>(d_hmacKey), sizeof(d_hmacKey));
+
+ if (file.fail()) {
+ throw std::runtime_error("Unable to load a ticket key from the OpenSSL tickets key file");
+ }
+#ifdef HAVE_LIBSODIUM
+ sodium_mlock(d_name, sizeof(d_name));
+ sodium_mlock(d_cipherKey, sizeof(d_cipherKey));
+ sodium_mlock(d_hmacKey, sizeof(d_hmacKey));
+#endif /* HAVE_LIBSODIUM */
+ }
+
+ ~OpenSSLTLSTicketKey()
+ {
+#ifdef HAVE_LIBSODIUM
+ sodium_munlock(d_name, sizeof(d_name));
+ sodium_munlock(d_cipherKey, sizeof(d_cipherKey));
+ sodium_munlock(d_hmacKey, sizeof(d_hmacKey));
+#else
+ OPENSSL_cleanse(d_name, sizeof(d_name));
+ OPENSSL_cleanse(d_cipherKey, sizeof(d_cipherKey));
+ OPENSSL_cleanse(d_hmacKey, sizeof(d_hmacKey));
+#endif /* HAVE_LIBSODIUM */
+ }
+
+ bool nameMatches(const unsigned char name[TLS_TICKETS_KEY_NAME_SIZE]) const
+ {
+ return (memcmp(d_name, name, sizeof(d_name)) == 0);
+ }
+
+ int encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx) const
+ {
+ memcpy(keyName, d_name, sizeof(d_name));
+
+ if (RAND_bytes(iv, EVP_MAX_IV_LENGTH) != 1) {
+ return -1;
+ }
+
+ if (EVP_EncryptInit_ex(ectx, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey, iv) != 1) {
+ return -1;
+ }
+
+ if (HMAC_Init_ex(hctx, d_hmacKey, sizeof(d_hmacKey), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) {
+ return -1;
+ }
+
+ return 1;
+ }
+
+ bool decrypt(const unsigned char* iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx) const
+ {
+ if (HMAC_Init_ex(hctx, d_hmacKey, sizeof(d_hmacKey), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) {
+ return false;
+ }
+
+ if (EVP_DecryptInit_ex(ectx, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey, iv) != 1) {
+ return false;
+ }
+
+ return true;
+ }
+
+private:
+ unsigned char d_name[TLS_TICKETS_KEY_NAME_SIZE];
+ unsigned char d_cipherKey[TLS_TICKETS_CIPHER_KEY_SIZE];
+ unsigned char d_hmacKey[TLS_TICKETS_MAC_KEY_SIZE];
+};
+
+class OpenSSLTLSTicketKeysRing
+{
+public:
+ OpenSSLTLSTicketKeysRing(size_t capacity)
+ {
+ pthread_rwlock_init(&d_lock, nullptr);
+ d_ticketKeys.set_capacity(capacity);
+ }
+
+ ~OpenSSLTLSTicketKeysRing()
+ {
+ pthread_rwlock_destroy(&d_lock);
+ }
+
+ void addKey(std::shared_ptr<OpenSSLTLSTicketKey> newKey)
+ {
+ WriteLock wl(&d_lock);
+ d_ticketKeys.push_back(newKey);
+ }
+
+ std::shared_ptr<OpenSSLTLSTicketKey> getEncryptionKey()
+ {
+ ReadLock rl(&d_lock);
+ return d_ticketKeys.front();
+ }
+
+ std::shared_ptr<OpenSSLTLSTicketKey> getDecryptionKey(unsigned char name[TLS_TICKETS_KEY_NAME_SIZE], bool& activeKey)
+ {
+ ReadLock rl(&d_lock);
+ for (auto& key : d_ticketKeys) {
+ if (key->nameMatches(name)) {
+ activeKey = (key == d_ticketKeys.front());
+ return key;
+ }
+ }
+ return nullptr;
+ }
+
+ size_t getKeysCount()
+ {
+ ReadLock rl(&d_lock);
+ return d_ticketKeys.size();
+ }
+
+private:
+ boost::circular_buffer<std::shared_ptr<OpenSSLTLSTicketKey> > d_ticketKeys;
+ pthread_rwlock_t d_lock;
+};
+
+class OpenSSLTLSConnection: public TLSConnection
+{
+public:
+ OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx)
+ {
+ d_socket = socket;
+ d_conn = SSL_new(tlsCtx);
+
+ if (!d_conn) {
+ vinfolog("Error creating TLS object");
+ if (g_verbose) {
+ ERR_print_errors_fp(stderr);
+ }
+ throw std::runtime_error("Error creating TLS object");
+ }
+
+ if (!SSL_set_fd(d_conn, d_socket)) {
+ throw std::runtime_error("Error assigning socket");
+ }
+
+ int res = 0;
+ do {
+ res = SSL_accept(d_conn);
+ if (res < 0) {
+ handleIORequest(res, timeout);
+ }
+ }
+ while (res < 0);
+
+ if (res != 1) {
+ throw std::runtime_error("Error accepting TLS connection");
+ }
+ }
+
+ virtual ~OpenSSLTLSConnection() override
+ {
+ if (d_conn) {
+ SSL_free(d_conn);
+ }
+ }
+
+ void handleIORequest(int res, unsigned int timeout)
+ {
+ int error = SSL_get_error(d_conn, res);
+ if (error == SSL_ERROR_WANT_READ) {
+ res = waitForData(d_socket, timeout);
+ if (res <= 0) {
+ throw std::runtime_error("Error reading from TLS connection");
+ }
+ }
+ else if (error == SSL_ERROR_WANT_WRITE) {
+ res = waitForRWData(d_socket, false, timeout, 0);
+ if (res <= 0) {
+ throw std::runtime_error("Error waiting to write to TLS connection");
+ }
+ }
+ else {
+ throw std::runtime_error("Error writing to TLS connection");
+ }
+ }
+
+ size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
+ {
+ size_t got = 0;
+ time_t start = 0;
+ unsigned int remainingTime = totalTimeout;
+ if (totalTimeout) {
+ start = time(nullptr);
+ }
+
+ do {
+ int res = SSL_read(d_conn, (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
+ if (res == 0) {
+ throw std::runtime_error("Error reading from TLS connection");
+ }
+ else if (res < 0) {
+ handleIORequest(res, readTimeout);
+ }
+ else {
+ got += (size_t) res;
+ }
+
+ if (totalTimeout) {
+ time_t now = time(nullptr);
+ unsigned int elapsed = now - start;
+ if (now < start || elapsed >= remainingTime) {
+ throw runtime_error("Timeout while reading data");
+ }
+ start = now;
+ remainingTime -= elapsed;
+ }
+ }
+ while (got < bufferSize);
+
+ return got;
+ }
+
+ size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
+ {
+ size_t got = 0;
+ do {
+ int res = SSL_write(d_conn, (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
+ if (res == 0) {
+ throw std::runtime_error("Error writing to TLS connection");
+ }
+ else if (res < 0) {
+ handleIORequest(res, writeTimeout);
+ }
+ else {
+ got += (size_t) res;
+ }
+ }
+ while (got < bufferSize);
+
+ return got;
+ }
+ void close() override
+ {
+ if (d_conn) {
+ SSL_shutdown(d_conn);
+ }
+ }
+
+private:
+ SSL* d_conn{nullptr};
+};
+
+class OpenSSLTLSIOCtx: public TLSCtx
+{
+public:
+ OpenSSLTLSIOCtx(const TLSFrontend& fe): d_ticketKeys(fe.d_numberOfTicketsKeys)
+ {
+ d_ticketsKeyRotationDelay = fe.d_ticketsKeyRotationDelay;
+
+ static const int sslOptions =
+ SSL_OP_NO_SSLv2 |
+ SSL_OP_NO_SSLv3 |
+ SSL_OP_NO_COMPRESSION |
+ SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION |
+ SSL_OP_SINGLE_DH_USE |
+ SSL_OP_SINGLE_ECDH_USE |
+ SSL_OP_CIPHER_SERVER_PREFERENCE;
+
+ if (s_users.fetch_add(1) == 0) {
+ ERR_load_crypto_strings();
+ OpenSSL_add_ssl_algorithms();
+ openssl_thread_setup();
+
+ s_ticketsKeyIndex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
+
+ if (s_ticketsKeyIndex == -1) {
+ throw std::runtime_error("Error getting an index for tickets key");
+ }
+ }
+
+ d_tlsCtx = SSL_CTX_new(SSLv23_server_method());
+ if (!d_tlsCtx) {
+ ERR_print_errors_fp(stderr);
+ throw std::runtime_error("Error creating TLS context on " + fe.d_addr.toStringWithPort());
+ }
+
+ /* use the internal built-in cache to store sessions */
+ SSL_CTX_set_session_cache_mode(d_tlsCtx, SSL_SESS_CACHE_SERVER);
+ /* use our own ticket keys handler so we can rotate them */
+ SSL_CTX_set_tlsext_ticket_key_cb(d_tlsCtx, &OpenSSLTLSIOCtx::ticketKeyCb);
+ SSL_CTX_set_ex_data(d_tlsCtx, s_ticketsKeyIndex, this);
+ SSL_CTX_set_options(d_tlsCtx, sslOptions);
+#if defined(SSL_CTX_set_ecdh_auto)
+ SSL_CTX_set_ecdh_auto(d_tlsCtx, 1);
+#endif
+ SSL_CTX_use_certificate_chain_file(d_tlsCtx, fe.d_certFile.c_str());
+ SSL_CTX_use_PrivateKey_file(d_tlsCtx, fe.d_keyFile.c_str(), SSL_FILETYPE_PEM);
+
+ if (!fe.d_ciphers.empty()) {
+ SSL_CTX_set_cipher_list(d_tlsCtx, fe.d_ciphers.c_str());
+ }
+
+ try {
+ if (fe.d_ticketKeyFile.empty()) {
+ handleTicketsKeyRotation(time(nullptr));
+ }
+ else {
+ loadTicketsKeys(fe.d_ticketKeyFile);
+ }
+ }
+ catch (const std::exception& e) {
+ SSL_CTX_free(d_tlsCtx);
+ throw;
+ }
+ }
+
+ virtual ~OpenSSLTLSIOCtx() override
+ {
+ if (d_tlsCtx) {
+ SSL_CTX_free(d_tlsCtx);
+ }
+
+ if (s_users.fetch_sub(1) == 1) {
+ ERR_free_strings();
+
+ EVP_cleanup();
+
+ CONF_modules_finish();
+ CONF_modules_free();
+ CONF_modules_unload(1);
+
+ CRYPTO_cleanup_all_ex_data();
+ openssl_thread_cleanup();
+ }
+ }
+
+ static int ticketKeyCb(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc)
+ {
+ SSL_CTX* sslCtx = SSL_get_SSL_CTX(s);
+ if (sslCtx == nullptr) {
+ return -1;
+ }
+
+ OpenSSLTLSIOCtx* ctx = reinterpret_cast<OpenSSLTLSIOCtx*>(SSL_CTX_get_ex_data(sslCtx, s_ticketsKeyIndex));
+ if (ctx == nullptr) {
+ return -1;
+ }
+
+ if (enc) {
+ const auto key = ctx->d_ticketKeys.getEncryptionKey();
+ if (key == nullptr) {
+ return -1;
+ }
+
+ return key->encrypt(keyName, iv, ectx, hctx);
+ }
+
+ bool activeEncryptionKey = false;
+
+ const auto key = ctx->d_ticketKeys.getDecryptionKey(keyName, activeEncryptionKey);
+ if (key == nullptr) {
+ /* we don't know this key, just create a new ticket */
+ return 0;
+ }
+
+ if (key->decrypt(iv, ectx, hctx) == false) {
+ return -1;
+ }
+
+ if (!activeEncryptionKey) {
+ /* this key is not active, please encrypt the ticket content with the currently active one */
+ return 2;
+ }
+
+ return 1;
+ }
+
+ std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
+ {
+ handleTicketsKeyRotation(now);
+
+ return std::unique_ptr<OpenSSLTLSConnection>(new OpenSSLTLSConnection(socket, timeout, d_tlsCtx));
+ }
+
+ void rotateTicketsKey(time_t now) override
+ {
+ auto newKey = std::make_shared<OpenSSLTLSTicketKey>();
+ d_ticketKeys.addKey(newKey);
+
+ if (d_ticketsKeyRotationDelay > 0) {
+ d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
+ }
+ }
+
+ void loadTicketsKeys(const std::string& keyFile) override
+ {
+ bool keyLoaded = false;
+ ifstream file(keyFile);
+ try {
+ do {
+ auto newKey = std::make_shared<OpenSSLTLSTicketKey>(file);
+ d_ticketKeys.addKey(newKey);
+ keyLoaded = true;
+ }
+ while (!file.fail());
+ }
+ catch (const std::exception& e) {
+ /* if we haven't been able to load at least one key, fail */
+ if (!keyLoaded) {
+ throw;
+ }
+ }
+
+ if (d_ticketsKeyRotationDelay > 0) {
+ d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
+ }
+
+ file.close();
+ }
+
+ size_t getTicketsKeysCount() override
+ {
+ return d_ticketKeys.getKeysCount();
+ }
+
+private:
+ OpenSSLTLSTicketKeysRing d_ticketKeys;
+ SSL_CTX* d_tlsCtx{nullptr};
+ static std::atomic<uint64_t> s_users;
+};
+
+std::atomic<uint64_t> OpenSSLTLSIOCtx::s_users(0);
+
+#endif /* HAVE_LIBSSL */
+
+#ifdef HAVE_GNUTLS
+#include <gnutls/gnutls.h>
+#include <gnutls/x509.h>
+
+class GnuTLSTicketsKey
+{
+public:
+ GnuTLSTicketsKey()
+ {
+ if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error generating tickets key for TLS context");
+ }
+
+#ifdef HAVE_LIBSODIUM
+ sodium_mlock(d_key.data, d_key.size);
+#endif /* HAVE_LIBSODIUM */
+ }
+
+ GnuTLSTicketsKey(const std::string& keyFile)
+ {
+ /* to be sure we are loading the correct amount of data, which
+ may change between versions, let's generate a correct key first */
+ if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
+ }
+
+#ifdef HAVE_LIBSODIUM
+ sodium_mlock(d_key.data, d_key.size);
+#endif /* HAVE_LIBSODIUM */
+
+ try {
+ ifstream file(keyFile);
+ file.read(reinterpret_cast<char*>(d_key.data), d_key.size);
+
+ if (file.fail()) {
+ file.close();
+ throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile);
+ }
+
+ file.close();
+ }
+ catch (const std::exception& e) {
+#ifdef HAVE_LIBSODIUM
+ sodium_munlock(d_key.data, d_key.size);
+#endif /* HAVE_LIBSODIUM */
+ gnutls_free(d_key.data);
+ throw;
+ }
+ }
+
+ ~GnuTLSTicketsKey()
+ {
+ if (d_key.data != nullptr && d_key.size > 0) {
+#ifdef HAVE_LIBSODIUM
+ sodium_munlock(d_key.data, d_key.size);
+#else
+ gnutls_memset(d_key.data, 0, d_key.size);
+#endif /* HAVE_LIBSODIUM */
+ }
+ gnutls_free(d_key.data);
+ }
+ const gnutls_datum_t& getKey() const
+ {
+ return d_key;
+ }
+
+private:
+ gnutls_datum_t d_key{nullptr, 0};
+};
+
+class GnuTLSConnection: public TLSConnection
+{
+public:
+
+ GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey> ticketsKey): d_ticketsKey(ticketsKey)
+ {
+ d_socket = socket;
+
+ if (gnutls_init(&d_conn, GNUTLS_SERVER
+#ifdef GNUTLS_NO_SIGNAL
+ | GNUTLS_NO_SIGNAL
+#endif
+ ) != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error creating TLS connection");
+ }
+
+ if (gnutls_credentials_set(d_conn, GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) {
+ gnutls_deinit(d_conn);
+ throw std::runtime_error("Error setting certificate and key to TLS connection");
+ }
+
+ if (gnutls_priority_set(d_conn, priorityCache) != GNUTLS_E_SUCCESS) {
+ gnutls_deinit(d_conn);
+ throw std::runtime_error("Error setting ciphers to TLS connection");
+ }
+
+ if (d_ticketsKey) {
+ const gnutls_datum_t& key = d_ticketsKey->getKey();
+ if (gnutls_session_ticket_enable_server(d_conn, &key) != GNUTLS_E_SUCCESS) {
+ gnutls_deinit(d_conn);
+ throw std::runtime_error("Error setting the tickets key to TLS connection");
+ }
+ }
+
+ gnutls_transport_set_int(d_conn, d_socket);
+
+ /* timeouts are in milliseconds */
+ gnutls_handshake_set_timeout(d_conn, timeout * 1000);
+ gnutls_record_set_timeout(d_conn, timeout * 1000);
+
+ int ret = 0;
+ do {
+ ret = gnutls_handshake(d_conn);
+ }
+ while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
+ }
+
+ virtual ~GnuTLSConnection() override
+ {
+ if (d_conn) {
+ gnutls_deinit(d_conn);
+ }
+ }
+
+ size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
+ {
+ size_t got = 0;
+ time_t start = 0;
+ unsigned int remainingTime = totalTimeout;
+ if (totalTimeout) {
+ start = time(nullptr);
+ }
+
+ do {
+ ssize_t res = gnutls_record_recv(d_conn, (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
+ if (res == 0) {
+ throw std::runtime_error("Error reading from TLS connection");
+ }
+ else if (res > 0) {
+ got += (size_t) res;
+ }
+ else if (res < 0) {
+ if (gnutls_error_is_fatal(res)) {
+ throw std::runtime_error("Error reading from TLS connection");
+ }
+ warnlog("Warning, non-fatal error while reading from TLS connection: %s", gnutls_strerror(res));
+ }
+
+ if (totalTimeout) {
+ time_t now = time(nullptr);
+ unsigned int elapsed = now - start;
+ if (now < start || elapsed >= remainingTime) {
+ throw runtime_error("Timeout while reading data");
+ }
+ start = now;
+ remainingTime -= elapsed;
+ }
+ }
+ while (got < bufferSize);
+
+ return got;
+ }
+
+ size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
+ {
+ size_t got = 0;
+
+ do {
+ ssize_t res = gnutls_record_send(d_conn, (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
+ if (res == 0) {
+ throw std::runtime_error("Error writing to TLS connection");
+ }
+ else if (res > 0) {
+ got += (size_t) res;
+ }
+ else if (res < 0) {
+ if (gnutls_error_is_fatal(res)) {
+ throw std::runtime_error("Error writing to TLS connection");
+ }
+ warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
+ }
+ }
+ while (got < bufferSize);
+
+ return got;
+ }
+
+ void close() override
+ {
+ if (d_conn) {
+ gnutls_bye(d_conn, GNUTLS_SHUT_WR);
+ }
+ }
+
+private:
+ gnutls_session_t d_conn{nullptr};
+ std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
+};
+
+class GnuTLSIOCtx: public TLSCtx
+{
+public:
+ GnuTLSIOCtx(const TLSFrontend& fe)
+ {
+ d_ticketsKeyRotationDelay = fe.d_ticketsKeyRotationDelay;
+
+ if (gnutls_certificate_allocate_credentials(&d_creds) != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort());
+ }
+
+ if (gnutls_certificate_set_x509_key_file(d_creds, fe.d_certFile.c_str(), fe.d_keyFile.c_str(), GNUTLS_X509_FMT_PEM) != GNUTLS_E_SUCCESS) {
+ gnutls_certificate_free_credentials(d_creds);
+ throw std::runtime_error("Error loading certificate and key for TLS context on " + fe.d_addr.toStringWithPort());
+ }
+
+#if GNUTLS_VERSION_NUMBER >= 0x030600
+ if (gnutls_certificate_set_known_dh_params(d_creds, GNUTLS_SEC_PARAM_HIGH) != GNUTLS_E_SUCCESS) {
+ gnutls_certificate_free_credentials(d_creds);
+ throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort());
+ }
+#endif
+
+ if (gnutls_priority_init(&d_priorityCache, fe.d_ciphers.empty() ? "NORMAL" : fe.d_ciphers.c_str(), nullptr) != GNUTLS_E_SUCCESS) {
+ warnlog("Error setting up TLS cipher preferences to %s, skipping.", fe.d_ciphers.c_str());
+ }
+
+ try {
+ if (fe.d_ticketKeyFile.empty()) {
+ handleTicketsKeyRotation(time(nullptr));
+ }
+ else {
+ loadTicketsKeys(fe.d_ticketKeyFile);
+ }
+ }
+ catch(const std::runtime_error& e) {
+ gnutls_certificate_free_credentials(d_creds);
+ throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what());
+ }
+ }
+
+ virtual ~GnuTLSIOCtx() override
+ {
+ if (d_creds) {
+ gnutls_certificate_free_credentials(d_creds);
+ }
+ if (d_priorityCache) {
+ gnutls_priority_deinit(d_priorityCache);
+ }
+ }
+
+ std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
+ {
+ handleTicketsKeyRotation(now);
+
+ return std::unique_ptr<GnuTLSConnection>(new GnuTLSConnection(socket, timeout, d_creds, d_priorityCache, d_ticketsKey));
+ }
+
+ void rotateTicketsKey(time_t now) override
+ {
+ auto newKey = std::make_shared<GnuTLSTicketsKey>();
+ d_ticketsKey = newKey;
+ if (d_ticketsKeyRotationDelay > 0) {
+ d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
+ }
+ }
+
+ void loadTicketsKeys(const std::string& file) override
+ {
+ auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
+ d_ticketsKey = newKey;
+ if (d_ticketsKeyRotationDelay > 0) {
+ d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
+ }
+ }
+
+ size_t getTicketsKeysCount() override
+ {
+ return d_ticketsKey != nullptr ? 1 : 0;
+ }
+
+private:
+ gnutls_certificate_credentials_t d_creds{nullptr};
+ gnutls_priority_t d_priorityCache{nullptr};
+ std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey{nullptr};
+};
+
+#endif /* HAVE_GNUTLS */
+
+#endif /* HAVE_DNS_OVER_TLS */
+
+bool TLSFrontend::setupTLS()
+{
+#ifdef HAVE_DNS_OVER_TLS
+ /* get the "best" available provider */
+ if (!d_provider.empty()) {
+#ifdef HAVE_GNUTLS
+ if (d_provider == "gnutls") {
+ d_ctx = std::make_shared<GnuTLSIOCtx>(*this);
+ return true;
+ }
+#endif /* HAVE_GNUTLS */
+#ifdef HAVE_LIBSSL
+ if (d_provider == "openssl") {
+ d_ctx = std::make_shared<OpenSSLTLSIOCtx>(*this);
+ return true;
+ }
+#endif /* HAVE_LIBSSL */
+ }
+#ifdef HAVE_GNUTLS
+ d_ctx = std::make_shared<GnuTLSIOCtx>(*this);
+#else /* HAVE_GNUTLS */
+#ifdef HAVE_LIBSSL
+ d_ctx = std::make_shared<OpenSSLTLSIOCtx>(*this);
+#endif /* HAVE_LIBSSL */
+#endif /* HAVE_GNUTLS */
+
+#endif /* HAVE_DNS_OVER_TLS */
+ return true;
+}
--- /dev/null
+
+#pragma once
+#include <memory>
+
+#include "misc.hh"
+
+class TLSConnection
+{
+public:
+ virtual ~TLSConnection() { }
+ virtual size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0) = 0;
+ virtual size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) = 0;
+ virtual void close() = 0;
+
+protected:
+ int d_socket{-1};
+};
+
+class TLSCtx
+{
+public:
+ virtual ~TLSCtx() {}
+ virtual std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) = 0;
+ virtual void rotateTicketsKey(time_t now) = 0;
+ virtual void loadTicketsKeys(const std::string& file)
+ {
+ throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file");
+ }
+
+ void handleTicketsKeyRotation(time_t now)
+ {
+ if (d_ticketsKeyRotationDelay != 0 && now > d_ticketsKeyNextRotation) {
+ if (d_rotatingTicketsKey.test_and_set()) {
+ /* someone is already rotating */
+ return;
+ }
+ try {
+ rotateTicketsKey(now);
+ d_rotatingTicketsKey.clear();
+ }
+ catch(const std::runtime_error& e) {
+ d_rotatingTicketsKey.clear();
+ throw std::runtime_error("Error generating a new tickets key for TLS context");
+ }
+ }
+ }
+
+ time_t getNextTicketsKeyRotation() const
+ {
+ return d_ticketsKeyNextRotation;
+ }
+
+ virtual size_t getTicketsKeysCount() = 0;
+
+protected:
+ std::atomic_flag d_rotatingTicketsKey{ATOMIC_FLAG_INIT};
+ time_t d_ticketsKeyRotationDelay{0};
+ time_t d_ticketsKeyNextRotation{0};
+};
+
+class TLSFrontend
+{
+public:
+ bool setupTLS();
+
+ void rotateTicketsKey(time_t now)
+ {
+ if (d_ctx != nullptr) {
+ d_ctx->rotateTicketsKey(now);
+ }
+ }
+
+ void loadTicketsKeys(const std::string& file)
+ {
+ if (d_ctx != nullptr) {
+ d_ctx->loadTicketsKeys(file);
+ }
+ }
+
+ std::shared_ptr<TLSCtx> getContext()
+ {
+ return d_ctx;
+ }
+
+ void cleanup()
+ {
+ d_ctx.reset();
+ }
+
+ size_t getTicketsKeysCount()
+ {
+ if (d_ctx != nullptr) {
+ return d_ctx->getTicketsKeysCount();
+ }
+
+ return 0;
+ }
+
+ static std::string timeToString(time_t rotationTime)
+ {
+ char buf[20];
+ struct tm date_tm;
+
+ localtime_r(&rotationTime, &date_tm);
+ strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", &date_tm);
+
+ return std::string(buf);
+ }
+
+ time_t getTicketsKeyRotationDelay() const
+ {
+ return d_ticketsKeyRotationDelay;
+ }
+
+ std::string getNextTicketsKeyRotation() const
+ {
+ std::string res;
+
+ if (d_ctx != nullptr) {
+ res = timeToString(d_ctx->getNextTicketsKeyRotation());
+ }
+
+ return res;
+ }
+
+ std::set<int> d_cpus;
+ ComboAddress d_addr;
+ std::string d_certFile;
+ std::string d_keyFile;
+ std::string d_ciphers;
+ std::string d_provider;
+ std::string d_interface;
+ std::string d_ticketKeyFile;
+
+ time_t d_ticketsKeyRotationDelay{43200};
+ int d_tcpFastOpenQueueSize{0};
+ uint8_t d_numberOfTicketsKeys{5};
+ bool d_reusePort{false};
+
+private:
+ std::shared_ptr<TLSCtx> d_ctx{nullptr};
+};
+
+class TCPIOHandler
+{
+public:
+ TCPIOHandler(int socket, unsigned int timeout, std::shared_ptr<TLSCtx> ctx, time_t now): d_socket(socket)
+ {
+ if (ctx) {
+ d_conn = ctx->getConnection(d_socket, timeout, now);
+ }
+ }
+ ~TCPIOHandler()
+ {
+ if (d_conn) {
+ d_conn->close();
+ }
+ else if (d_socket != -1) {
+ shutdown(d_socket, SHUT_RDWR);
+ }
+ }
+ size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0)
+ {
+ if (d_conn) {
+ return d_conn->read(buffer, bufferSize, readTimeout, totalTimeout);
+ } else {
+ return readn2WithTimeout(d_socket, buffer, bufferSize, readTimeout, totalTimeout);
+ }
+ }
+ size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout)
+ {
+ if (d_conn) {
+ return d_conn->write(buffer, bufferSize, writeTimeout);
+ }
+ else {
+ return writen2WithTimeout(d_socket, buffer, bufferSize, writeTimeout);
+ }
+ }
+
+ bool writeSizeAndMsg(const void* buffer, size_t bufferSize, unsigned int writeTimeout)
+ {
+ if (d_conn) {
+ uint16_t size = htons(bufferSize);
+ if (d_conn->write(&size, sizeof(size), writeTimeout) != sizeof(size)) {
+ return false;
+ }
+ return (d_conn->write(buffer, bufferSize, writeTimeout) == bufferSize);
+ }
+ else {
+ return sendSizeAndMsgWithTimeout(d_socket, bufferSize, static_cast<const char*>(buffer), writeTimeout, nullptr, nullptr, 0, 0, 0);
+ }
+ }
+
+private:
+ std::unique_ptr<TLSConnection> d_conn{nullptr};
+ int d_socket{-1};
+};
--- /dev/null
+[req]
+default_bits = 2048
+encrypt_key = no
+x509_extensions = custom_extensions
+prompt = no
+distinguished_name = distinguished_name
+
+[v3_ca]
+subjectKeyIdentifier = hash
+authorityKeyIdentifier = keyid:always,issuer:always
+basicConstraints = critical, CA:true
+
+[distinguished_name]
+CN = DNSDist TLS regression tests CA
+OU = PowerDNS.com BV
+countryName = NL
+
+[custom_extensions]
+basicConstraints = CA:true
+keyUsage = cRLSign, keyCertSign
--- /dev/null
+[req]
+default_bits = 2048
+encrypt_key = no
+prompt = no
+distinguished_name = server_distinguished_name
+
+[server_distinguished_name]
+CN = tls.tests.dnsdist.org
+OU = PowerDNS.com BV
+countryName = NL
+
import Queue
import os
import socket
+import ssl
import struct
import subprocess
import sys
return sock
@classmethod
- def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False):
+ def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ if timeout:
+ sock.settimeout(timeout)
+
+ # 2.7.9+
+ if hasattr(ssl, 'create_default_context'):
+ sslctx = ssl.create_default_context(cafile=caCert)
+ sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
+ else:
+ sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
+
+ sslsock.connect(("127.0.0.1", port))
+ return sslsock
+
+ @classmethod
+ def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0):
if not rawQuery:
wire = query.to_wire()
else:
wire = query
+ if response:
+ cls._toResponderQueue.put(response, True, timeout)
+
sock.send(struct.pack("!H", len(wire)))
sock.send(wire)
@classmethod
- def recvTCPResponseOverConnection(cls, sock):
+ def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
message = None
data = sock.recv(2)
if data:
data = sock.recv(datalen)
if data:
message = dns.message.from_wire(data)
- return message
+
+ if useQueue and not cls._fromResponderQueue.empty():
+ receivedQuery = cls._fromResponderQueue.get(True, timeout)
+ return (receivedQuery, message)
+ else:
+ return message
@classmethod
def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
if [ "${PDNS_DEBUG}" = "YES" ]; then
set -x
fi
+
+# Generate a new CA
+openssl req -new -x509 -days 1 -extensions v3_ca -keyout ca.key -out ca.pem -nodes -config configCA.conf
+# Generate a new server certificate request
+openssl req -new -newkey rsa:2048 -nodes -keyout server.key -out server.csr -config configServer.conf
+# Sign the server cert
+openssl x509 -req -days 1 -CA ca.pem -CAkey ca.key -CAcreateserial -in server.csr -out server.pem
+# Generate a chain
+cat server.pem ca.pem >> server.chain
+
nosetests --with-xunit $@
+
+rm ca.key ca.pem ca.srl server.csr server.key server.pem server.chain
--- /dev/null
+#!/usr/bin/env python
+import dns
+from dnsdisttests import DNSDistTest
+
+class TestTLS(DNSDistTest):
+
+ _serverKey = 'server.key'
+ _serverCert = 'server.chain'
+ _serverName = 'tls.tests.dnsdist.org'
+ _caCert = 'ca.pem'
+ _tlsServerPort = 8453
+ _config_template = """
+ newServer{address="127.0.0.1:%s"}
+ addTLSLocal("127.0.0.1:%s", "%s", "%s")
+ """
+ _config_params = ['_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey']
+
+ def testTLSSimple(self):
+ """
+ TLS: Single query
+ """
+ name = 'single.tls.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+
+ conn = self.openTLSConnection(self._tlsServerPort, self._serverName, self._caCert)
+
+ self.sendTCPQueryOverConnection(conn, query, response=response)
+ (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, useQueue=True)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(response, receivedResponse)