From: Remi Gacogne Date: Fri, 12 Oct 2018 16:14:15 +0000 (+0200) Subject: dnsdist: Wrap GnuTLS and OpenSSL pointers in smart pointers X-Git-Tag: dnsdist-1.3.3~42^2 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=8dd7033badcb3842b721571153297af95dd88f7b;p=pdns dnsdist: Wrap GnuTLS and OpenSSL pointers in smart pointers --- diff --git a/pdns/dnsdistdist/tcpiohandler.cc b/pdns/dnsdistdist/tcpiohandler.cc index eb05d81e6..234c19fc7 100644 --- a/pdns/dnsdistdist/tcpiohandler.cc +++ b/pdns/dnsdistdist/tcpiohandler.cc @@ -232,10 +232,9 @@ private: class OpenSSLTLSConnection: public TLSConnection { public: - OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx) + OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr(SSL_new(tlsCtx), SSL_free)) { d_socket = socket; - d_conn = SSL_new(tlsCtx); if (!d_conn) { vinfolog("Error creating TLS object"); @@ -245,45 +244,27 @@ public: throw std::runtime_error("Error creating TLS object"); } - if (!SSL_set_fd(d_conn, d_socket)) { - SSL_free(d_conn); - d_conn = nullptr; + if (!SSL_set_fd(d_conn.get(), d_socket)) { throw std::runtime_error("Error assigning socket"); } int res = 0; do { - res = SSL_accept(d_conn); + res = SSL_accept(d_conn.get()); if (res < 0) { - try { - handleIORequest(res, timeout); - } - catch(...) { - SSL_free(d_conn); - d_conn = nullptr; - throw; - } + handleIORequest(res, timeout); } } while (res < 0); if (res != 1) { - SSL_free(d_conn); - d_conn = nullptr; throw std::runtime_error("Error accepting TLS connection"); } } - virtual ~OpenSSLTLSConnection() override - { - if (d_conn) { - SSL_free(d_conn); - } - } - void handleIORequest(int res, unsigned int timeout) { - int error = SSL_get_error(d_conn, res); + int error = SSL_get_error(d_conn.get(), res); if (error == SSL_ERROR_WANT_READ) { res = waitForData(d_socket, timeout); if (res <= 0) { @@ -311,7 +292,7 @@ public: } do { - int res = SSL_read(d_conn, (reinterpret_cast(buffer) + got), static_cast(bufferSize - got)); + int res = SSL_read(d_conn.get(), (reinterpret_cast(buffer) + got), static_cast(bufferSize - got)); if (res == 0) { throw std::runtime_error("Error reading from TLS connection"); } @@ -341,7 +322,7 @@ public: { size_t got = 0; do { - int res = SSL_write(d_conn, (reinterpret_cast(buffer) + got), static_cast(bufferSize - got)); + int res = SSL_write(d_conn.get(), (reinterpret_cast(buffer) + got), static_cast(bufferSize - got)); if (res == 0) { throw std::runtime_error("Error writing to TLS connection"); } @@ -359,18 +340,18 @@ public: void close() override { if (d_conn) { - SSL_shutdown(d_conn); + SSL_shutdown(d_conn.get()); } } private: - SSL* d_conn{nullptr}; + std::unique_ptr d_conn; }; class OpenSSLTLSIOCtx: public TLSCtx { public: - OpenSSLTLSIOCtx(const TLSFrontend& fe): d_ticketKeys(fe.d_numberOfTicketsKeys) + OpenSSLTLSIOCtx(const TLSFrontend& fe): d_ticketKeys(fe.d_numberOfTicketsKeys), d_tlsCtx(std::unique_ptr(nullptr, SSL_CTX_free)) { d_ticketsKeyRotationDelay = fe.d_ticketsKeyRotationDelay; @@ -399,42 +380,36 @@ public: } } - d_tlsCtx = SSL_CTX_new(SSLv23_server_method()); + d_tlsCtx = std::unique_ptr(SSL_CTX_new(SSLv23_server_method()), SSL_CTX_free); if (!d_tlsCtx) { ERR_print_errors_fp(stderr); throw std::runtime_error("Error creating TLS context on " + fe.d_addr.toStringWithPort()); } /* use the internal built-in cache to store sessions */ - SSL_CTX_set_session_cache_mode(d_tlsCtx, SSL_SESS_CACHE_SERVER); + SSL_CTX_set_session_cache_mode(d_tlsCtx.get(), SSL_SESS_CACHE_SERVER); /* use our own ticket keys handler so we can rotate them */ - SSL_CTX_set_tlsext_ticket_key_cb(d_tlsCtx, &OpenSSLTLSIOCtx::ticketKeyCb); - SSL_CTX_set_ex_data(d_tlsCtx, s_ticketsKeyIndex, this); - SSL_CTX_set_options(d_tlsCtx, sslOptions); + SSL_CTX_set_tlsext_ticket_key_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb); + SSL_CTX_set_ex_data(d_tlsCtx.get(), s_ticketsKeyIndex, this); + SSL_CTX_set_options(d_tlsCtx.get(), sslOptions); #if defined(SSL_CTX_set_ecdh_auto) - SSL_CTX_set_ecdh_auto(d_tlsCtx, 1); + SSL_CTX_set_ecdh_auto(d_tlsCtx.get(), 1); #endif for (const auto& pair : fe.d_certKeyPairs) { - if (SSL_CTX_use_certificate_chain_file(d_tlsCtx, pair.first.c_str()) != 1) { + if (SSL_CTX_use_certificate_chain_file(d_tlsCtx.get(), pair.first.c_str()) != 1) { ERR_print_errors_fp(stderr); - SSL_CTX_free(d_tlsCtx); - d_tlsCtx = nullptr; throw std::runtime_error("Error loading certificate from " + pair.first + " for the TLS context on " + fe.d_addr.toStringWithPort()); } - if (SSL_CTX_use_PrivateKey_file(d_tlsCtx, pair.second.c_str(), SSL_FILETYPE_PEM) != 1) { + if (SSL_CTX_use_PrivateKey_file(d_tlsCtx.get(), pair.second.c_str(), SSL_FILETYPE_PEM) != 1) { ERR_print_errors_fp(stderr); - SSL_CTX_free(d_tlsCtx); - d_tlsCtx = nullptr; throw std::runtime_error("Error loading key from " + pair.second + " for the TLS context on " + fe.d_addr.toStringWithPort()); } } if (!fe.d_ciphers.empty()) { - if (SSL_CTX_set_cipher_list(d_tlsCtx, fe.d_ciphers.c_str()) != 1) { + if (SSL_CTX_set_cipher_list(d_tlsCtx.get(), fe.d_ciphers.c_str()) != 1) { ERR_print_errors_fp(stderr); - SSL_CTX_free(d_tlsCtx); - d_tlsCtx = nullptr; throw std::runtime_error("Error setting the cipher list to '" + fe.d_ciphers + "' for the TLS context on " + fe.d_addr.toStringWithPort()); } } @@ -448,17 +423,13 @@ public: } } catch (const std::exception& e) { - SSL_CTX_free(d_tlsCtx); - d_tlsCtx = nullptr; throw; } } virtual ~OpenSSLTLSIOCtx() override { - if (d_tlsCtx) { - SSL_CTX_free(d_tlsCtx); - } + d_tlsCtx.reset(); if (s_users.fetch_sub(1) == 1) { ERR_free_strings(); @@ -519,7 +490,7 @@ public: { handleTicketsKeyRotation(now); - return std::unique_ptr(new OpenSSLTLSConnection(socket, timeout, d_tlsCtx)); + return std::unique_ptr(new OpenSSLTLSConnection(socket, timeout, d_tlsCtx.get())); } void rotateTicketsKey(time_t now) override @@ -565,7 +536,7 @@ public: private: OpenSSLTLSTicketKeysRing d_ticketKeys; - SSL_CTX* d_tlsCtx{nullptr}; + std::unique_ptr d_tlsCtx; static std::atomic s_users; }; @@ -668,7 +639,7 @@ class GnuTLSConnection: public TLSConnection { public: - GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr& ticketsKey, bool enableTickets): d_ticketsKey(ticketsKey) + GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr& ticketsKey, bool enableTickets): d_conn(std::unique_ptr(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey) { unsigned int sslOptions = GNUTLS_SERVER; #ifdef GNUTLS_NO_SIGNAL @@ -677,48 +648,42 @@ public: d_socket = socket; - if (gnutls_init(&d_conn, sslOptions) != GNUTLS_E_SUCCESS) { + gnutls_session_t conn; + if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error creating TLS connection"); } - if (gnutls_credentials_set(d_conn, GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) { - gnutls_deinit(d_conn); + d_conn = std::unique_ptr(conn, gnutls_deinit); + conn = nullptr; + + if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error setting certificate and key to TLS connection"); } - if (gnutls_priority_set(d_conn, priorityCache) != GNUTLS_E_SUCCESS) { - gnutls_deinit(d_conn); + if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error setting ciphers to TLS connection"); } if (enableTickets && d_ticketsKey) { const gnutls_datum_t& key = d_ticketsKey->getKey(); - if (gnutls_session_ticket_enable_server(d_conn, &key) != GNUTLS_E_SUCCESS) { - gnutls_deinit(d_conn); + if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error setting the tickets key to TLS connection"); } } - gnutls_transport_set_int(d_conn, d_socket); + gnutls_transport_set_int(d_conn.get(), d_socket); /* timeouts are in milliseconds */ - gnutls_handshake_set_timeout(d_conn, timeout * 1000); - gnutls_record_set_timeout(d_conn, timeout * 1000); + gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000); + gnutls_record_set_timeout(d_conn.get(), timeout * 1000); int ret = 0; do { - ret = gnutls_handshake(d_conn); + ret = gnutls_handshake(d_conn.get()); } while (ret < 0 && gnutls_error_is_fatal(ret) == 0); } - virtual ~GnuTLSConnection() override - { - if (d_conn) { - gnutls_deinit(d_conn); - } - } - size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override { size_t got = 0; @@ -729,7 +694,7 @@ public: } do { - ssize_t res = gnutls_record_recv(d_conn, (reinterpret_cast(buffer) + got), bufferSize - got); + ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast(buffer) + got), bufferSize - got); if (res == 0) { throw std::runtime_error("Error reading from TLS connection"); } @@ -763,7 +728,7 @@ public: size_t got = 0; do { - ssize_t res = gnutls_record_send(d_conn, (reinterpret_cast(buffer) + got), bufferSize - got); + ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast(buffer) + got), bufferSize - got); if (res == 0) { throw std::runtime_error("Error writing to TLS connection"); } @@ -785,40 +750,42 @@ public: void close() override { if (d_conn) { - gnutls_bye(d_conn, GNUTLS_SHUT_WR); + gnutls_bye(d_conn.get(), GNUTLS_SHUT_WR); } } private: - gnutls_session_t d_conn{nullptr}; + std::unique_ptr d_conn; std::shared_ptr d_ticketsKey; }; class GnuTLSIOCtx: public TLSCtx { public: - GnuTLSIOCtx(const TLSFrontend& fe): d_enableTickets(fe.d_enableTickets) + GnuTLSIOCtx(const TLSFrontend& fe): d_creds(std::unique_ptr(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(fe.d_enableTickets) { int rc = 0; d_ticketsKeyRotationDelay = fe.d_ticketsKeyRotationDelay; - rc = gnutls_certificate_allocate_credentials(&d_creds); + gnutls_certificate_credentials_t creds; + rc = gnutls_certificate_allocate_credentials(&creds); if (rc != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); } + d_creds = std::unique_ptr(creds, gnutls_certificate_free_credentials); + creds = nullptr; + for (const auto& pair : fe.d_certKeyPairs) { - rc = gnutls_certificate_set_x509_key_file(d_creds, pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM); + rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM); if (rc != GNUTLS_E_SUCCESS) { - gnutls_certificate_free_credentials(d_creds); throw std::runtime_error("Error loading certificate ('" + pair.first + "') and key ('" + pair.second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); } } #if GNUTLS_VERSION_NUMBER >= 0x030600 - rc = gnutls_certificate_set_known_dh_params(d_creds, GNUTLS_SEC_PARAM_HIGH); + rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH); if (rc != GNUTLS_E_SUCCESS) { - gnutls_certificate_free_credentials(d_creds); throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); } #endif @@ -837,16 +804,14 @@ public: } } catch(const std::runtime_error& e) { - gnutls_certificate_free_credentials(d_creds); throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what()); } } virtual ~GnuTLSIOCtx() override { - if (d_creds) { - gnutls_certificate_free_credentials(d_creds); - } + d_creds.reset(); + if (d_priorityCache) { gnutls_priority_deinit(d_priorityCache); } @@ -856,7 +821,7 @@ public: { handleTicketsKeyRotation(now); - return std::unique_ptr(new GnuTLSConnection(socket, timeout, d_creds, d_priorityCache, d_ticketsKey, d_enableTickets)); + return std::unique_ptr(new GnuTLSConnection(socket, timeout, d_creds.get(), d_priorityCache, d_ticketsKey, d_enableTickets)); } void rotateTicketsKey(time_t now) override @@ -891,7 +856,7 @@ public: } private: - gnutls_certificate_credentials_t d_creds{nullptr}; + std::unique_ptr d_creds; gnutls_priority_t d_priorityCache{nullptr}; std::shared_ptr d_ticketsKey{nullptr}; bool d_enableTickets{true};