]> granicus.if.org Git - icinga2/blob - lib/base/tlsstream.cpp
Make ApiListener#m_SSLContext a Boost ASIO SSL context
[icinga2] / lib / base / tlsstream.cpp
1 /* Icinga 2 | (c) 2012 Icinga GmbH | GPLv2+ */
2
3 #include "base/tlsstream.hpp"
4 #include "base/utility.hpp"
5 #include "base/exception.hpp"
6 #include "base/logger.hpp"
7 #include "base/configuration.hpp"
8 #include "base/convert.hpp"
9 #include <boost/asio/ssl/context.hpp>
10 #include <iostream>
11
12 #ifndef _WIN32
13 #       include <poll.h>
14 #endif /* _WIN32 */
15
16 #define TLS_TIMEOUT_SECONDS 10
17
18 using namespace icinga;
19
20 int TlsStream::m_SSLIndex;
21 bool TlsStream::m_SSLIndexInitialized = false;
22
23 /**
24  * Constructor for the TlsStream class.
25  *
26  * @param role The role of the client.
27  * @param sslContext The SSL context for the client.
28  */
29 TlsStream::TlsStream(const Socket::Ptr& socket, const String& hostname, ConnectionRole role, const std::shared_ptr<SSL_CTX>& sslContext)
30         : TlsStream(socket, hostname, role, sslContext.get())
31 {
32 }
33
34 /**
35  * Constructor for the TlsStream class.
36  *
37  * @param role The role of the client.
38  * @param sslContext The SSL context for the client.
39  */
40 TlsStream::TlsStream(const Socket::Ptr& socket, const String& hostname, ConnectionRole role, const std::shared_ptr<boost::asio::ssl::context>& sslContext)
41         : TlsStream(socket, hostname, role, sslContext->native_handle())
42 {
43 }
44
45 /**
46  * Constructor for the TlsStream class.
47  *
48  * @param role The role of the client.
49  * @param sslContext The SSL context for the client.
50  */
51 TlsStream::TlsStream(const Socket::Ptr& socket, const String& hostname, ConnectionRole role, SSL_CTX* sslContext)
52         : SocketEvents(socket), m_Eof(false), m_HandshakeOK(false), m_VerifyOK(true), m_ErrorCode(0),
53         m_ErrorOccurred(false),  m_Socket(socket), m_Role(role), m_SendQ(new FIFO()), m_RecvQ(new FIFO()),
54         m_CurrentAction(TlsActionNone), m_Retry(false), m_Shutdown(false)
55 {
56         std::ostringstream msgbuf;
57         char errbuf[120];
58
59         m_SSL = std::shared_ptr<SSL>(SSL_new(sslContext), SSL_free);
60
61         if (!m_SSL) {
62                 msgbuf << "SSL_new() failed with code " << ERR_peek_error() << ", \"" << ERR_error_string(ERR_peek_error(), errbuf) << "\"";
63                 Log(LogCritical, "TlsStream", msgbuf.str());
64
65                 BOOST_THROW_EXCEPTION(openssl_error()
66                         << boost::errinfo_api_function("SSL_new")
67                         << errinfo_openssl_error(ERR_peek_error()));
68         }
69
70         if (!m_SSLIndexInitialized) {
71                 m_SSLIndex = SSL_get_ex_new_index(0, const_cast<char *>("TlsStream"), nullptr, nullptr, nullptr);
72                 m_SSLIndexInitialized = true;
73         }
74
75         SSL_set_ex_data(m_SSL.get(), m_SSLIndex, this);
76
77         SSL_set_verify(m_SSL.get(), SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, &TlsStream::ValidateCertificate);
78
79         socket->MakeNonBlocking();
80
81         SSL_set_fd(m_SSL.get(), socket->GetFD());
82
83         if (m_Role == RoleServer)
84                 SSL_set_accept_state(m_SSL.get());
85         else {
86 #ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
87                 if (!hostname.IsEmpty())
88                         SSL_set_tlsext_host_name(m_SSL.get(), hostname.CStr());
89 #endif /* SSL_CTRL_SET_TLSEXT_HOSTNAME */
90
91                 SSL_set_connect_state(m_SSL.get());
92         }
93 }
94
95 TlsStream::~TlsStream()
96 {
97         CloseInternal(true);
98 }
99
100 int TlsStream::ValidateCertificate(int preverify_ok, X509_STORE_CTX *ctx)
101 {
102         auto *ssl = static_cast<SSL *>(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
103         auto *stream = static_cast<TlsStream *>(SSL_get_ex_data(ssl, m_SSLIndex));
104
105         if (!preverify_ok) {
106                 stream->m_VerifyOK = false;
107
108                 std::ostringstream msgbuf;
109                 int err = X509_STORE_CTX_get_error(ctx);
110                 msgbuf << "code " << err << ": " << X509_verify_cert_error_string(err);
111                 stream->m_VerifyError = msgbuf.str();
112         }
113
114         return 1;
115 }
116
117 bool TlsStream::IsVerifyOK() const
118 {
119         return m_VerifyOK;
120 }
121
122 String TlsStream::GetVerifyError() const
123 {
124         return m_VerifyError;
125 }
126
127 /**
128  * Retrieves the X509 certficate for this client.
129  *
130  * @returns The X509 certificate.
131  */
132 std::shared_ptr<X509> TlsStream::GetClientCertificate() const
133 {
134         boost::mutex::scoped_lock lock(m_Mutex);
135         return std::shared_ptr<X509>(SSL_get_certificate(m_SSL.get()), &Utility::NullDeleter);
136 }
137
138 /**
139  * Retrieves the X509 certficate for the peer.
140  *
141  * @returns The X509 certificate.
142  */
143 std::shared_ptr<X509> TlsStream::GetPeerCertificate() const
144 {
145         boost::mutex::scoped_lock lock(m_Mutex);
146         return std::shared_ptr<X509>(SSL_get_peer_certificate(m_SSL.get()), X509_free);
147 }
148
149 void TlsStream::OnEvent(int revents)
150 {
151         int rc;
152         size_t count;
153
154         boost::mutex::scoped_lock lock(m_Mutex);
155
156         if (!m_SSL)
157                 return;
158
159         char buffer[64 * 1024];
160
161         if (m_CurrentAction == TlsActionNone) {
162                 if (revents & (POLLIN | POLLERR | POLLHUP))
163                         m_CurrentAction = TlsActionRead;
164                 else if (m_SendQ->GetAvailableBytes() > 0 && (revents & POLLOUT))
165                         m_CurrentAction = TlsActionWrite;
166                 else {
167                         ChangeEvents(POLLIN);
168
169                         return;
170                 }
171         }
172
173         bool success = false;
174
175         /* Clear error queue for this thread before using SSL_{read,write,do_handshake}.
176          * Otherwise SSL_*_error() does not work reliably.
177          */
178         ERR_clear_error();
179
180         size_t readTotal = 0;
181
182         switch (m_CurrentAction) {
183                 case TlsActionRead:
184                         do {
185                                 rc = SSL_read(m_SSL.get(), buffer, sizeof(buffer));
186
187                                 if (rc > 0) {
188                                         m_RecvQ->Write(buffer, rc);
189                                         success = true;
190
191                                         readTotal += rc;
192                                 }
193
194 #ifdef I2_DEBUG /* I2_DEBUG */
195                                 Log(LogDebug, "TlsStream")
196                                         << "Read bytes: " << rc << " Total read bytes: " << readTotal;
197 #endif /* I2_DEBUG */
198                                 /* Limit read size. We cannot do this check inside the while loop
199                                  * since below should solely check whether OpenSSL has more data
200                                  * or not. */
201                                 if (readTotal >= 64 * 1024) {
202 #ifdef I2_DEBUG /* I2_DEBUG */
203                                         Log(LogWarning, "TlsStream")
204                                                 << "Maximum read bytes exceeded: " << readTotal;
205 #endif /* I2_DEBUG */
206                                         break;
207                                 }
208
209                         /* Use OpenSSL's state machine here to determine whether we need
210                          * to read more data. SSL_has_pending() is available with 1.1.0.
211                          */
212                         } while (SSL_pending(m_SSL.get()));
213
214                         if (success)
215                                 m_CV.notify_all();
216
217                         break;
218                 case TlsActionWrite:
219                         count = m_SendQ->Peek(buffer, sizeof(buffer), true);
220
221                         rc = SSL_write(m_SSL.get(), buffer, count);
222
223                         if (rc > 0) {
224                                 m_SendQ->Read(nullptr, rc, true);
225                                 success = true;
226                         }
227
228                         break;
229                 case TlsActionHandshake:
230                         rc = SSL_do_handshake(m_SSL.get());
231
232                         if (rc > 0) {
233                                 success = true;
234                                 m_HandshakeOK = true;
235                                 m_CV.notify_all();
236                         }
237
238                         break;
239                 default:
240                         VERIFY(!"Invalid TlsAction");
241         }
242
243         if (rc <= 0) {
244                 int err = SSL_get_error(m_SSL.get(), rc);
245
246                 switch (err) {
247                         case SSL_ERROR_WANT_READ:
248                                 m_Retry = true;
249                                 ChangeEvents(POLLIN);
250
251                                 break;
252                         case SSL_ERROR_WANT_WRITE:
253                                 m_Retry = true;
254                                 ChangeEvents(POLLOUT);
255
256                                 break;
257                         case SSL_ERROR_ZERO_RETURN:
258                                 lock.unlock();
259
260                                 Close();
261
262                                 return;
263                         default:
264                                 m_ErrorCode = ERR_peek_error();
265                                 m_ErrorOccurred = true;
266
267                                 if (m_ErrorCode != 0) {
268                                         Log(LogWarning, "TlsStream")
269                                                 << "OpenSSL error: " << ERR_error_string(m_ErrorCode, nullptr);
270                                 } else {
271                                         Log(LogWarning, "TlsStream", "TLS stream was disconnected.");
272                                 }
273
274                                 lock.unlock();
275
276                                 Close();
277
278                                 return;
279                 }
280         }
281
282         if (success) {
283                 m_CurrentAction = TlsActionNone;
284
285                 if (!m_Eof) {
286                         if (m_SendQ->GetAvailableBytes() > 0)
287                                 ChangeEvents(POLLIN|POLLOUT);
288                         else
289                                 ChangeEvents(POLLIN);
290                 }
291
292                 lock.unlock();
293
294                 while (m_RecvQ->IsDataAvailable() && IsHandlingEvents())
295                         SignalDataAvailable();
296         }
297
298         if (m_Shutdown && !m_SendQ->IsDataAvailable()) {
299                 if (!success)
300                         lock.unlock();
301
302                 Close();
303         }
304 }
305
306 void TlsStream::HandleError() const
307 {
308         if (m_ErrorOccurred) {
309                 BOOST_THROW_EXCEPTION(openssl_error()
310                         << boost::errinfo_api_function("TlsStream::OnEvent")
311                         << errinfo_openssl_error(m_ErrorCode));
312         }
313 }
314
315 void TlsStream::Handshake()
316 {
317         boost::mutex::scoped_lock lock(m_Mutex);
318
319         m_CurrentAction = TlsActionHandshake;
320         ChangeEvents(POLLOUT);
321
322         boost::system_time const timeout = boost::get_system_time() + boost::posix_time::milliseconds(long(Configuration::TlsHandshakeTimeout * 1000));
323
324         while (!m_HandshakeOK && !m_ErrorOccurred && !m_Eof && timeout > boost::get_system_time())
325                 m_CV.timed_wait(lock, timeout);
326
327         if (timeout < boost::get_system_time())
328                 BOOST_THROW_EXCEPTION(std::runtime_error("Timeout was reached (" + Convert::ToString(Configuration::TlsHandshakeTimeout) + ") during TLS handshake."));
329
330         if (m_Eof)
331                 BOOST_THROW_EXCEPTION(std::runtime_error("Socket was closed during TLS handshake."));
332
333         HandleError();
334 }
335
336 /**
337  * Processes data for the stream.
338  */
339 size_t TlsStream::Peek(void *buffer, size_t count, bool allow_partial)
340 {
341         boost::mutex::scoped_lock lock(m_Mutex);
342
343         if (!allow_partial)
344                 while (m_RecvQ->GetAvailableBytes() < count && !m_ErrorOccurred && !m_Eof)
345                         m_CV.wait(lock);
346
347         HandleError();
348
349         return m_RecvQ->Peek(buffer, count, true);
350 }
351
352 size_t TlsStream::Read(void *buffer, size_t count, bool allow_partial)
353 {
354         boost::mutex::scoped_lock lock(m_Mutex);
355
356         if (!allow_partial)
357                 while (m_RecvQ->GetAvailableBytes() < count && !m_ErrorOccurred && !m_Eof)
358                         m_CV.wait(lock);
359
360         HandleError();
361
362         return m_RecvQ->Read(buffer, count, true);
363 }
364
365 void TlsStream::Write(const void *buffer, size_t count)
366 {
367         boost::mutex::scoped_lock lock(m_Mutex);
368
369         m_SendQ->Write(buffer, count);
370
371         ChangeEvents(POLLIN|POLLOUT);
372 }
373
374 void TlsStream::Shutdown()
375 {
376         m_Shutdown = true;
377         ChangeEvents(POLLOUT);
378 }
379
380 /**
381  * Closes the stream.
382  */
383 void TlsStream::Close()
384 {
385         CloseInternal(false);
386 }
387
388 void TlsStream::CloseInternal(bool inDestructor)
389 {
390         if (m_Eof)
391                 return;
392
393         m_Eof = true;
394
395         if (!inDestructor)
396                 SignalDataAvailable();
397
398         SocketEvents::Unregister();
399
400         Stream::Close();
401
402         boost::mutex::scoped_lock lock(m_Mutex);
403
404         if (!m_SSL)
405                 return;
406
407         /* https://www.openssl.org/docs/manmaster/man3/SSL_shutdown.html
408          *
409          * It is recommended to do a bidirectional shutdown by checking
410          * the return value of SSL_shutdown() and call it again until
411          * it returns 1 or a fatal error. A maximum of 2x pending + 2x data
412          * is recommended.
413          */
414         int rc = 0;
415
416         for (int i = 0; i < 4; i++) {
417                 if ((rc = SSL_shutdown(m_SSL.get())))
418                         break;
419         }
420
421         m_SSL.reset();
422
423         m_Socket->Close();
424         m_Socket.reset();
425
426         m_CV.notify_all();
427 }
428
429 bool TlsStream::IsEof() const
430 {
431         return m_Eof && m_RecvQ->GetAvailableBytes() < 1u;
432 }
433
434 bool TlsStream::SupportsWaiting() const
435 {
436         return true;
437 }
438
439 bool TlsStream::IsDataAvailable() const
440 {
441         boost::mutex::scoped_lock lock(m_Mutex);
442
443         return m_RecvQ->GetAvailableBytes() > 0;
444 }
445
446 Socket::Ptr TlsStream::GetSocket() const
447 {
448         return m_Socket;
449 }