]> granicus.if.org Git - icinga2/blobdiff - lib/base/tlsstream.cpp
Merge pull request #6718 from Icinga/bugfix/ssl-shutdown
[icinga2] / lib / base / tlsstream.cpp
index c9f14667ff053cb628f43d04ca6b0cf935f1966c..b771b3622600057838a7e76ee8b2d70cf8bfb7c7 100644 (file)
@@ -1,6 +1,6 @@
 /******************************************************************************
  * Icinga 2                                                                   *
- * Copyright (C) 2012-2017 Icinga Development Team (https://www.icinga.com/)  *
+ * Copyright (C) 2012-2018 Icinga Development Team (https://icinga.com/)      *
  *                                                                            *
  * This program is free software; you can redistribute it and/or              *
  * modify it under the terms of the GNU General Public License                *
 #include "base/utility.hpp"
 #include "base/exception.hpp"
 #include "base/logger.hpp"
+#include "base/configuration.hpp"
+#include "base/convert.hpp"
 #include <iostream>
 
 #ifndef _WIN32
 #      include <poll.h>
 #endif /* _WIN32 */
 
+#define TLS_TIMEOUT_SECONDS 10
+
 using namespace icinga;
 
-int I2_EXPORT TlsStream::m_SSLIndex;
-bool I2_EXPORT TlsStream::m_SSLIndexInitialized = false;
+int TlsStream::m_SSLIndex;
+bool TlsStream::m_SSLIndexInitialized = false;
 
 /**
  * Constructor for the TlsStream class.
@@ -40,8 +44,8 @@ bool I2_EXPORT TlsStream::m_SSLIndexInitialized = false;
  */
 TlsStream::TlsStream(const Socket::Ptr& socket, const String& hostname, ConnectionRole role, const std::shared_ptr<SSL_CTX>& sslContext)
        : SocketEvents(socket, this), m_Eof(false), m_HandshakeOK(false), m_VerifyOK(true), m_ErrorCode(0),
-         m_ErrorOccurred(false),  m_Socket(socket), m_Role(role), m_SendQ(new FIFO()), m_RecvQ(new FIFO()),
-         m_CurrentAction(TlsActionNone), m_Retry(false), m_Shutdown(false)
+       m_ErrorOccurred(false),  m_Socket(socket), m_Role(role), m_SendQ(new FIFO()), m_RecvQ(new FIFO()),
+       m_CurrentAction(TlsActionNone), m_Retry(false), m_Shutdown(false)
 {
        std::ostringstream msgbuf;
        char errbuf[120];
@@ -82,15 +86,15 @@ TlsStream::TlsStream(const Socket::Ptr& socket, const String& hostname, Connecti
        }
 }
 
-TlsStream::~TlsStream(void)
+TlsStream::~TlsStream()
 {
        CloseInternal(true);
 }
 
 int TlsStream::ValidateCertificate(int preverify_ok, X509_STORE_CTX *ctx)
 {
-       SSL *ssl = static_cast<SSL *>(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
-       TlsStream *stream = static_cast<TlsStream *>(SSL_get_ex_data(ssl, m_SSLIndex));
+       auto *ssl = static_cast<SSL *>(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
+       auto *stream = static_cast<TlsStream *>(SSL_get_ex_data(ssl, m_SSLIndex));
 
        if (!preverify_ok) {
                stream->m_VerifyOK = false;
@@ -104,12 +108,12 @@ int TlsStream::ValidateCertificate(int preverify_ok, X509_STORE_CTX *ctx)
        return 1;
 }
 
-bool TlsStream::IsVerifyOK(void) const
+bool TlsStream::IsVerifyOK() const
 {
        return m_VerifyOK;
 }
 
-String TlsStream::GetVerifyError(void) const
+String TlsStream::GetVerifyError() const
 {
        return m_VerifyError;
 }
@@ -119,7 +123,7 @@ String TlsStream::GetVerifyError(void) const
  *
  * @returns The X509 certificate.
  */
-std::shared_ptr<X509> TlsStream::GetClientCertificate(void) const
+std::shared_ptr<X509> TlsStream::GetClientCertificate() const
 {
        boost::mutex::scoped_lock lock(m_Mutex);
        return std::shared_ptr<X509>(SSL_get_certificate(m_SSL.get()), &Utility::NullDeleter);
@@ -130,7 +134,7 @@ std::shared_ptr<X509> TlsStream::GetClientCertificate(void) const
  *
  * @returns The X509 certificate.
  */
-std::shared_ptr<X509> TlsStream::GetPeerCertificate(void) const
+std::shared_ptr<X509> TlsStream::GetPeerCertificate() const
 {
        boost::mutex::scoped_lock lock(m_Mutex);
        return std::shared_ptr<X509>(SSL_get_peer_certificate(m_SSL.get()), X509_free);
@@ -149,12 +153,17 @@ void TlsStream::OnEvent(int revents)
        char buffer[64 * 1024];
 
        if (m_CurrentAction == TlsActionNone) {
-               if (revents & (POLLIN | POLLERR | POLLHUP))
+               bool corked = IsCorked();
+               if (!corked && (revents & (POLLIN | POLLERR | POLLHUP)))
                        m_CurrentAction = TlsActionRead;
                else if (m_SendQ->GetAvailableBytes() > 0 && (revents & POLLOUT))
                        m_CurrentAction = TlsActionWrite;
                else {
-                       ChangeEvents(POLLIN);
+                       if (corked)
+                               ChangeEvents(0);
+                       else
+                               ChangeEvents(POLLIN);
+
                        return;
                }
        }
@@ -166,6 +175,8 @@ void TlsStream::OnEvent(int revents)
         */
        ERR_clear_error();
 
+       size_t readTotal = 0;
+
        switch (m_CurrentAction) {
                case TlsActionRead:
                        do {
@@ -174,8 +185,29 @@ void TlsStream::OnEvent(int revents)
                                if (rc > 0) {
                                        m_RecvQ->Write(buffer, rc);
                                        success = true;
+
+                                       readTotal += rc;
                                }
-                       } while (rc > 0);
+
+#ifdef I2_DEBUG /* I2_DEBUG */
+                               Log(LogDebug, "TlsStream")
+                                       << "Read bytes: " << rc << " Total read bytes: " << readTotal;
+#endif /* I2_DEBUG */
+                               /* Limit read size. We cannot do this check inside the while loop
+                                * since below should solely check whether OpenSSL has more data
+                                * or not. */
+                               if (readTotal >= 64 * 1024) {
+#ifdef I2_DEBUG /* I2_DEBUG */
+                                       Log(LogWarning, "TlsStream")
+                                               << "Maximum read bytes exceeded: " << readTotal;
+#endif /* I2_DEBUG */
+                                       break;
+                               }
+
+                       /* Use OpenSSL's state machine here to determine whether we need
+                        * to read more data. SSL_has_pending() is available with 1.1.0.
+                        */
+                       } while (SSL_pending(m_SSL.get()));
 
                        if (success)
                                m_CV.notify_all();
@@ -257,7 +289,7 @@ void TlsStream::OnEvent(int revents)
 
                lock.unlock();
 
-               while (m_RecvQ->IsDataAvailable() && IsHandlingEvents())
+               while (!IsCorked() && m_RecvQ->IsDataAvailable() && IsHandlingEvents())
                        SignalDataAvailable();
        }
 
@@ -269,24 +301,29 @@ void TlsStream::OnEvent(int revents)
        }
 }
 
-void TlsStream::HandleError(void) const
+void TlsStream::HandleError() const
 {
        if (m_ErrorOccurred) {
                BOOST_THROW_EXCEPTION(openssl_error()
-                   << boost::errinfo_api_function("TlsStream::OnEvent")
-                   << errinfo_openssl_error(m_ErrorCode));
+                       << boost::errinfo_api_function("TlsStream::OnEvent")
+                       << errinfo_openssl_error(m_ErrorCode));
        }
 }
 
-void TlsStream::Handshake(void)
+void TlsStream::Handshake()
 {
        boost::mutex::scoped_lock lock(m_Mutex);
 
        m_CurrentAction = TlsActionHandshake;
        ChangeEvents(POLLOUT);
 
-       while (!m_HandshakeOK && !m_ErrorOccurred && !m_Eof)
-               m_CV.wait(lock);
+       boost::system_time const timeout = boost::get_system_time() + boost::posix_time::milliseconds(long(Configuration::TlsHandshakeTimeout * 1000));
+
+       while (!m_HandshakeOK && !m_ErrorOccurred && !m_Eof && timeout > boost::get_system_time())
+               m_CV.timed_wait(lock, timeout);
+
+       if (timeout < boost::get_system_time())
+               BOOST_THROW_EXCEPTION(std::runtime_error("Timeout was reached (" + Convert::ToString(Configuration::TlsHandshakeTimeout) + ") during TLS handshake."));
 
        if (m_Eof)
                BOOST_THROW_EXCEPTION(std::runtime_error("Socket was closed during TLS handshake."));
@@ -332,7 +369,7 @@ void TlsStream::Write(const void *buffer, size_t count)
        ChangeEvents(POLLIN|POLLOUT);
 }
 
-void TlsStream::Shutdown(void)
+void TlsStream::Shutdown()
 {
        m_Shutdown = true;
        ChangeEvents(POLLOUT);
@@ -341,7 +378,7 @@ void TlsStream::Shutdown(void)
 /**
  * Closes the stream.
  */
-void TlsStream::Close(void)
+void TlsStream::Close()
 {
        CloseInternal(false);
 }
@@ -365,7 +402,20 @@ void TlsStream::CloseInternal(bool inDestructor)
        if (!m_SSL)
                return;
 
-       (void)SSL_shutdown(m_SSL.get());
+       /* https://www.openssl.org/docs/manmaster/man3/SSL_shutdown.html
+        *
+        * It is recommended to do a bidirectional shutdown by checking
+        * the return value of SSL_shutdown() and call it again until
+        * it returns 1 or a fatal error. A maximum of 2x pending + 2x data
+        * is recommended.
+         */
+       int rc = 0;
+
+       for (int i = 0; i < 4; i++) {
+               if ((rc = SSL_shutdown(m_SSL.get())))
+                       break;
+       }
+
        m_SSL.reset();
 
        m_Socket->Close();
@@ -374,24 +424,36 @@ void TlsStream::CloseInternal(bool inDestructor)
        m_CV.notify_all();
 }
 
-bool TlsStream::IsEof(void) const
+bool TlsStream::IsEof() const
 {
-       return m_Eof;
+       return m_Eof && m_RecvQ->GetAvailableBytes() < 1u;
 }
 
-bool TlsStream::SupportsWaiting(void) const
+bool TlsStream::SupportsWaiting() const
 {
        return true;
 }
 
-bool TlsStream::IsDataAvailable(void) const
+bool TlsStream::IsDataAvailable() const
 {
        boost::mutex::scoped_lock lock(m_Mutex);
 
        return m_RecvQ->GetAvailableBytes() > 0;
 }
 
-Socket::Ptr TlsStream::GetSocket(void) const
+void TlsStream::SetCorked(bool corked)
+{
+       Stream::SetCorked(corked);
+
+       boost::mutex::scoped_lock lock(m_Mutex);
+
+       if (corked)
+               m_CurrentAction = TlsActionNone;
+       else
+               ChangeEvents(POLLIN | POLLOUT);
+}
+
+Socket::Ptr TlsStream::GetSocket() const
 {
        return m_Socket;
 }