From caf08b30693f57b2c93e47b32291ee23efed13e0 Mon Sep 17 00:00:00 2001 From: Gunnar Beutner Date: Mon, 6 Aug 2012 10:01:21 +0200 Subject: [PATCH] Improved TcpClient lock usage. --- base/socket.cpp | 17 ++++++----------- base/socket.h | 5 ++--- base/tcpclient.cpp | 44 +++++++++++++++++++++++++++++++------------- base/tcpclient.h | 1 + base/tlsclient.cpp | 30 ++++++++++++++++++++---------- 5 files changed, 60 insertions(+), 37 deletions(-) diff --git a/base/socket.cpp b/base/socket.cpp index c1046ab84..f1a274ccf 100644 --- a/base/socket.cpp +++ b/base/socket.cpp @@ -33,7 +33,7 @@ Socket::Socket(void) */ Socket::~Socket(void) { - mutex::scoped_lock lock(m_Mutex); + mutex::scoped_lock lock(m_SocketMutex); CloseInternal(true); } @@ -90,7 +90,7 @@ SOCKET Socket::GetFD(void) const */ void Socket::Close(void) { - mutex::scoped_lock lock(m_Mutex); + mutex::scoped_lock lock(m_SocketMutex); CloseInternal(false); } @@ -209,7 +209,7 @@ String Socket::GetAddressFromSockaddr(sockaddr *address, socklen_t len) */ String Socket::GetClientAddress(void) { - mutex::scoped_lock lock(m_Mutex); + mutex::scoped_lock lock(m_SocketMutex); sockaddr_storage sin; socklen_t len = sizeof(sin); @@ -227,7 +227,7 @@ String Socket::GetClientAddress(void) */ String Socket::GetPeerAddress(void) { - mutex::scoped_lock lock(m_Mutex); + mutex::scoped_lock lock(m_SocketMutex); sockaddr_storage sin; socklen_t len = sizeof(sin); @@ -258,7 +258,7 @@ SocketException::SocketException(const String& message, int errorCode) void Socket::ReadThreadProc(void) { - mutex::scoped_lock lock(m_Mutex); + mutex::scoped_lock lock(m_SocketMutex); for (;;) { fd_set readfds, exceptfds; @@ -312,7 +312,7 @@ void Socket::ReadThreadProc(void) void Socket::WriteThreadProc(void) { - mutex::scoped_lock lock(m_Mutex); + mutex::scoped_lock lock(m_SocketMutex); for (;;) { fd_set writefds; @@ -358,11 +358,6 @@ void Socket::WriteThreadProc(void) } } -mutex& Socket::GetMutex(void) const -{ - return m_Mutex; -} - void Socket::SetConnected(bool connected) { m_Connected = connected; diff --git a/base/socket.h b/base/socket.h index 13c3af0bf..fbf394004 100644 --- a/base/socket.h +++ b/base/socket.h @@ -44,8 +44,6 @@ public: String GetClientAddress(void); String GetPeerAddress(void); - mutex& GetMutex(void) const; - bool IsConnected(void) const; void CheckException(void); @@ -70,6 +68,8 @@ protected: virtual void CloseInternal(bool from_dtor); + mutable mutex m_SocketMutex; + private: SOCKET m_FD; /**< The socket descriptor. */ bool m_Connected; @@ -79,7 +79,6 @@ private: condition_variable m_WriteCV; - mutable mutex m_Mutex; boost::exception_ptr m_Exception; void ReadThreadProc(void); diff --git a/base/tcpclient.cpp b/base/tcpclient.cpp index 4ea083829..9e980406c 100644 --- a/base/tcpclient.cpp +++ b/base/tcpclient.cpp @@ -113,22 +113,29 @@ void TcpClient::HandleWritable(void) } for (;;) { - count = m_SendQueue->GetAvailableBytes(); + { + mutex::scoped_lock lock(m_QueueMutex); - if (count == 0) - break; + count = m_SendQueue->GetAvailableBytes(); - if (count > sizeof(data)) - count = sizeof(data); + if (count == 0) + break; - m_SendQueue->Peek(data, count); + if (count > sizeof(data)) + count = sizeof(data); + + m_SendQueue->Peek(data, count); + } rc = send(GetFD(), (const char *)data, count, 0); if (rc <= 0) throw_exception(SocketException("send() failed", GetError())); - m_SendQueue->Read(NULL, rc); + { + mutex::scoped_lock lock(m_QueueMutex); + m_SendQueue->Read(NULL, rc); + } } } @@ -137,7 +144,7 @@ void TcpClient::HandleWritable(void) */ size_t TcpClient::GetAvailableBytes(void) const { - mutex::scoped_lock lock(GetMutex()); + mutex::scoped_lock lock(m_QueueMutex); return m_RecvQueue->GetAvailableBytes(); } @@ -147,7 +154,7 @@ size_t TcpClient::GetAvailableBytes(void) const */ void TcpClient::Peek(void *buffer, size_t count) { - mutex::scoped_lock lock(GetMutex()); + mutex::scoped_lock lock(m_QueueMutex); m_RecvQueue->Peek(buffer, count); } @@ -157,7 +164,7 @@ void TcpClient::Peek(void *buffer, size_t count) */ void TcpClient::Read(void *buffer, size_t count) { - mutex::scoped_lock lock(GetMutex()); + mutex::scoped_lock lock(m_QueueMutex); m_RecvQueue->Read(buffer, count); } @@ -167,7 +174,7 @@ void TcpClient::Read(void *buffer, size_t count) */ void TcpClient::Write(const void *buffer, size_t count) { - mutex::scoped_lock lock(GetMutex()); + mutex::scoped_lock lock(m_QueueMutex); m_SendQueue->Write(buffer, count); } @@ -193,7 +200,11 @@ void TcpClient::HandleReadable(void) if (rc <= 0) throw_exception(SocketException("recv() failed", GetError())); - m_RecvQueue->Write(data, rc); + { + mutex::scoped_lock lock(m_QueueMutex); + + m_RecvQueue->Write(data, rc); + } } Event::Post(boost::bind(boost::ref(OnDataAvailable), GetSelf())); @@ -216,7 +227,14 @@ bool TcpClient::WantsToRead(void) const */ bool TcpClient::WantsToWrite(void) const { - return (m_SendQueue->GetAvailableBytes() > 0 || !IsConnected()); + { + mutex::scoped_lock lock(m_QueueMutex); + + if (m_SendQueue->GetAvailableBytes() > 0) + return true; + } + + return (!IsConnected()); } /** diff --git a/base/tcpclient.h b/base/tcpclient.h index 3498f5d7c..76076e494 100644 --- a/base/tcpclient.h +++ b/base/tcpclient.h @@ -68,6 +68,7 @@ protected: virtual void HandleReadable(void); virtual void HandleWritable(void); + mutable mutex m_QueueMutex; FIFO::Ptr m_SendQueue; FIFO::Ptr m_RecvQueue; diff --git a/base/tlsclient.cpp b/base/tlsclient.cpp index ac2fb522e..111bc2850 100644 --- a/base/tlsclient.cpp +++ b/base/tlsclient.cpp @@ -90,7 +90,7 @@ void TlsClient::NullCertificateDeleter(X509 *certificate) */ shared_ptr TlsClient::GetClientCertificate(void) const { - mutex::scoped_lock lock(GetMutex()); + mutex::scoped_lock lock(m_SocketMutex); return shared_ptr(SSL_get_certificate(m_SSL.get()), &TlsClient::NullCertificateDeleter); } @@ -102,7 +102,7 @@ shared_ptr TlsClient::GetClientCertificate(void) const */ shared_ptr TlsClient::GetPeerCertificate(void) const { - mutex::scoped_lock lock(GetMutex()); + mutex::scoped_lock lock(m_SocketMutex); return shared_ptr(SSL_get_peer_certificate(m_SSL.get()), X509_free); } @@ -146,8 +146,11 @@ void TlsClient::HandleReadable(void) } } - if (IsConnected()) + if (IsConnected()) { + mutex::scoped_lock lock(m_QueueMutex); + m_RecvQueue->Write(data, rc); + } } post_event: @@ -169,15 +172,19 @@ void TlsClient::HandleWritable(void) int rc; if (IsConnected()) { - count = m_SendQueue->GetAvailableBytes(); + { + mutex::scoped_lock lock(m_QueueMutex); + + count = m_SendQueue->GetAvailableBytes(); - if (count == 0) - break; + if (count == 0) + break; - if (count > sizeof(data)) - count = sizeof(data); + if (count > sizeof(data)) + count = sizeof(data); - m_SendQueue->Peek(data, count); + m_SendQueue->Peek(data, count); + } rc = SSL_write(m_SSL.get(), (const char *)data, count); } else { @@ -205,8 +212,11 @@ void TlsClient::HandleWritable(void) } } - if (IsConnected()) + if (IsConnected()) { + mutex::scoped_lock lock(m_QueueMutex); + m_SendQueue->Read(NULL, rc); + } } } -- 2.40.0