]> granicus.if.org Git - icinga2/blob - lib/base/tlsstream.cpp
Merge pull request #7185 from Icinga/bugfix/gelfwriter-wrong-log-facility
[icinga2] / lib / base / tlsstream.cpp
1 /* Icinga 2 | (c) 2012 Icinga GmbH | GPLv2+ */
2
3 #include "base/application.hpp"
4 #include "base/tlsstream.hpp"
5 #include "base/utility.hpp"
6 #include "base/exception.hpp"
7 #include "base/logger.hpp"
8 #include "base/configuration.hpp"
9 #include "base/convert.hpp"
10 #include <boost/asio/ssl/context.hpp>
11 #include <boost/asio/ssl/verify_context.hpp>
12 #include <boost/asio/ssl/verify_mode.hpp>
13 #include <iostream>
14 #include <openssl/ssl.h>
15 #include <openssl/tls1.h>
16 #include <openssl/x509.h>
17 #include <sstream>
18
19 #ifndef _WIN32
20 #       include <poll.h>
21 #endif /* _WIN32 */
22
23 #define TLS_TIMEOUT_SECONDS 10
24
25 using namespace icinga;
26
27 int TlsStream::m_SSLIndex;
28 bool TlsStream::m_SSLIndexInitialized = false;
29
30 /**
31  * Constructor for the TlsStream class.
32  *
33  * @param role The role of the client.
34  * @param sslContext The SSL context for the client.
35  */
36 TlsStream::TlsStream(const Socket::Ptr& socket, const String& hostname, ConnectionRole role, const std::shared_ptr<SSL_CTX>& sslContext)
37         : TlsStream(socket, hostname, role, sslContext.get())
38 {
39 }
40
41 /**
42  * Constructor for the TlsStream class.
43  *
44  * @param role The role of the client.
45  * @param sslContext The SSL context for the client.
46  */
47 TlsStream::TlsStream(const Socket::Ptr& socket, const String& hostname, ConnectionRole role, const std::shared_ptr<boost::asio::ssl::context>& sslContext)
48         : TlsStream(socket, hostname, role, sslContext->native_handle())
49 {
50 }
51
52 /**
53  * Constructor for the TlsStream class.
54  *
55  * @param role The role of the client.
56  * @param sslContext The SSL context for the client.
57  */
58 TlsStream::TlsStream(const Socket::Ptr& socket, const String& hostname, ConnectionRole role, SSL_CTX* sslContext)
59         : SocketEvents(socket), m_Eof(false), m_HandshakeOK(false), m_VerifyOK(true), m_ErrorCode(0),
60         m_ErrorOccurred(false),  m_Socket(socket), m_Role(role), m_SendQ(new FIFO()), m_RecvQ(new FIFO()),
61         m_CurrentAction(TlsActionNone), m_Retry(false), m_Shutdown(false)
62 {
63         std::ostringstream msgbuf;
64         char errbuf[256];
65
66         m_SSL = std::shared_ptr<SSL>(SSL_new(sslContext), SSL_free);
67
68         if (!m_SSL) {
69                 msgbuf << "SSL_new() failed with code " << ERR_peek_error() << ", \"" << ERR_error_string(ERR_peek_error(), errbuf) << "\"";
70                 Log(LogCritical, "TlsStream", msgbuf.str());
71
72                 BOOST_THROW_EXCEPTION(openssl_error()
73                         << boost::errinfo_api_function("SSL_new")
74                         << errinfo_openssl_error(ERR_peek_error()));
75         }
76
77         if (!m_SSLIndexInitialized) {
78                 m_SSLIndex = SSL_get_ex_new_index(0, const_cast<char *>("TlsStream"), nullptr, nullptr, nullptr);
79                 m_SSLIndexInitialized = true;
80         }
81
82         SSL_set_ex_data(m_SSL.get(), m_SSLIndex, this);
83
84         SSL_set_verify(m_SSL.get(), SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, &TlsStream::ValidateCertificate);
85
86         socket->MakeNonBlocking();
87
88         SSL_set_fd(m_SSL.get(), socket->GetFD());
89
90         if (m_Role == RoleServer)
91                 SSL_set_accept_state(m_SSL.get());
92         else {
93 #ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
94                 if (!hostname.IsEmpty())
95                         SSL_set_tlsext_host_name(m_SSL.get(), hostname.CStr());
96 #endif /* SSL_CTRL_SET_TLSEXT_HOSTNAME */
97
98                 SSL_set_connect_state(m_SSL.get());
99         }
100 }
101
102 TlsStream::~TlsStream()
103 {
104         CloseInternal(true);
105 }
106
107 int TlsStream::ValidateCertificate(int preverify_ok, X509_STORE_CTX *ctx)
108 {
109         auto *ssl = static_cast<SSL *>(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
110         auto *stream = static_cast<TlsStream *>(SSL_get_ex_data(ssl, m_SSLIndex));
111
112         if (!preverify_ok) {
113                 stream->m_VerifyOK = false;
114
115                 std::ostringstream msgbuf;
116                 int err = X509_STORE_CTX_get_error(ctx);
117                 msgbuf << "code " << err << ": " << X509_verify_cert_error_string(err);
118                 stream->m_VerifyError = msgbuf.str();
119         }
120
121         return 1;
122 }
123
124 bool TlsStream::IsVerifyOK() const
125 {
126         return m_VerifyOK;
127 }
128
129 String TlsStream::GetVerifyError() const
130 {
131         return m_VerifyError;
132 }
133
134 /**
135  * Retrieves the X509 certficate for this client.
136  *
137  * @returns The X509 certificate.
138  */
139 std::shared_ptr<X509> TlsStream::GetClientCertificate() const
140 {
141         boost::mutex::scoped_lock lock(m_Mutex);
142         return std::shared_ptr<X509>(SSL_get_certificate(m_SSL.get()), &Utility::NullDeleter);
143 }
144
145 /**
146  * Retrieves the X509 certficate for the peer.
147  *
148  * @returns The X509 certificate.
149  */
150 std::shared_ptr<X509> TlsStream::GetPeerCertificate() const
151 {
152         boost::mutex::scoped_lock lock(m_Mutex);
153         return std::shared_ptr<X509>(SSL_get_peer_certificate(m_SSL.get()), X509_free);
154 }
155
156 void TlsStream::OnEvent(int revents)
157 {
158         int rc;
159         size_t count;
160
161         boost::mutex::scoped_lock lock(m_Mutex);
162
163         if (!m_SSL)
164                 return;
165
166         char buffer[64 * 1024];
167
168         if (m_CurrentAction == TlsActionNone) {
169                 if (revents & (POLLIN | POLLERR | POLLHUP))
170                         m_CurrentAction = TlsActionRead;
171                 else if (m_SendQ->GetAvailableBytes() > 0 && (revents & POLLOUT))
172                         m_CurrentAction = TlsActionWrite;
173                 else {
174                         ChangeEvents(POLLIN);
175
176                         return;
177                 }
178         }
179
180         bool success = false;
181
182         /* Clear error queue for this thread before using SSL_{read,write,do_handshake}.
183          * Otherwise SSL_*_error() does not work reliably.
184          */
185         ERR_clear_error();
186
187         size_t readTotal = 0;
188
189         switch (m_CurrentAction) {
190                 case TlsActionRead:
191                         do {
192                                 rc = SSL_read(m_SSL.get(), buffer, sizeof(buffer));
193
194                                 if (rc > 0) {
195                                         m_RecvQ->Write(buffer, rc);
196                                         success = true;
197
198                                         readTotal += rc;
199                                 }
200
201 #ifdef I2_DEBUG /* I2_DEBUG */
202                                 Log(LogDebug, "TlsStream")
203                                         << "Read bytes: " << rc << " Total read bytes: " << readTotal;
204 #endif /* I2_DEBUG */
205                                 /* Limit read size. We cannot do this check inside the while loop
206                                  * since below should solely check whether OpenSSL has more data
207                                  * or not. */
208                                 if (readTotal >= 64 * 1024) {
209 #ifdef I2_DEBUG /* I2_DEBUG */
210                                         Log(LogWarning, "TlsStream")
211                                                 << "Maximum read bytes exceeded: " << readTotal;
212 #endif /* I2_DEBUG */
213                                         break;
214                                 }
215
216                         /* Use OpenSSL's state machine here to determine whether we need
217                          * to read more data. SSL_has_pending() is available with 1.1.0.
218                          */
219                         } while (SSL_pending(m_SSL.get()));
220
221                         if (success)
222                                 m_CV.notify_all();
223
224                         break;
225                 case TlsActionWrite:
226                         count = m_SendQ->Peek(buffer, sizeof(buffer), true);
227
228                         rc = SSL_write(m_SSL.get(), buffer, count);
229
230                         if (rc > 0) {
231                                 m_SendQ->Read(nullptr, rc, true);
232                                 success = true;
233                         }
234
235                         break;
236                 case TlsActionHandshake:
237                         rc = SSL_do_handshake(m_SSL.get());
238
239                         if (rc > 0) {
240                                 success = true;
241                                 m_HandshakeOK = true;
242                                 m_CV.notify_all();
243                         }
244
245                         break;
246                 default:
247                         VERIFY(!"Invalid TlsAction");
248         }
249
250         if (rc <= 0) {
251                 int err = SSL_get_error(m_SSL.get(), rc);
252
253                 switch (err) {
254                         case SSL_ERROR_WANT_READ:
255                                 m_Retry = true;
256                                 ChangeEvents(POLLIN);
257
258                                 break;
259                         case SSL_ERROR_WANT_WRITE:
260                                 m_Retry = true;
261                                 ChangeEvents(POLLOUT);
262
263                                 break;
264                         case SSL_ERROR_ZERO_RETURN:
265                                 lock.unlock();
266
267                                 Close();
268
269                                 return;
270                         default:
271                                 m_ErrorCode = ERR_peek_error();
272                                 m_ErrorOccurred = true;
273
274                                 if (m_ErrorCode != 0) {
275                                         char errbuf[256];
276                                         Log(LogWarning, "TlsStream")
277                                                 << "OpenSSL error: " << ERR_error_string(m_ErrorCode, errbuf);
278                                 } else {
279                                         Log(LogWarning, "TlsStream", "TLS stream was disconnected.");
280                                 }
281
282                                 lock.unlock();
283
284                                 Close();
285
286                                 return;
287                 }
288         }
289
290         if (success) {
291                 m_CurrentAction = TlsActionNone;
292
293                 if (!m_Eof) {
294                         if (m_SendQ->GetAvailableBytes() > 0)
295                                 ChangeEvents(POLLIN|POLLOUT);
296                         else
297                                 ChangeEvents(POLLIN);
298                 }
299
300                 lock.unlock();
301
302                 while (m_RecvQ->IsDataAvailable() && IsHandlingEvents())
303                         SignalDataAvailable();
304         }
305
306         if (m_Shutdown && !m_SendQ->IsDataAvailable()) {
307                 if (!success)
308                         lock.unlock();
309
310                 Close();
311         }
312 }
313
314 void TlsStream::HandleError() const
315 {
316         if (m_ErrorOccurred) {
317                 BOOST_THROW_EXCEPTION(openssl_error()
318                         << boost::errinfo_api_function("TlsStream::OnEvent")
319                         << errinfo_openssl_error(m_ErrorCode));
320         }
321 }
322
323 void TlsStream::Handshake()
324 {
325         boost::mutex::scoped_lock lock(m_Mutex);
326
327         m_CurrentAction = TlsActionHandshake;
328         ChangeEvents(POLLOUT);
329
330         boost::system_time const timeout = boost::get_system_time() + boost::posix_time::milliseconds(long(Configuration::TlsHandshakeTimeout * 1000));
331
332         while (!m_HandshakeOK && !m_ErrorOccurred && !m_Eof && timeout > boost::get_system_time())
333                 m_CV.timed_wait(lock, timeout);
334
335         if (timeout < boost::get_system_time())
336                 BOOST_THROW_EXCEPTION(std::runtime_error("Timeout was reached (" + Convert::ToString(Configuration::TlsHandshakeTimeout) + ") during TLS handshake."));
337
338         if (m_Eof)
339                 BOOST_THROW_EXCEPTION(std::runtime_error("Socket was closed during TLS handshake."));
340
341         HandleError();
342 }
343
344 /**
345  * Processes data for the stream.
346  */
347 size_t TlsStream::Peek(void *buffer, size_t count, bool allow_partial)
348 {
349         boost::mutex::scoped_lock lock(m_Mutex);
350
351         if (!allow_partial)
352                 while (m_RecvQ->GetAvailableBytes() < count && !m_ErrorOccurred && !m_Eof)
353                         m_CV.wait(lock);
354
355         HandleError();
356
357         return m_RecvQ->Peek(buffer, count, true);
358 }
359
360 size_t TlsStream::Read(void *buffer, size_t count, bool allow_partial)
361 {
362         boost::mutex::scoped_lock lock(m_Mutex);
363
364         if (!allow_partial)
365                 while (m_RecvQ->GetAvailableBytes() < count && !m_ErrorOccurred && !m_Eof)
366                         m_CV.wait(lock);
367
368         HandleError();
369
370         return m_RecvQ->Read(buffer, count, true);
371 }
372
373 void TlsStream::Write(const void *buffer, size_t count)
374 {
375         boost::mutex::scoped_lock lock(m_Mutex);
376
377         m_SendQ->Write(buffer, count);
378
379         ChangeEvents(POLLIN|POLLOUT);
380 }
381
382 void TlsStream::Shutdown()
383 {
384         m_Shutdown = true;
385         ChangeEvents(POLLOUT);
386 }
387
388 /**
389  * Closes the stream.
390  */
391 void TlsStream::Close()
392 {
393         CloseInternal(false);
394 }
395
396 void TlsStream::CloseInternal(bool inDestructor)
397 {
398         if (m_Eof)
399                 return;
400
401         m_Eof = true;
402
403         if (!inDestructor)
404                 SignalDataAvailable();
405
406         SocketEvents::Unregister();
407
408         Stream::Close();
409
410         boost::mutex::scoped_lock lock(m_Mutex);
411
412         if (!m_SSL)
413                 return;
414
415         /* https://www.openssl.org/docs/manmaster/man3/SSL_shutdown.html
416          *
417          * It is recommended to do a bidirectional shutdown by checking
418          * the return value of SSL_shutdown() and call it again until
419          * it returns 1 or a fatal error. A maximum of 2x pending + 2x data
420          * is recommended.
421          */
422         int rc = 0;
423
424         for (int i = 0; i < 4; i++) {
425                 if ((rc = SSL_shutdown(m_SSL.get())))
426                         break;
427         }
428
429         m_SSL.reset();
430
431         m_Socket->Close();
432         m_Socket.reset();
433
434         m_CV.notify_all();
435 }
436
437 bool TlsStream::IsEof() const
438 {
439         return m_Eof && m_RecvQ->GetAvailableBytes() < 1u;
440 }
441
442 bool TlsStream::SupportsWaiting() const
443 {
444         return true;
445 }
446
447 bool TlsStream::IsDataAvailable() const
448 {
449         boost::mutex::scoped_lock lock(m_Mutex);
450
451         return m_RecvQ->GetAvailableBytes() > 0;
452 }
453
454 Socket::Ptr TlsStream::GetSocket() const
455 {
456         return m_Socket;
457 }
458
459 bool UnbufferedAsioTlsStream::IsVerifyOK() const
460 {
461         return m_VerifyOK;
462 }
463
464 String UnbufferedAsioTlsStream::GetVerifyError() const
465 {
466         return m_VerifyError;
467 }
468
469 std::shared_ptr<X509> UnbufferedAsioTlsStream::GetPeerCertificate()
470 {
471         return std::shared_ptr<X509>(SSL_get_peer_certificate(native_handle()), X509_free);
472 }
473
474 void UnbufferedAsioTlsStream::BeforeHandshake(handshake_type type)
475 {
476         namespace ssl = boost::asio::ssl;
477
478         set_verify_mode(ssl::verify_peer | ssl::verify_client_once);
479
480         set_verify_callback([this](bool preverified, ssl::verify_context& ctx) {
481                 if (!preverified) {
482                         m_VerifyOK = false;
483
484                         std::ostringstream msgbuf;
485                         int err = X509_STORE_CTX_get_error(ctx.native_handle());
486
487                         msgbuf << "code " << err << ": " << X509_verify_cert_error_string(err);
488                         m_VerifyError = msgbuf.str();
489                 }
490
491                 return true;
492         });
493
494 #ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
495         if (type == client && !m_Hostname.IsEmpty()) {
496                 String environmentName = Application::GetAppEnvironment();
497                 String serverName = m_Hostname;
498
499                 if (!environmentName.IsEmpty())
500                         serverName += ":" + environmentName;
501
502                 SSL_set_tlsext_host_name(native_handle(), serverName.CStr());
503         }
504 #endif /* SSL_CTRL_SET_TLSEXT_HOSTNAME */
505 }