]> granicus.if.org Git - pdns/commitdiff
dnsdist: Wrap GnuTLS and OpenSSL pointers in smart pointers
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 12 Oct 2018 16:14:15 +0000 (18:14 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 15 Oct 2018 13:39:08 +0000 (15:39 +0200)
pdns/dnsdistdist/tcpiohandler.cc

index eb05d81e610b37a72d00d88a9db7fa8991438559..234c19fc74928e9bcf878c6cf8d09b0ad95cca08 100644 (file)
@@ -232,10 +232,9 @@ private:
 class OpenSSLTLSConnection: public TLSConnection
 {
 public:
-  OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx)
+  OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx), SSL_free))
   {
     d_socket = socket;
-    d_conn = SSL_new(tlsCtx);
 
     if (!d_conn) {
       vinfolog("Error creating TLS object");
@@ -245,45 +244,27 @@ public:
       throw std::runtime_error("Error creating TLS object");
     }
 
-    if (!SSL_set_fd(d_conn, d_socket)) {
-      SSL_free(d_conn);
-      d_conn = nullptr;
+    if (!SSL_set_fd(d_conn.get(), d_socket)) {
       throw std::runtime_error("Error assigning socket");
     }
 
     int res = 0;
     do {
-      res = SSL_accept(d_conn);
+      res = SSL_accept(d_conn.get());
       if (res < 0) {
-        try {
-          handleIORequest(res, timeout);
-        }
-        catch(...) {
-          SSL_free(d_conn);
-          d_conn = nullptr;
-          throw;
-        }
+        handleIORequest(res, timeout);
       }
     }
     while (res < 0);
 
     if (res != 1) {
-      SSL_free(d_conn);
-      d_conn = nullptr;
       throw std::runtime_error("Error accepting TLS connection");
     }
   }
 
-  virtual ~OpenSSLTLSConnection() override
-  {
-    if (d_conn) {
-      SSL_free(d_conn);
-    }
-  }
-
   void handleIORequest(int res, unsigned int timeout)
   {
-    int error = SSL_get_error(d_conn, res);
+    int error = SSL_get_error(d_conn.get(), res);
     if (error == SSL_ERROR_WANT_READ) {
       res = waitForData(d_socket, timeout);
       if (res <= 0) {
@@ -311,7 +292,7 @@ public:
     }
 
     do {
-      int res = SSL_read(d_conn, (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
+      int res = SSL_read(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
       if (res == 0) {
         throw std::runtime_error("Error reading from TLS connection");
       }
@@ -341,7 +322,7 @@ public:
   {
     size_t got = 0;
     do {
-      int res = SSL_write(d_conn, (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
+      int res = SSL_write(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
       if (res == 0) {
         throw std::runtime_error("Error writing to TLS connection");
       }
@@ -359,18 +340,18 @@ public:
   void close() override
   {
     if (d_conn) {
-      SSL_shutdown(d_conn);
+      SSL_shutdown(d_conn.get());
     }
   }
 
 private:
-  SSL* d_conn{nullptr};
+  std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
 };
 
 class OpenSSLTLSIOCtx: public TLSCtx
 {
 public:
-  OpenSSLTLSIOCtx(const TLSFrontend& fe): d_ticketKeys(fe.d_numberOfTicketsKeys)
+  OpenSSLTLSIOCtx(const TLSFrontend& fe): d_ticketKeys(fe.d_numberOfTicketsKeys), d_tlsCtx(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(nullptr, SSL_CTX_free))
   {
     d_ticketsKeyRotationDelay = fe.d_ticketsKeyRotationDelay;
 
@@ -399,42 +380,36 @@ public:
       }
     }
 
-    d_tlsCtx = SSL_CTX_new(SSLv23_server_method());
+    d_tlsCtx = std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(SSL_CTX_new(SSLv23_server_method()), SSL_CTX_free);
     if (!d_tlsCtx) {
       ERR_print_errors_fp(stderr);
       throw std::runtime_error("Error creating TLS context on " + fe.d_addr.toStringWithPort());
     }
 
     /* use the internal built-in cache to store sessions */
-    SSL_CTX_set_session_cache_mode(d_tlsCtx, SSL_SESS_CACHE_SERVER);
+    SSL_CTX_set_session_cache_mode(d_tlsCtx.get(), SSL_SESS_CACHE_SERVER);
     /* use our own ticket keys handler so we can rotate them */
-    SSL_CTX_set_tlsext_ticket_key_cb(d_tlsCtx, &OpenSSLTLSIOCtx::ticketKeyCb);
-    SSL_CTX_set_ex_data(d_tlsCtx, s_ticketsKeyIndex, this);
-    SSL_CTX_set_options(d_tlsCtx, sslOptions);
+    SSL_CTX_set_tlsext_ticket_key_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
+    SSL_CTX_set_ex_data(d_tlsCtx.get(), s_ticketsKeyIndex, this);
+    SSL_CTX_set_options(d_tlsCtx.get(), sslOptions);
 #if defined(SSL_CTX_set_ecdh_auto)
-    SSL_CTX_set_ecdh_auto(d_tlsCtx, 1);
+    SSL_CTX_set_ecdh_auto(d_tlsCtx.get(), 1);
 #endif
 
     for (const auto& pair : fe.d_certKeyPairs) {
-      if (SSL_CTX_use_certificate_chain_file(d_tlsCtx, pair.first.c_str()) != 1) {
+      if (SSL_CTX_use_certificate_chain_file(d_tlsCtx.get(), pair.first.c_str()) != 1) {
         ERR_print_errors_fp(stderr);
-        SSL_CTX_free(d_tlsCtx);
-        d_tlsCtx = nullptr;
         throw std::runtime_error("Error loading certificate from " + pair.first + " for the TLS context on " + fe.d_addr.toStringWithPort());
       }
-      if (SSL_CTX_use_PrivateKey_file(d_tlsCtx, pair.second.c_str(), SSL_FILETYPE_PEM) != 1) {
+      if (SSL_CTX_use_PrivateKey_file(d_tlsCtx.get(), pair.second.c_str(), SSL_FILETYPE_PEM) != 1) {
         ERR_print_errors_fp(stderr);
-        SSL_CTX_free(d_tlsCtx);
-        d_tlsCtx = nullptr;
         throw std::runtime_error("Error loading key from " + pair.second + " for the TLS context on " + fe.d_addr.toStringWithPort());
       }
     }
 
     if (!fe.d_ciphers.empty()) {
-      if (SSL_CTX_set_cipher_list(d_tlsCtx, fe.d_ciphers.c_str()) != 1) {
+      if (SSL_CTX_set_cipher_list(d_tlsCtx.get(), fe.d_ciphers.c_str()) != 1) {
         ERR_print_errors_fp(stderr);
-        SSL_CTX_free(d_tlsCtx);
-        d_tlsCtx = nullptr;
         throw std::runtime_error("Error setting the cipher list to '" + fe.d_ciphers + "' for the TLS context on " + fe.d_addr.toStringWithPort());
       }
     }
@@ -448,17 +423,13 @@ public:
       }
     }
     catch (const std::exception& e) {
-      SSL_CTX_free(d_tlsCtx);
-      d_tlsCtx = nullptr;
       throw;
     }
   }
 
   virtual ~OpenSSLTLSIOCtx() override
   {
-    if (d_tlsCtx) {
-      SSL_CTX_free(d_tlsCtx);
-    }
+    d_tlsCtx.reset();
 
     if (s_users.fetch_sub(1) == 1) {
       ERR_free_strings();
@@ -519,7 +490,7 @@ public:
   {
     handleTicketsKeyRotation(now);
 
-    return std::unique_ptr<OpenSSLTLSConnection>(new OpenSSLTLSConnection(socket, timeout, d_tlsCtx));
+    return std::unique_ptr<OpenSSLTLSConnection>(new OpenSSLTLSConnection(socket, timeout, d_tlsCtx.get()));
   }
 
   void rotateTicketsKey(time_t now) override
@@ -565,7 +536,7 @@ public:
 
 private:
   OpenSSLTLSTicketKeysRing d_ticketKeys;
-  SSL_CTX* d_tlsCtx{nullptr};
+  std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx;
   static std::atomic<uint64_t> s_users;
 };
 
@@ -668,7 +639,7 @@ class GnuTLSConnection: public TLSConnection
 {
 public:
 
-  GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_ticketsKey(ticketsKey)
+  GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey)
   {
     unsigned int sslOptions = GNUTLS_SERVER;
 #ifdef GNUTLS_NO_SIGNAL
@@ -677,48 +648,42 @@ public:
 
     d_socket = socket;
 
-    if (gnutls_init(&d_conn, sslOptions) != GNUTLS_E_SUCCESS) {
+    gnutls_session_t conn;
+    if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
       throw std::runtime_error("Error creating TLS connection");
     }
 
-    if (gnutls_credentials_set(d_conn, GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) {
-      gnutls_deinit(d_conn);
+    d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
+    conn = nullptr;
+
+    if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) {
       throw std::runtime_error("Error setting certificate and key to TLS connection");
     }
 
-    if (gnutls_priority_set(d_conn, priorityCache) != GNUTLS_E_SUCCESS) {
-      gnutls_deinit(d_conn);
+    if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) {
       throw std::runtime_error("Error setting ciphers to TLS connection");
     }
 
     if (enableTickets && d_ticketsKey) {
       const gnutls_datum_t& key = d_ticketsKey->getKey();
-      if (gnutls_session_ticket_enable_server(d_conn, &key) != GNUTLS_E_SUCCESS) {
-        gnutls_deinit(d_conn);
+      if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) {
         throw std::runtime_error("Error setting the tickets key to TLS connection");
       }
     }
 
-    gnutls_transport_set_int(d_conn, d_socket);
+    gnutls_transport_set_int(d_conn.get(), d_socket);
 
     /* timeouts are in milliseconds */
-    gnutls_handshake_set_timeout(d_conn, timeout * 1000);
-    gnutls_record_set_timeout(d_conn, timeout * 1000);
+    gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000);
+    gnutls_record_set_timeout(d_conn.get(), timeout * 1000);
 
     int ret = 0;
     do {
-      ret = gnutls_handshake(d_conn);
+      ret = gnutls_handshake(d_conn.get());
     }
     while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
   }
 
-  virtual ~GnuTLSConnection() override
-  {
-    if (d_conn) {
-      gnutls_deinit(d_conn);
-    }
-  }
-
   size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
   {
     size_t got = 0;
@@ -729,7 +694,7 @@ public:
     }
 
     do {
-      ssize_t res = gnutls_record_recv(d_conn, (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
+      ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
       if (res == 0) {
         throw std::runtime_error("Error reading from TLS connection");
       }
@@ -763,7 +728,7 @@ public:
     size_t got = 0;
 
     do {
-      ssize_t res = gnutls_record_send(d_conn, (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
+      ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
       if (res == 0) {
         throw std::runtime_error("Error writing to TLS connection");
       }
@@ -785,40 +750,42 @@ public:
   void close() override
   {
     if (d_conn) {
-      gnutls_bye(d_conn, GNUTLS_SHUT_WR);
+      gnutls_bye(d_conn.get(), GNUTLS_SHUT_WR);
     }
   }
 
 private:
-  gnutls_session_t d_conn{nullptr};
+  std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
   std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
 };
 
 class GnuTLSIOCtx: public TLSCtx
 {
 public:
-  GnuTLSIOCtx(const TLSFrontend& fe): d_enableTickets(fe.d_enableTickets)
+  GnuTLSIOCtx(const TLSFrontend& fe): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(fe.d_enableTickets)
   {
     int rc = 0;
     d_ticketsKeyRotationDelay = fe.d_ticketsKeyRotationDelay;
 
-    rc = gnutls_certificate_allocate_credentials(&d_creds);
+    gnutls_certificate_credentials_t creds;
+    rc = gnutls_certificate_allocate_credentials(&creds);
     if (rc != GNUTLS_E_SUCCESS) {
       throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
     }
 
+    d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials);
+    creds = nullptr;
+
     for (const auto& pair : fe.d_certKeyPairs) {
-      rc = gnutls_certificate_set_x509_key_file(d_creds, pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM);
+      rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM);
       if (rc != GNUTLS_E_SUCCESS) {
-        gnutls_certificate_free_credentials(d_creds);
         throw std::runtime_error("Error loading certificate ('" + pair.first + "') and key ('" + pair.second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
       }
     }
 
 #if GNUTLS_VERSION_NUMBER >= 0x030600
-    rc = gnutls_certificate_set_known_dh_params(d_creds, GNUTLS_SEC_PARAM_HIGH);
+    rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH);
     if (rc != GNUTLS_E_SUCCESS) {
-      gnutls_certificate_free_credentials(d_creds);
       throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
     }
 #endif
@@ -837,16 +804,14 @@ public:
       }
     }
     catch(const std::runtime_error& e) {
-      gnutls_certificate_free_credentials(d_creds);
       throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what());
     }
   }
 
   virtual ~GnuTLSIOCtx() override
   {
-    if (d_creds) {
-      gnutls_certificate_free_credentials(d_creds);
-    }
+    d_creds.reset();
+
     if (d_priorityCache) {
       gnutls_priority_deinit(d_priorityCache);
     }
@@ -856,7 +821,7 @@ public:
   {
     handleTicketsKeyRotation(now);
 
-    return std::unique_ptr<GnuTLSConnection>(new GnuTLSConnection(socket, timeout, d_creds, d_priorityCache, d_ticketsKey, d_enableTickets));
+    return std::unique_ptr<GnuTLSConnection>(new GnuTLSConnection(socket, timeout, d_creds.get(), d_priorityCache, d_ticketsKey, d_enableTickets));
   }
 
   void rotateTicketsKey(time_t now) override
@@ -891,7 +856,7 @@ public:
   }
 
 private:
-  gnutls_certificate_credentials_t d_creds{nullptr};
+  std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)> d_creds;
   gnutls_priority_t d_priorityCache{nullptr};
   std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey{nullptr};
   bool d_enableTickets{true};