]> granicus.if.org Git - icinga2/commitdiff
Bugfixes for SSL sockets.
authorGunnar Beutner <gunnar.beutner@netways.de>
Mon, 16 Jul 2012 09:44:11 +0000 (11:44 +0200)
committerGunnar Beutner <gunnar.beutner@netways.de>
Mon, 16 Jul 2012 10:25:09 +0000 (12:25 +0200)
base/socket.cpp
base/socket.h
base/tcpclient.cpp
base/tlsclient.cpp

index 810ce314bced3144d30e495b3dd90206df4a510c..7ca0f949b3a4406e518b0f900e31d13f41077780 100644 (file)
@@ -316,12 +316,8 @@ void Socket::ReadThreadProc(void)
                        return;
                }
 
-               if (FD_ISSET(fd, &readfds)) {
-                       if (!m_Connected)
-                               m_Connected = true;
-
+               if (FD_ISSET(fd, &readfds))
                        HandleReadable();
-               }
 
                if (FD_ISSET(fd, &exceptfds))
                        HandleException();
@@ -340,7 +336,7 @@ void Socket::WriteThreadProc(void)
 
                FD_ZERO(&writefds);
 
-               while (!WantsToWrite() && m_Connected) {
+               while (!WantsToWrite()) {
                        m_WriteCV.timed_wait(lock, boost::posix_time::seconds(1));
 
                        if (GetFD() == INVALID_SOCKET)
@@ -368,12 +364,8 @@ void Socket::WriteThreadProc(void)
                        return;
                }
 
-               if (FD_ISSET(fd, &writefds)) {
-                       if (!m_Connected)
-                               m_Connected = true;
-
+               if (FD_ISSET(fd, &writefds))
                        HandleWritable();
-               }
        }
 }
 
@@ -381,3 +373,13 @@ mutex& Socket::GetMutex(void) const
 {
        return m_Mutex;
 }
+
+void Socket::SetConnected(bool connected)
+{
+       m_Connected = connected;
+}
+
+bool Socket::IsConnected(void) const
+{
+       return m_Connected;
+}
index e046f5a29a9ba60bcc4f20aa0dbcbb878dafa899..5eeb0fe5b647c0792839d022c502078ab2daa227 100644 (file)
@@ -53,6 +53,9 @@ protected:
        void SetFD(SOCKET fd);
        SOCKET GetFD(void) const;
 
+       void SetConnected(bool connected);
+       bool IsConnected(void) const;
+
        int GetError(void) const;
        static int GetLastSocketError(void);
        void HandleSocketError(const exception& ex);
index ce9c89a423dd64f58d683b4ac7787d98357ab892..c797f3bcdd6663214c15e6b1b3c43e431ed03187 100644 (file)
@@ -120,10 +120,14 @@ void TcpClient::HandleWritable(void)
                rc = send(GetFD(), (const char *)data, count, 0);
 
                if (rc <= 0) {
+                       SetConnected(false);
+
                        HandleSocketError(SocketException("send() failed", GetError()));
                        return;
                }
 
+               SetConnected(true);
+
                m_SendQueue->Read(NULL, rc);
        }
 }
@@ -182,10 +186,14 @@ void TcpClient::HandleReadable(void)
                        return;
 
                if (rc <= 0) {
+                       SetConnected(false);
+
                        HandleSocketError(SocketException("recv() failed", GetError()));
                        return;
                }
 
+               SetConnected(true);
+
                m_RecvQueue->Write(data, rc);
        }
 
index 593bfed855b20494f2510aa485937f91d6b53cdb..c948e89031b697593dd53bd32c210ce1c306c28d 100644 (file)
@@ -118,7 +118,13 @@ void TlsClient::HandleReadable(void)
 
        for (;;) {
                char data[1024];
-               int rc = SSL_read(m_SSL.get(), data, sizeof(data));
+               int rc;
+
+               if (IsConnected()) {
+                       rc = SSL_read(m_SSL.get(), data, sizeof(data));
+               } else {
+                       rc = SSL_do_handshake(m_SSL.get());
+               }
 
                if (rc <= 0) {
                        switch (SSL_get_error(m_SSL.get(), rc)) {
@@ -137,7 +143,10 @@ void TlsClient::HandleReadable(void)
                        }
                }
 
-               m_RecvQueue->Write(data, rc);
+               if (IsConnected())
+                       m_RecvQueue->Write(data, rc);
+               else
+                       SetConnected(true);
        }
 
 post_event:
@@ -156,17 +165,23 @@ void TlsClient::HandleWritable(void)
        size_t count;
 
        for (;;) {
-               count = m_SendQueue->GetAvailableBytes();
+               int rc;
 
-               if (count == 0)
-                       break;
+               if (IsConnected()) {
+                       count = m_SendQueue->GetAvailableBytes();
 
-               if (count > sizeof(data))
-                       count = sizeof(data);
+                       if (count == 0)
+                               break;
 
-               m_SendQueue->Peek(data, count);
+                       if (count > sizeof(data))
+                               count = sizeof(data);
 
-               int rc = SSL_write(m_SSL.get(), (const char *)data, count);
+                       m_SendQueue->Peek(data, count);
+
+                       rc = SSL_write(m_SSL.get(), (const char *)data, count);
+               } else {
+                       rc = SSL_do_handshake(m_SSL.get());
+               }
 
                if (rc <= 0) {
                        switch (SSL_get_error(m_SSL.get(), rc)) {
@@ -185,7 +200,10 @@ void TlsClient::HandleWritable(void)
                        }
                }
 
-               m_SendQueue->Read(NULL, rc);
+               if (IsConnected())
+                       m_SendQueue->Read(NULL, rc);
+               else
+                       SetConnected(true);
        }
 }