class OpenSSLTLSConnection: public TLSConnection
{
public:
- OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx)
+ OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx), SSL_free))
{
d_socket = socket;
- d_conn = SSL_new(tlsCtx);
if (!d_conn) {
vinfolog("Error creating TLS object");
throw std::runtime_error("Error creating TLS object");
}
- if (!SSL_set_fd(d_conn, d_socket)) {
- SSL_free(d_conn);
- d_conn = nullptr;
+ if (!SSL_set_fd(d_conn.get(), d_socket)) {
throw std::runtime_error("Error assigning socket");
}
int res = 0;
do {
- res = SSL_accept(d_conn);
+ res = SSL_accept(d_conn.get());
if (res < 0) {
- try {
- handleIORequest(res, timeout);
- }
- catch(...) {
- SSL_free(d_conn);
- d_conn = nullptr;
- throw;
- }
+ handleIORequest(res, timeout);
}
}
while (res < 0);
if (res != 1) {
- SSL_free(d_conn);
- d_conn = nullptr;
throw std::runtime_error("Error accepting TLS connection");
}
}
- virtual ~OpenSSLTLSConnection() override
- {
- if (d_conn) {
- SSL_free(d_conn);
- }
- }
-
void handleIORequest(int res, unsigned int timeout)
{
- int error = SSL_get_error(d_conn, res);
+ int error = SSL_get_error(d_conn.get(), res);
if (error == SSL_ERROR_WANT_READ) {
res = waitForData(d_socket, timeout);
if (res <= 0) {
}
do {
- int res = SSL_read(d_conn, (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
+ int res = SSL_read(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
if (res == 0) {
throw std::runtime_error("Error reading from TLS connection");
}
{
size_t got = 0;
do {
- int res = SSL_write(d_conn, (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
+ int res = SSL_write(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
if (res == 0) {
throw std::runtime_error("Error writing to TLS connection");
}
void close() override
{
if (d_conn) {
- SSL_shutdown(d_conn);
+ SSL_shutdown(d_conn.get());
}
}
private:
- SSL* d_conn{nullptr};
+ std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
};
class OpenSSLTLSIOCtx: public TLSCtx
{
public:
- OpenSSLTLSIOCtx(const TLSFrontend& fe): d_ticketKeys(fe.d_numberOfTicketsKeys)
+ OpenSSLTLSIOCtx(const TLSFrontend& fe): d_ticketKeys(fe.d_numberOfTicketsKeys), d_tlsCtx(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(nullptr, SSL_CTX_free))
{
d_ticketsKeyRotationDelay = fe.d_ticketsKeyRotationDelay;
}
}
- d_tlsCtx = SSL_CTX_new(SSLv23_server_method());
+ d_tlsCtx = std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(SSL_CTX_new(SSLv23_server_method()), SSL_CTX_free);
if (!d_tlsCtx) {
ERR_print_errors_fp(stderr);
throw std::runtime_error("Error creating TLS context on " + fe.d_addr.toStringWithPort());
}
/* use the internal built-in cache to store sessions */
- SSL_CTX_set_session_cache_mode(d_tlsCtx, SSL_SESS_CACHE_SERVER);
+ SSL_CTX_set_session_cache_mode(d_tlsCtx.get(), SSL_SESS_CACHE_SERVER);
/* use our own ticket keys handler so we can rotate them */
- SSL_CTX_set_tlsext_ticket_key_cb(d_tlsCtx, &OpenSSLTLSIOCtx::ticketKeyCb);
- SSL_CTX_set_ex_data(d_tlsCtx, s_ticketsKeyIndex, this);
- SSL_CTX_set_options(d_tlsCtx, sslOptions);
+ SSL_CTX_set_tlsext_ticket_key_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
+ SSL_CTX_set_ex_data(d_tlsCtx.get(), s_ticketsKeyIndex, this);
+ SSL_CTX_set_options(d_tlsCtx.get(), sslOptions);
#if defined(SSL_CTX_set_ecdh_auto)
- SSL_CTX_set_ecdh_auto(d_tlsCtx, 1);
+ SSL_CTX_set_ecdh_auto(d_tlsCtx.get(), 1);
#endif
for (const auto& pair : fe.d_certKeyPairs) {
- if (SSL_CTX_use_certificate_chain_file(d_tlsCtx, pair.first.c_str()) != 1) {
+ if (SSL_CTX_use_certificate_chain_file(d_tlsCtx.get(), pair.first.c_str()) != 1) {
ERR_print_errors_fp(stderr);
- SSL_CTX_free(d_tlsCtx);
- d_tlsCtx = nullptr;
throw std::runtime_error("Error loading certificate from " + pair.first + " for the TLS context on " + fe.d_addr.toStringWithPort());
}
- if (SSL_CTX_use_PrivateKey_file(d_tlsCtx, pair.second.c_str(), SSL_FILETYPE_PEM) != 1) {
+ if (SSL_CTX_use_PrivateKey_file(d_tlsCtx.get(), pair.second.c_str(), SSL_FILETYPE_PEM) != 1) {
ERR_print_errors_fp(stderr);
- SSL_CTX_free(d_tlsCtx);
- d_tlsCtx = nullptr;
throw std::runtime_error("Error loading key from " + pair.second + " for the TLS context on " + fe.d_addr.toStringWithPort());
}
}
if (!fe.d_ciphers.empty()) {
- if (SSL_CTX_set_cipher_list(d_tlsCtx, fe.d_ciphers.c_str()) != 1) {
+ if (SSL_CTX_set_cipher_list(d_tlsCtx.get(), fe.d_ciphers.c_str()) != 1) {
ERR_print_errors_fp(stderr);
- SSL_CTX_free(d_tlsCtx);
- d_tlsCtx = nullptr;
throw std::runtime_error("Error setting the cipher list to '" + fe.d_ciphers + "' for the TLS context on " + fe.d_addr.toStringWithPort());
}
}
}
}
catch (const std::exception& e) {
- SSL_CTX_free(d_tlsCtx);
- d_tlsCtx = nullptr;
throw;
}
}
virtual ~OpenSSLTLSIOCtx() override
{
- if (d_tlsCtx) {
- SSL_CTX_free(d_tlsCtx);
- }
+ d_tlsCtx.reset();
if (s_users.fetch_sub(1) == 1) {
ERR_free_strings();
{
handleTicketsKeyRotation(now);
- return std::unique_ptr<OpenSSLTLSConnection>(new OpenSSLTLSConnection(socket, timeout, d_tlsCtx));
+ return std::unique_ptr<OpenSSLTLSConnection>(new OpenSSLTLSConnection(socket, timeout, d_tlsCtx.get()));
}
void rotateTicketsKey(time_t now) override
private:
OpenSSLTLSTicketKeysRing d_ticketKeys;
- SSL_CTX* d_tlsCtx{nullptr};
+ std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx;
static std::atomic<uint64_t> s_users;
};
{
public:
- GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_ticketsKey(ticketsKey)
+ GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey)
{
unsigned int sslOptions = GNUTLS_SERVER;
#ifdef GNUTLS_NO_SIGNAL
d_socket = socket;
- if (gnutls_init(&d_conn, sslOptions) != GNUTLS_E_SUCCESS) {
+ gnutls_session_t conn;
+ if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
throw std::runtime_error("Error creating TLS connection");
}
- if (gnutls_credentials_set(d_conn, GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) {
- gnutls_deinit(d_conn);
+ d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
+ conn = nullptr;
+
+ if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) {
throw std::runtime_error("Error setting certificate and key to TLS connection");
}
- if (gnutls_priority_set(d_conn, priorityCache) != GNUTLS_E_SUCCESS) {
- gnutls_deinit(d_conn);
+ if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) {
throw std::runtime_error("Error setting ciphers to TLS connection");
}
if (enableTickets && d_ticketsKey) {
const gnutls_datum_t& key = d_ticketsKey->getKey();
- if (gnutls_session_ticket_enable_server(d_conn, &key) != GNUTLS_E_SUCCESS) {
- gnutls_deinit(d_conn);
+ if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) {
throw std::runtime_error("Error setting the tickets key to TLS connection");
}
}
- gnutls_transport_set_int(d_conn, d_socket);
+ gnutls_transport_set_int(d_conn.get(), d_socket);
/* timeouts are in milliseconds */
- gnutls_handshake_set_timeout(d_conn, timeout * 1000);
- gnutls_record_set_timeout(d_conn, timeout * 1000);
+ gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000);
+ gnutls_record_set_timeout(d_conn.get(), timeout * 1000);
int ret = 0;
do {
- ret = gnutls_handshake(d_conn);
+ ret = gnutls_handshake(d_conn.get());
}
while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
}
- virtual ~GnuTLSConnection() override
- {
- if (d_conn) {
- gnutls_deinit(d_conn);
- }
- }
-
size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
{
size_t got = 0;
}
do {
- ssize_t res = gnutls_record_recv(d_conn, (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
+ ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
if (res == 0) {
throw std::runtime_error("Error reading from TLS connection");
}
size_t got = 0;
do {
- ssize_t res = gnutls_record_send(d_conn, (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
+ ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
if (res == 0) {
throw std::runtime_error("Error writing to TLS connection");
}
void close() override
{
if (d_conn) {
- gnutls_bye(d_conn, GNUTLS_SHUT_WR);
+ gnutls_bye(d_conn.get(), GNUTLS_SHUT_WR);
}
}
private:
- gnutls_session_t d_conn{nullptr};
+ std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
};
class GnuTLSIOCtx: public TLSCtx
{
public:
- GnuTLSIOCtx(const TLSFrontend& fe): d_enableTickets(fe.d_enableTickets)
+ GnuTLSIOCtx(const TLSFrontend& fe): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(fe.d_enableTickets)
{
int rc = 0;
d_ticketsKeyRotationDelay = fe.d_ticketsKeyRotationDelay;
- rc = gnutls_certificate_allocate_credentials(&d_creds);
+ gnutls_certificate_credentials_t creds;
+ rc = gnutls_certificate_allocate_credentials(&creds);
if (rc != GNUTLS_E_SUCCESS) {
throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
}
+ d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials);
+ creds = nullptr;
+
for (const auto& pair : fe.d_certKeyPairs) {
- rc = gnutls_certificate_set_x509_key_file(d_creds, pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM);
+ rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM);
if (rc != GNUTLS_E_SUCCESS) {
- gnutls_certificate_free_credentials(d_creds);
throw std::runtime_error("Error loading certificate ('" + pair.first + "') and key ('" + pair.second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
}
}
#if GNUTLS_VERSION_NUMBER >= 0x030600
- rc = gnutls_certificate_set_known_dh_params(d_creds, GNUTLS_SEC_PARAM_HIGH);
+ rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH);
if (rc != GNUTLS_E_SUCCESS) {
- gnutls_certificate_free_credentials(d_creds);
throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
}
#endif
}
}
catch(const std::runtime_error& e) {
- gnutls_certificate_free_credentials(d_creds);
throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what());
}
}
virtual ~GnuTLSIOCtx() override
{
- if (d_creds) {
- gnutls_certificate_free_credentials(d_creds);
- }
+ d_creds.reset();
+
if (d_priorityCache) {
gnutls_priority_deinit(d_priorityCache);
}
{
handleTicketsKeyRotation(now);
- return std::unique_ptr<GnuTLSConnection>(new GnuTLSConnection(socket, timeout, d_creds, d_priorityCache, d_ticketsKey, d_enableTickets));
+ return std::unique_ptr<GnuTLSConnection>(new GnuTLSConnection(socket, timeout, d_creds.get(), d_priorityCache, d_ticketsKey, d_enableTickets));
}
void rotateTicketsKey(time_t now) override
}
private:
- gnutls_certificate_credentials_t d_creds{nullptr};
+ std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)> d_creds;
gnutls_priority_t d_priorityCache{nullptr};
std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey{nullptr};
bool d_enableTickets{true};