]> granicus.if.org Git - icinga2/commitdiff
Use new I/O engine in PkiUtility::FetchCert() and PkiUtility::RequestCertificate() 7133/head
authorAlexander A. Klimov <alexander.klimov@icinga.com>
Mon, 25 Feb 2019 17:58:04 +0000 (18:58 +0100)
committerAlexander A. Klimov <alexander.klimov@icinga.com>
Mon, 1 Apr 2019 15:18:00 +0000 (17:18 +0200)
lib/remote/pkiutility.cpp

index e1e78528865ce004cc3746d66d26d6f2a4ba3773..c08989dd8a98a065da1ba00be794f0a9689492e8 100644 (file)
@@ -2,8 +2,11 @@
 
 #include "remote/pkiutility.hpp"
 #include "remote/apilistener.hpp"
+#include "base/defer.hpp"
+#include "base/io-engine.hpp"
 #include "base/logger.hpp"
 #include "base/application.hpp"
+#include "base/tcpsocket.hpp"
 #include "base/tlsutility.hpp"
 #include "base/console.hpp"
 #include "base/tlsstream.hpp"
@@ -14,6 +17,7 @@
 #include "remote/jsonrpc.hpp"
 #include <fstream>
 #include <iostream>
+#include <boost/asio/ssl/context.hpp>
 
 using namespace icinga;
 
@@ -76,41 +80,43 @@ int PkiUtility::SignCsr(const String& csrfile, const String& certfile)
 
 std::shared_ptr<X509> PkiUtility::FetchCert(const String& host, const String& port)
 {
-       TcpSocket::Ptr client = new TcpSocket();
+       std::shared_ptr<boost::asio::ssl::context> sslContext;
 
        try {
-               client->Connect(host, port);
+               sslContext = MakeAsioSslContext();
        } catch (const std::exception& ex) {
                Log(LogCritical, "pki")
-                       << "Cannot connect to host '" << host << "' on port '" << port << "'";
+                       << "Cannot make SSL context.";
                Log(LogDebug, "pki")
-                       << "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex);
+                       << "Cannot make SSL context:\n"  << DiagnosticInformation(ex);
                return std::shared_ptr<X509>();
        }
 
-       std::shared_ptr<SSL_CTX> sslContext;
+       auto stream (std::make_shared<AsioTlsStream>(IoEngine::Get().GetIoService(), *sslContext, host));
 
        try {
-               sslContext = MakeSSLContext();
+               Connect(stream->lowest_layer(), host, port);
        } catch (const std::exception& ex) {
                Log(LogCritical, "pki")
-                       << "Cannot make SSL context.";
+                       << "Cannot connect to host '" << host << "' on port '" << port << "'";
                Log(LogDebug, "pki")
-                       << "Cannot make SSL context:\n"  << DiagnosticInformation(ex);
+                       << "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex);
                return std::shared_ptr<X509>();
        }
 
-       TlsStream::Ptr stream = new TlsStream(client, host, RoleClient, sslContext);
+       auto& sslConn (stream->next_layer());
 
        try {
-               stream->Handshake();
+               sslConn.handshake(sslConn.client);
        } catch (const std::exception& ex) {
                Log(LogCritical, "pki")
                        << "Client TLS handshake failed. (" << ex.what() << ")";
                return std::shared_ptr<X509>();
        }
 
-       return stream->GetPeerCertificate();
+       Defer shutdown ([&sslConn]() { sslConn.shutdown(); });
+
+       return sslConn.GetPeerCertificate();
 }
 
 int PkiUtility::WriteCert(const std::shared_ptr<X509>& cert, const String& trustedfile)
@@ -142,41 +148,43 @@ int PkiUtility::GenTicket(const String& cn, const String& salt, std::ostream& ti
 int PkiUtility::RequestCertificate(const String& host, const String& port, const String& keyfile,
        const String& certfile, const String& cafile, const std::shared_ptr<X509>& trustedCert, const String& ticket)
 {
-       TcpSocket::Ptr client = new TcpSocket();
+       std::shared_ptr<boost::asio::ssl::context> sslContext;
 
        try {
-               client->Connect(host, port);
+               sslContext = MakeAsioSslContext(certfile, keyfile);
        } catch (const std::exception& ex) {
                Log(LogCritical, "cli")
-                       << "Cannot connect to host '" << host << "' on port '" << port << "'";
+                       << "Cannot make SSL context for cert path: '" << certfile << "' key path: '" << keyfile << "' ca path: '" << cafile << "'.";
                Log(LogDebug, "cli")
-                       << "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex);
+                       << "Cannot make SSL context for cert path: '" << certfile << "' key path: '" << keyfile << "' ca path: '" << cafile << "':\n"  << DiagnosticInformation(ex);
                return 1;
        }
 
-       std::shared_ptr<SSL_CTX> sslContext;
+       auto stream (std::make_shared<AsioTlsStream>(IoEngine::Get().GetIoService(), *sslContext, host));
 
        try {
-               sslContext = MakeSSLContext(certfile, keyfile);
+               Connect(stream->lowest_layer(), host, port);
        } catch (const std::exception& ex) {
                Log(LogCritical, "cli")
-                       << "Cannot make SSL context for cert path: '" << certfile << "' key path: '" << keyfile << "' ca path: '" << cafile << "'.";
+                       << "Cannot connect to host '" << host << "' on port '" << port << "'";
                Log(LogDebug, "cli")
-                       << "Cannot make SSL context for cert path: '" << certfile << "' key path: '" << keyfile << "' ca path: '" << cafile << "':\n"  << DiagnosticInformation(ex);
+                       << "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex);
                return 1;
        }
 
-       TlsStream::Ptr stream = new TlsStream(client, host, RoleClient, sslContext);
+       auto& sslConn (stream->next_layer());
 
        try {
-               stream->Handshake();
+               sslConn.handshake(sslConn.client);
        } catch (const std::exception& ex) {
                Log(LogCritical, "cli")
                        << "Client TLS handshake failed: " << DiagnosticInformation(ex, false);
                return 1;
        }
 
-       std::shared_ptr<X509> peerCert = stream->GetPeerCertificate();
+       Defer shutdown ([&sslConn]() { sslConn.shutdown(); });
+
+       auto peerCert (sslConn.GetPeerCertificate());
 
        if (X509_cmp(peerCert.get(), trustedCert.get())) {
                Log(LogCritical, "cli", "Peer certificate does not match trusted certificate.");
@@ -196,36 +204,32 @@ int PkiUtility::RequestCertificate(const String& host, const String& port, const
                { "params", params }
        });
 
-       JsonRpc::SendMessage(stream, request);
-
-       String jsonString;
        Dictionary::Ptr response;
-       StreamReadContext src;
-
-       for (;;) {
-               StreamReadStatus srs = JsonRpc::ReadMessage(stream, &jsonString, src);
-
-               if (srs == StatusEof)
-                       break;
 
-               if (srs != StatusNewItem)
-                       continue;
+       try {
+               JsonRpc::SendMessage(stream, request);
+               stream->flush();
 
-               response = JsonRpc::DecodeMessage(jsonString);
+               for (;;) {
+                       response = JsonRpc::DecodeMessage(JsonRpc::ReadMessage(stream));
 
-               if (response && response->Contains("error")) {
-                       Log(LogCritical, "cli", "Could not fetch valid response. Please check the master log (notice or debug).");
+                       if (response && response->Contains("error")) {
+                               Log(LogCritical, "cli", "Could not fetch valid response. Please check the master log (notice or debug).");
 #ifdef I2_DEBUG
-                       /* we shouldn't expose master errors to the user in production environments */
-                       Log(LogCritical, "cli", response->Get("error"));
+                               /* we shouldn't expose master errors to the user in production environments */
+                               Log(LogCritical, "cli", response->Get("error"));
 #endif /* I2_DEBUG */
-                       return 1;
-               }
+                               return 1;
+                       }
 
-               if (response && (response->Get("id") != msgid))
-                       continue;
+                       if (response && (response->Get("id") != msgid))
+                               continue;
 
-               break;
+                       break;
+               }
+       } catch (...) {
+               Log(LogCritical, "cli", "Could not fetch valid response. Please check the master log.");
+               return 1;
        }
 
        if (!response) {