From aa1ccd7ada4a492ddd6c90d882730b750963ebc3 Mon Sep 17 00:00:00 2001 From: Jean Flach Date: Thu, 8 Feb 2018 14:54:52 +0100 Subject: [PATCH] Authenticate API user before parsing body --- lib/remote/apiuser.cpp | 28 ++++ lib/remote/apiuser.hpp | 1 + lib/remote/httprequest.cpp | 56 ++++---- lib/remote/httprequest.hpp | 1 + lib/remote/httpserverconnection.cpp | 200 +++++++++++++--------------- lib/remote/httpserverconnection.hpp | 6 +- 6 files changed, 157 insertions(+), 135 deletions(-) diff --git a/lib/remote/apiuser.cpp b/lib/remote/apiuser.cpp index fd4aec9a5..fedcd7163 100644 --- a/lib/remote/apiuser.cpp +++ b/lib/remote/apiuser.cpp @@ -20,6 +20,7 @@ #include "remote/apiuser.hpp" #include "remote/apiuser.tcpp" #include "base/configtype.hpp" +#include "base/base64.hpp" using namespace icinga; @@ -34,3 +35,30 @@ ApiUser::Ptr ApiUser::GetByClientCN(const String& cn) return ApiUser::Ptr(); } + +ApiUser::Ptr ApiUser::GetByAuthHeader(const String& auth_header) { + String::SizeType pos = auth_header.FindFirstOf(" "); + String username, password; + + if (pos != String::NPos && auth_header.SubStr(0, pos) == "Basic") { + String credentials_base64 = auth_header.SubStr(pos + 1); + String credentials = Base64::Decode(credentials_base64); + + String::SizeType cpos = credentials.FindFirstOf(":"); + + if (cpos != String::NPos) { + username = credentials.SubStr(0, cpos); + password = credentials.SubStr(cpos + 1); + } + } + + const ApiUser::Ptr& user = ApiUser::GetByName(username); + + /* Deny authentication if 1) given password is empty 2) configured password does not match. */ + if (password.IsEmpty()) + return nullptr; + else if (user && user->GetPassword() != password) + return nullptr; + + return user; +} diff --git a/lib/remote/apiuser.hpp b/lib/remote/apiuser.hpp index 34a994158..4021487b4 100644 --- a/lib/remote/apiuser.hpp +++ b/lib/remote/apiuser.hpp @@ -36,6 +36,7 @@ public: DECLARE_OBJECTNAME(ApiUser); static ApiUser::Ptr GetByClientCN(const String& cn); + static ApiUser::Ptr GetByAuthHeader(const String& auth_header); }; } diff --git a/lib/remote/httprequest.cpp b/lib/remote/httprequest.cpp index 2c731c4fc..a55e12638 100644 --- a/lib/remote/httprequest.cpp +++ b/lib/remote/httprequest.cpp @@ -30,6 +30,7 @@ using namespace icinga; HttpRequest::HttpRequest(const Stream::Ptr& stream) : CompleteHeaders(false), + CompleteHeaderCheck(false), CompleteBody(false), ProtocolVersion(HttpVersion11), Headers(new Dictionary()), @@ -43,7 +44,7 @@ bool HttpRequest::ParseHeader(StreamReadContext& src, bool may_wait) return false; if (m_State != HttpRequestStart && m_State != HttpRequestHeaders) - return false; + BOOST_THROW_EXCEPTION(std::runtime_error("Invalid HTTP state")); String line; StreamReadStatus srs = m_Stream->ReadLine(&line, src, may_wait); @@ -109,19 +110,19 @@ bool HttpRequest::ParseHeader(StreamReadContext& src, bool may_wait) bool HttpRequest::ParseBody(StreamReadContext& src, bool may_wait) { - if (!m_Stream || m_State != HttpRequestBody) + if (!m_Stream) return false; + if (m_State != HttpRequestBody) + BOOST_THROW_EXCEPTION(std::runtime_error("Invalid HTTP state")); + /* we're done if the request doesn't contain a message body */ if (!Headers->Contains("content-length") && !Headers->Contains("transfer-encoding")) { CompleteBody = true; - return true; + return false; } else if (!m_Body) m_Body = new FIFO(); - if (CompleteBody) - return true; - if (Headers->Get("transfer-encoding") == "chunked") { if (!m_ChunkContext) m_ChunkContext = boost::make_shared(boost::ref(src)); @@ -139,39 +140,38 @@ bool HttpRequest::ParseBody(StreamReadContext& src, bool may_wait) if (size == 0) { CompleteBody = true; + return false; + } else return true; - } - } else { - if (src.Eof) - BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body")); + } - if (src.MustRead) { - if (!src.FillFromStream(m_Stream, false)) { - src.Eof = true; - BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body")); - } + if (src.Eof) + BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body")); - src.MustRead = false; + if (src.MustRead) { + if (!src.FillFromStream(m_Stream, false)) { + src.Eof = true; + BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body")); } - long length_indicator_signed = Convert::ToLong(Headers->Get("content-length")); + src.MustRead = false; + } - if (length_indicator_signed < 0) - BOOST_THROW_EXCEPTION(std::invalid_argument("Content-Length must not be negative.")); + long length_indicator_signed = Convert::ToLong(Headers->Get("content-length")); - size_t length_indicator = length_indicator_signed; + if (length_indicator_signed < 0) + BOOST_THROW_EXCEPTION(std::invalid_argument("Content-Length must not be negative.")); - if (src.Size < length_indicator) { - src.MustRead = true; - return false; - } + size_t length_indicator = length_indicator_signed; - m_Body->Write(src.Buffer, length_indicator); - src.DropData(length_indicator); - CompleteBody = true; - return true; + if (src.Size < length_indicator) { + src.MustRead = true; + return false; } + m_Body->Write(src.Buffer, length_indicator); + src.DropData(length_indicator); + CompleteBody = true; return true; } diff --git a/lib/remote/httprequest.hpp b/lib/remote/httprequest.hpp index f58cfbbb1..af7ebfbaa 100644 --- a/lib/remote/httprequest.hpp +++ b/lib/remote/httprequest.hpp @@ -53,6 +53,7 @@ struct I2_REMOTE_API HttpRequest { public: bool CompleteHeaders; + bool CompleteHeaderCheck; bool CompleteBody; String RequestMethod; diff --git a/lib/remote/httpserverconnection.cpp b/lib/remote/httpserverconnection.cpp index 3fbee20ba..955862756 100644 --- a/lib/remote/httpserverconnection.cpp +++ b/lib/remote/httpserverconnection.cpp @@ -90,100 +90,105 @@ void HttpServerConnection::Disconnect(void) bool HttpServerConnection::ProcessMessage(void) { bool res; + HttpResponse response(m_Stream, m_CurrentRequest); - try { - res = m_CurrentRequest.ParseHeader(m_Context, false); - } catch (const std::invalid_argument& ex) { - HttpResponse response(m_Stream, m_CurrentRequest); - response.SetStatus(400, "Bad request"); - String msg = String("

Bad request

") + ex.what() + "

"; - response.WriteBody(msg.CStr(), msg.GetLength()); - response.Finish(); + if (!m_CurrentRequest.CompleteHeaders) { + try { + res = m_CurrentRequest.ParseHeader(m_Context, false); + } catch (const std::invalid_argument& ex) { + response.SetStatus(400, "Bad Request"); + String msg = String("

Bad Request

") + ex.what() + "

"; + response.WriteBody(msg.CStr(), msg.GetLength()); + response.Finish(); - m_Stream->Shutdown(); - return false; - } catch (const std::exception& ex) { - HttpResponse response(m_Stream, m_CurrentRequest); - response.SetStatus(400, "Bad request"); - String msg = "

Bad request

" + DiagnosticInformation(ex) + "

"; - response.WriteBody(msg.CStr(), msg.GetLength()); - response.Finish(); + m_Stream->Shutdown(); + return false; + } catch (const std::exception& ex) { + response.SetStatus(500, "Internal Server Error"); + String msg = "

Internal Server Error

" + DiagnosticInformation(ex) + "

"; + response.WriteBody(msg.CStr(), msg.GetLength()); + response.Finish(); - m_Stream->Shutdown(); - return false; + m_Stream->Shutdown(); + return false; + } + return res; } - if (m_CurrentRequest.CompleteHeaders) { - m_RequestQueue.Enqueue(boost::bind(&HttpServerConnection::ProcessMessageAsync, - HttpServerConnection::Ptr(this), m_CurrentRequest)); + if (!m_CurrentRequest.CompleteHeaderCheck) { + m_CurrentRequest.CompleteHeaderCheck = true; + if (!ManageHeaders(response)) { + m_Stream->Shutdown(); + return false; + } + } - m_Seen = Utility::GetTime(); - m_PendingRequests++; + if (!m_CurrentRequest.CompleteBody) { + try { + res = m_CurrentRequest.ParseBody(m_Context, false); + } catch (const std::invalid_argument& ex) { + response.SetStatus(400, "Bad Request"); + String msg = String("

Bad Request

") + ex.what() + "

"; + response.WriteBody(msg.CStr(), msg.GetLength()); + response.Finish(); - m_CurrentRequest.~HttpRequest(); - new (&m_CurrentRequest) HttpRequest(m_Stream); + m_Stream->Shutdown(); + return false; + } catch (const std::exception& ex) { + response.SetStatus(500, "Internal Server Error"); + String msg = "

Internal Server Error

" + DiagnosticInformation(ex) + "

"; + response.WriteBody(msg.CStr(), msg.GetLength()); + response.Finish(); - return true; + m_Stream->Shutdown(); + return false; + } + return res; } - return res; -} + m_RequestQueue.Enqueue(std::bind(&HttpServerConnection::ProcessMessageAsync, + HttpServerConnection::Ptr(this), m_CurrentRequest, response, m_AuthenticatedUser)); -void HttpServerConnection::ProcessMessageAsync(HttpRequest& request) -{ - String auth_header = request.Headers->Get("authorization"); + m_Seen = Utility::GetTime(); + m_PendingRequests++; - String::SizeType pos = auth_header.FindFirstOf(" "); - String username, password; - - if (pos != String::NPos && auth_header.SubStr(0, pos) == "Basic") { - String credentials_base64 = auth_header.SubStr(pos + 1); - String credentials = Base64::Decode(credentials_base64); + m_CurrentRequest.~HttpRequest(); + new (&m_CurrentRequest) HttpRequest(m_Stream); - String::SizeType cpos = credentials.FindFirstOf(":"); + return false; +} - if (cpos != String::NPos) { - username = credentials.SubStr(0, cpos); - password = credentials.SubStr(cpos + 1); - } +bool HttpServerConnection::ManageHeaders(HttpResponse& response) +{ + if (m_CurrentRequest.Headers->Get("expect") == "100-continue") { + String continueResponse = "HTTP/1.1 100 Continue\r\n\r\n"; + m_Stream->Write(continueResponse.CStr(), continueResponse.GetLength()); } - ApiUser::Ptr user; - /* client_cn matched. */ if (m_ApiUser) - user = m_ApiUser; - else { - user = ApiUser::GetByName(username); - - /* Deny authentication if 1) given password is empty 2) configured password does not match. */ - if (password.IsEmpty()) - user.reset(); - else if (user && user->GetPassword() != password) - user.reset(); - } + m_AuthenticatedUser = m_ApiUser; + else + m_AuthenticatedUser = ApiUser::GetByAuthHeader(m_CurrentRequest.Headers->Get("authorization")); - String requestUrl = request.RequestUrl->Format(); + String requestUrl = m_CurrentRequest.RequestUrl->Format(); Socket::Ptr socket = m_Stream->GetSocket(); Log(LogInformation, "HttpServerConnection") - << "Request: " << request.RequestMethod << " " << requestUrl + << "Request: " << m_CurrentRequest.RequestMethod << " " << requestUrl << " (from " << (socket ? socket->GetPeerAddress() : "") - << ", user: " << (user ? user->GetName() : "") << ")"; - - HttpResponse response(m_Stream, request); + << ", user: " << (m_AuthenticatedUser ? m_AuthenticatedUser->GetName() : "") << ")"; ApiListener::Ptr listener = ApiListener::GetInstance(); if (!listener) - return; + return false; Array::Ptr headerAllowOrigin = listener->GetAccessControlAllowOrigin(); if (headerAllowOrigin->GetLength() != 0) { - String origin = request.Headers->Get("origin"); - + String origin = m_CurrentRequest.Headers->Get("origin"); { ObjectLock olock(headerAllowOrigin); @@ -196,9 +201,9 @@ void HttpServerConnection::ProcessMessageAsync(HttpRequest& request) if (listener->GetAccessControlAllowCredentials()) response.AddHeader("Access-Control-Allow-Credentials", "true"); - String accessControlRequestMethodHeader = request.Headers->Get("access-control-request-method"); + String accessControlRequestMethodHeader = m_CurrentRequest.Headers->Get("access-control-request-method"); - if (!accessControlRequestMethodHeader.IsEmpty()) { + if (m_CurrentRequest.RequestMethod == "OPTIONS" && !accessControlRequestMethodHeader.IsEmpty()) { response.SetStatus(200, "OK"); response.AddHeader("Access-Control-Allow-Methods", listener->GetAccessControlAllowMethods()); @@ -208,27 +213,27 @@ void HttpServerConnection::ProcessMessageAsync(HttpRequest& request) response.WriteBody(msg.CStr(), msg.GetLength()); response.Finish(); - m_PendingRequests--; - - return; + return false; } } - String accept_header = request.Headers->Get("accept"); - - if (request.RequestMethod != "GET" && accept_header != "application/json") { + if (m_CurrentRequest.RequestMethod != "GET" && m_CurrentRequest.Headers->Get("accept") != "application/json") { response.SetStatus(400, "Wrong Accept header"); response.AddHeader("Content-Type", "text/html"); String msg = "

Accept header is missing or not set to 'application/json'.

"; response.WriteBody(msg.CStr(), msg.GetLength()); - } else if (!user) { + response.Finish(); + return false; + } + + if (!m_AuthenticatedUser) { Log(LogWarning, "HttpServerConnection") - << "Unauthorized request: " << request.RequestMethod << " " << requestUrl; + << "Unauthorized request: " << m_CurrentRequest.RequestMethod << " " << requestUrl; response.SetStatus(401, "Unauthorized"); response.AddHeader("WWW-Authenticate", "Basic realm=\"Icinga 2\""); - if (request.Headers->Get("accept") == "application/json") { + if (m_CurrentRequest.Headers->Get("accept") == "application/json") { Dictionary::Ptr result = new Dictionary(); result->Set("error", 401); @@ -240,42 +245,25 @@ void HttpServerConnection::ProcessMessageAsync(HttpRequest& request) String msg = "

Unauthorized. Please check your user credentials.

"; response.WriteBody(msg.CStr(), msg.GetLength()); } - } else { - bool res = true; - while (!request.CompleteBody) - res = request.ParseBody(m_Context, false); - if (!res) { - Log(LogCritical, "HttpServerConnection", "Failed to read body"); - Dictionary::Ptr result = new Dictionary; - result->Set("error", 404); - result->Set("status", "Bad Request: Malformed body."); - HttpUtility::SendJsonBody(response, result); - } else { - try { - HttpHandler::ProcessRequest(user, request, response); - } catch (const std::exception& ex) { - Log(LogCritical, "HttpServerConnection") - << "Unhandled exception while processing Http request: " << DiagnosticInformation(ex); - response.SetStatus(503, "Unhandled exception"); - - String errorInfo = DiagnosticInformation(ex); - - if (request.Headers->Get("accept") == "application/json") { - Dictionary::Ptr result = new Dictionary(); - result->Set("error", 503); - result->Set("status", errorInfo); - - HttpUtility::SendJsonBody(response, result); - } else { - response.AddHeader("Content-Type", "text/plain"); - response.WriteBody(errorInfo.CStr(), errorInfo.GetLength()); - } - } - } + + response.Finish(); + return false; } - response.Finish(); + return true; +} +void HttpServerConnection::ProcessMessageAsync(HttpRequest& request, HttpResponse& response, const ApiUser::Ptr& user) +{ + try { + HttpHandler::ProcessRequest(user, request, response); + } catch (const std::exception& ex) { + Log(LogCritical, "HttpServerConnection") + << "Unhandled exception while processing Http request: " << DiagnosticInformation(ex); + HttpUtility::SendJsonError(response, 503, "Unhandled exception" , DiagnosticInformation(ex)); + } + + response.Finish(); m_PendingRequests--; } diff --git a/lib/remote/httpserverconnection.hpp b/lib/remote/httpserverconnection.hpp index e8f17dc36..22b036c2f 100644 --- a/lib/remote/httpserverconnection.hpp +++ b/lib/remote/httpserverconnection.hpp @@ -21,6 +21,7 @@ #define HTTPSERVERCONNECTION_H #include "remote/httprequest.hpp" +#include "remote/httpresponse.hpp" #include "remote/apiuser.hpp" #include "base/tlsstream.hpp" #include "base/timer.hpp" @@ -51,6 +52,7 @@ public: private: ApiUser::Ptr m_ApiUser; + ApiUser::Ptr m_AuthenticatedUser; TlsStream::Ptr m_Stream; double m_Seen; HttpRequest m_CurrentRequest; @@ -67,7 +69,9 @@ private: static void TimeoutTimerHandler(void); void CheckLiveness(void); - void ProcessMessageAsync(HttpRequest& request); + bool ManageHeaders(HttpResponse& response); + + void ProcessMessageAsync(HttpRequest& request, HttpResponse& response, const ApiUser::Ptr&); }; } -- 2.40.0