]> granicus.if.org Git - icinga2/commitdiff
Authenticate API user before parsing body
authorJean Flach <jean-marcel.flach@icinga.com>
Thu, 8 Feb 2018 13:54:52 +0000 (14:54 +0100)
committerJean Flach <jean-marcel.flach@icinga.com>
Tue, 20 Feb 2018 12:32:04 +0000 (13:32 +0100)
lib/remote/apiuser.cpp
lib/remote/apiuser.hpp
lib/remote/httprequest.cpp
lib/remote/httprequest.hpp
lib/remote/httpserverconnection.cpp
lib/remote/httpserverconnection.hpp

index 183291af314db4ee16fdb8bd4c96e26f82091277..416eedb349847f82f42c2ba8bbd4493bf24a970c 100644 (file)
@@ -20,6 +20,7 @@
 #include "remote/apiuser.hpp"
 #include "remote/apiuser-ti.cpp"
 #include "base/configtype.hpp"
+#include "base/base64.hpp"
 
 using namespace icinga;
 
@@ -34,3 +35,30 @@ ApiUser::Ptr ApiUser::GetByClientCN(const String& cn)
 
        return nullptr;
 }
+
+ApiUser::Ptr ApiUser::GetByAuthHeader(const String& auth_header) {
+       String::SizeType pos = auth_header.FindFirstOf(" ");
+       String username, password;
+
+       if (pos != String::NPos && auth_header.SubStr(0, pos) == "Basic") {
+               String credentials_base64 = auth_header.SubStr(pos + 1);
+               String credentials = Base64::Decode(credentials_base64);
+
+               String::SizeType cpos = credentials.FindFirstOf(":");
+
+               if (cpos != String::NPos) {
+                       username = credentials.SubStr(0, cpos);
+                       password = credentials.SubStr(cpos + 1);
+               }
+       }
+
+       const ApiUser::Ptr& user = ApiUser::GetByName(username);
+
+       /* Deny authentication if 1) given password is empty 2) configured password does not match. */
+       if (password.IsEmpty())
+               return nullptr;
+       else if (user && user->GetPassword() != password)
+               return nullptr;
+
+       return user;
+}
index 755273bf4cf46f86e8c6b9399e1f0c051847e7db..15b1c41e371b7a9ea28dc5a4cc19867c6d4d872f 100644 (file)
@@ -36,6 +36,7 @@ public:
        DECLARE_OBJECTNAME(ApiUser);
 
        static ApiUser::Ptr GetByClientCN(const String& cn);
+       static ApiUser::Ptr GetByAuthHeader(const String& auth_header);
 };
 
 }
index 0d6c4a6182f0431477e8c8869b9593a612cd602f..c5382d25412484fbb99123f572a6a65fdbfe890c 100644 (file)
@@ -26,6 +26,7 @@ using namespace icinga;
 
 HttpRequest::HttpRequest(Stream::Ptr stream)
        : CompleteHeaders(false),
+       CompleteHeaderCheck(false),
        CompleteBody(false),
        ProtocolVersion(HttpVersion11),
        Headers(new Dictionary()),
@@ -39,7 +40,7 @@ bool HttpRequest::ParseHeader(StreamReadContext& src, bool may_wait)
                return false;
 
        if (m_State != HttpRequestStart && m_State != HttpRequestHeaders)
-               return false;
+               BOOST_THROW_EXCEPTION(std::runtime_error("Invalid HTTP state"));
 
        String line;
        StreamReadStatus srs = m_Stream->ReadLine(&line, src, may_wait);
@@ -105,19 +106,19 @@ bool HttpRequest::ParseHeader(StreamReadContext& src, bool may_wait)
 
 bool HttpRequest::ParseBody(StreamReadContext& src, bool may_wait)
 {
-       if (!m_Stream || m_State != HttpRequestBody)
+       if (!m_Stream)
                return false;
 
+       if (m_State != HttpRequestBody)
+               BOOST_THROW_EXCEPTION(std::runtime_error("Invalid HTTP state"));
+
        /* we're done if the request doesn't contain a message body */
        if (!Headers->Contains("content-length") && !Headers->Contains("transfer-encoding")) {
                CompleteBody = true;
-               return true;
+               return false;
        } else if (!m_Body)
                m_Body = new FIFO();
 
-       if (CompleteBody)
-               return true;
-
        if (Headers->Get("transfer-encoding") == "chunked") {
                if (!m_ChunkContext)
                        m_ChunkContext = std::make_shared<ChunkReadContext>(std::ref(src));
@@ -135,39 +136,38 @@ bool HttpRequest::ParseBody(StreamReadContext& src, bool may_wait)
 
                if (size == 0) {
                        CompleteBody = true;
+                       return false;
+               } else
                        return true;
-               }
-       } else {
-               if (src.Eof)
-                       BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body"));
+       }
 
-               if (src.MustRead) {
-                       if (!src.FillFromStream(m_Stream, false)) {
-                               src.Eof = true;
-                               BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body"));
-                       }
+       if (src.Eof)
+               BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body"));
 
-                       src.MustRead = false;
+       if (src.MustRead) {
+               if (!src.FillFromStream(m_Stream, false)) {
+                       src.Eof = true;
+                       BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body"));
                }
 
-               long length_indicator_signed = Convert::ToLong(Headers->Get("content-length"));
+               src.MustRead = false;
+       }
 
-               if (length_indicator_signed < 0)
-                       BOOST_THROW_EXCEPTION(std::invalid_argument("Content-Length must not be negative."));
+       long length_indicator_signed = Convert::ToLong(Headers->Get("content-length"));
 
-               size_t length_indicator = length_indicator_signed;
+       if (length_indicator_signed < 0)
+               BOOST_THROW_EXCEPTION(std::invalid_argument("Content-Length must not be negative."));
 
-               if (src.Size < length_indicator) {
-                       src.MustRead = true;
-                       return false;
-               }
+       size_t length_indicator = length_indicator_signed;
 
-               m_Body->Write(src.Buffer, length_indicator);
-               src.DropData(length_indicator);
-               CompleteBody = true;
-               return true;
+       if (src.Size < length_indicator) {
+               src.MustRead = true;
+               return false;
        }
 
+       m_Body->Write(src.Buffer, length_indicator);
+       src.DropData(length_indicator);
+       CompleteBody = true;
        return true;
 }
 
index e8591474cb2e5681a88a3e7eafa060be9e69a03e..b456d72294414138ff0a25538babdc29ae20008e 100644 (file)
@@ -53,6 +53,7 @@ struct HttpRequest
 {
 public:
        bool CompleteHeaders;
+       bool CompleteHeaderCheck;
        bool CompleteBody;
 
        String RequestMethod;
index 9e400da161e769039d3060129f90041206af6e97..c22f0ee155bfb06869497bea13983cd43f8419d6 100644 (file)
@@ -90,100 +90,105 @@ void HttpServerConnection::Disconnect()
 bool HttpServerConnection::ProcessMessage()
 {
        bool res;
+       HttpResponse response(m_Stream, m_CurrentRequest);
 
-       try {
-               res = m_CurrentRequest.ParseHeader(m_Context, false);
-       } catch (const std::invalid_argument& ex) {
-               HttpResponse response(m_Stream, m_CurrentRequest);
-               response.SetStatus(400, "Bad request");
-               String msg = String("<h1>Bad request</h1><p><pre>") + ex.what() + "</pre></p>";
-               response.WriteBody(msg.CStr(), msg.GetLength());
-               response.Finish();
+       if (!m_CurrentRequest.CompleteHeaders) {
+               try {
+                       res = m_CurrentRequest.ParseHeader(m_Context, false);
+               } catch (const std::invalid_argument& ex) {
+                       response.SetStatus(400, "Bad Request");
+                       String msg = String("<h1>Bad Request</h1><p><pre>") + ex.what() + "</pre></p>";
+                       response.WriteBody(msg.CStr(), msg.GetLength());
+                       response.Finish();
 
-               m_Stream->Shutdown();
-               return false;
-       } catch (const std::exception& ex) {
-               HttpResponse response(m_Stream, m_CurrentRequest);
-               response.SetStatus(400, "Bad request");
-               String msg = "<h1>Bad request</h1><p><pre>" + DiagnosticInformation(ex) + "</pre></p>";
-               response.WriteBody(msg.CStr(), msg.GetLength());
-               response.Finish();
+                       m_Stream->Shutdown();
+                       return false;
+               } catch (const std::exception& ex) {
+                       response.SetStatus(500, "Internal Server Error");
+                       String msg = "<h1>Internal Server Error</h1><p><pre>" + DiagnosticInformation(ex) + "</pre></p>";
+                       response.WriteBody(msg.CStr(), msg.GetLength());
+                       response.Finish();
 
-               m_Stream->Shutdown();
-               return false;
+                       m_Stream->Shutdown();
+                       return false;
+               }
+               return res;
        }
 
-       if (m_CurrentRequest.CompleteHeaders) {
-               m_RequestQueue.Enqueue(std::bind(&HttpServerConnection::ProcessMessageAsync,
-                       HttpServerConnection::Ptr(this), m_CurrentRequest));
+       if (!m_CurrentRequest.CompleteHeaderCheck) {
+               m_CurrentRequest.CompleteHeaderCheck = true;
+               if (!ManageHeaders(response)) {
+                       m_Stream->Shutdown();
+                       return false;
+               }
+       }
 
-               m_Seen = Utility::GetTime();
-               m_PendingRequests++;
+       if (!m_CurrentRequest.CompleteBody) {
+               try {
+                       res = m_CurrentRequest.ParseBody(m_Context, false);
+               } catch (const std::invalid_argument& ex) {
+                       response.SetStatus(400, "Bad Request");
+                       String msg = String("<h1>Bad Request</h1><p><pre>") + ex.what() + "</pre></p>";
+                       response.WriteBody(msg.CStr(), msg.GetLength());
+                       response.Finish();
 
-               m_CurrentRequest.~HttpRequest();
-               new (&m_CurrentRequest) HttpRequest(m_Stream);
+                       m_Stream->Shutdown();
+                       return false;
+               } catch (const std::exception& ex) {
+                       response.SetStatus(500, "Internal Server Error");
+                       String msg = "<h1>Internal Server Error</h1><p><pre>" + DiagnosticInformation(ex) + "</pre></p>";
+                       response.WriteBody(msg.CStr(), msg.GetLength());
+                       response.Finish();
 
-               return true;
+                       m_Stream->Shutdown();
+                       return false;
+               }
+               return res;
        }
 
-       return res;
-}
-
-void HttpServerConnection::ProcessMessageAsync(HttpRequest& request)
-{
-       String auth_header = request.Headers->Get("authorization");
+       m_RequestQueue.Enqueue(std::bind(&HttpServerConnection::ProcessMessageAsync,
+               HttpServerConnection::Ptr(this), m_CurrentRequest, response, m_AuthenticatedUser));
 
-       String::SizeType pos = auth_header.FindFirstOf(" ");
-       String username, password;
+       m_Seen = Utility::GetTime();
+       m_PendingRequests++;
 
-       if (pos != String::NPos && auth_header.SubStr(0, pos) == "Basic") {
-               String credentials_base64 = auth_header.SubStr(pos + 1);
-               String credentials = Base64::Decode(credentials_base64);
+       m_CurrentRequest.~HttpRequest();
+       new (&m_CurrentRequest) HttpRequest(m_Stream);
 
-               String::SizeType cpos = credentials.FindFirstOf(":");
+       return false;
+}
 
-               if (cpos != String::NPos) {
-                       username = credentials.SubStr(0, cpos);
-                       password = credentials.SubStr(cpos + 1);
-               }
+bool HttpServerConnection::ManageHeaders(HttpResponse& response)
+{
+       if (m_CurrentRequest.Headers->Get("expect") == "100-continue") {
+               String continueResponse = "HTTP/1.1 100 Continue\r\n\r\n";
+               m_Stream->Write(continueResponse.CStr(), continueResponse.GetLength());
        }
 
-       ApiUser::Ptr user;
-
        /* client_cn matched. */
        if (m_ApiUser)
-               user = m_ApiUser;
-       else {
-               user = ApiUser::GetByName(username);
-
-               /* Deny authentication if 1) given password is empty 2) configured password does not match. */
-               if (password.IsEmpty())
-                       user.reset();
-               else if (user && user->GetPassword() != password)
-                       user.reset();
-       }
+               m_AuthenticatedUser = m_ApiUser;
+       else
+               m_AuthenticatedUser = ApiUser::GetByAuthHeader(m_CurrentRequest.Headers->Get("authorization"));
 
-       String requestUrl = request.RequestUrl->Format();
+       String requestUrl = m_CurrentRequest.RequestUrl->Format();
 
        Socket::Ptr socket = m_Stream->GetSocket();
 
        Log(LogInformation, "HttpServerConnection")
-               << "Request: " << request.RequestMethod << " " << requestUrl
+               << "Request: " << m_CurrentRequest.RequestMethod << " " << requestUrl
                << " (from " << (socket ? socket->GetPeerAddress() : "<unkown>")
-               << ", user: " << (user ? user->GetName() : "<unauthenticated>") << ")";
-
-       HttpResponse response(m_Stream, request);
+               << ", user: " << (m_AuthenticatedUser ? m_AuthenticatedUser->GetName() : "<unauthenticated>") << ")";
 
        ApiListener::Ptr listener = ApiListener::GetInstance();
 
        if (!listener)
-               return;
+               return false;
 
        Array::Ptr headerAllowOrigin = listener->GetAccessControlAllowOrigin();
 
        if (headerAllowOrigin->GetLength() != 0) {
-               String origin = request.Headers->Get("origin");
-
+               String origin = m_CurrentRequest.Headers->Get("origin");
                {
                        ObjectLock olock(headerAllowOrigin);
 
@@ -196,9 +201,9 @@ void HttpServerConnection::ProcessMessageAsync(HttpRequest& request)
                if (listener->GetAccessControlAllowCredentials())
                        response.AddHeader("Access-Control-Allow-Credentials", "true");
 
-               String accessControlRequestMethodHeader = request.Headers->Get("access-control-request-method");
+               String accessControlRequestMethodHeader = m_CurrentRequest.Headers->Get("access-control-request-method");
 
-               if (!accessControlRequestMethodHeader.IsEmpty()) {
+               if (m_CurrentRequest.RequestMethod == "OPTIONS" && !accessControlRequestMethodHeader.IsEmpty()) {
                        response.SetStatus(200, "OK");
 
                        response.AddHeader("Access-Control-Allow-Methods", listener->GetAccessControlAllowMethods());
@@ -208,27 +213,27 @@ void HttpServerConnection::ProcessMessageAsync(HttpRequest& request)
                        response.WriteBody(msg.CStr(), msg.GetLength());
 
                        response.Finish();
-                       m_PendingRequests--;
-
-                       return;
+                       return false;
                }
        }
 
-       String accept_header = request.Headers->Get("accept");
-
-       if (request.RequestMethod != "GET" && accept_header != "application/json") {
+       if (m_CurrentRequest.RequestMethod != "GET" && m_CurrentRequest.Headers->Get("accept") != "application/json") {
                response.SetStatus(400, "Wrong Accept header");
                response.AddHeader("Content-Type", "text/html");
                String msg = "<h1>Accept header is missing or not set to 'application/json'.</h1>";
                response.WriteBody(msg.CStr(), msg.GetLength());
-       } else if (!user) {
+               response.Finish();
+               return false;
+       }
+
+       if (!m_AuthenticatedUser) {
                Log(LogWarning, "HttpServerConnection")
-                       << "Unauthorized request: " << request.RequestMethod << " " << requestUrl;
+                       << "Unauthorized request: " << m_CurrentRequest.RequestMethod << " " << requestUrl;
 
                response.SetStatus(401, "Unauthorized");
                response.AddHeader("WWW-Authenticate", "Basic realm=\"Icinga 2\"");
 
-               if (request.Headers->Get("accept") == "application/json") {
+               if (m_CurrentRequest.Headers->Get("accept") == "application/json") {
                        Dictionary::Ptr result = new Dictionary({
                                { "error", 401 },
                                { "status", "Unauthorized. Please check your user credentials." }
@@ -240,44 +245,25 @@ void HttpServerConnection::ProcessMessageAsync(HttpRequest& request)
                        String msg = "<h1>Unauthorized. Please check your user credentials.</h1>";
                        response.WriteBody(msg.CStr(), msg.GetLength());
                }
-       } else {
-               bool res = true;
-               while (!request.CompleteBody)
-                       res = request.ParseBody(m_Context, false);
-               if (!res) {
-                       Log(LogCritical, "HttpServerConnection", "Failed to read body");
-                       Dictionary::Ptr result = new Dictionary({
-                               { "error", 400 },
-                               { "status", "Bad Request: Malformed body." }
-                       });
-                       HttpUtility::SendJsonBody(response, nullptr, result);
-               } else {
-                       try {
-                               HttpHandler::ProcessRequest(user, request, response);
-                       } catch (const std::exception& ex) {
-                               Log(LogCritical, "HttpServerConnection")
-                                       << "Unhandled exception while processing Http request: " << DiagnosticInformation(ex);
-                               response.SetStatus(503, "Unhandled exception");
-
-                               String errorInfo = DiagnosticInformation(ex);
-
-                               if (request.Headers->Get("accept") == "application/json") {
-                                       Dictionary::Ptr result = new Dictionary({
-                                               { "error", 503 },
-                                               { "status", errorInfo }
-                                       });
-
-                                       HttpUtility::SendJsonBody(response, nullptr, result);
-                               } else {
-                                       response.AddHeader("Content-Type", "text/plain");
-                                       response.WriteBody(errorInfo.CStr(), errorInfo.GetLength());
-                               }
-                       }
-               }
+
+               response.Finish();
+               return false;
        }
 
-       response.Finish();
+       return true;
+}
 
+void HttpServerConnection::ProcessMessageAsync(HttpRequest& request, HttpResponse& response, const ApiUser::Ptr& user)
+{
+       try {
+               HttpHandler::ProcessRequest(user, request, response);
+       } catch (const std::exception& ex) {
+               Log(LogCritical, "HttpServerConnection")
+                       << "Unhandled exception while processing Http request: " << DiagnosticInformation(ex);
+               HttpUtility::SendJsonError(response, nullptr, 503, "Unhandled exception" , DiagnosticInformation(ex));
+       }
+
+       response.Finish();
        m_PendingRequests--;
 }
 
index f521100136ff3404aafc4ed66fcc33f303843c2e..104df75093e3c4c0a32fb096d29cf672d9614d57 100644 (file)
@@ -21,6 +21,7 @@
 #define HTTPSERVERCONNECTION_H
 
 #include "remote/httprequest.hpp"
+#include "remote/httpresponse.hpp"
 #include "remote/apiuser.hpp"
 #include "base/tlsstream.hpp"
 #include "base/timer.hpp"
@@ -51,6 +52,7 @@ public:
 
 private:
        ApiUser::Ptr m_ApiUser;
+       ApiUser::Ptr m_AuthenticatedUser;
        TlsStream::Ptr m_Stream;
        double m_Seen;
        HttpRequest m_CurrentRequest;
@@ -67,7 +69,9 @@ private:
        static void TimeoutTimerHandler();
        void CheckLiveness();
 
-       void ProcessMessageAsync(HttpRequest& request);
+       bool ManageHeaders(HttpResponse& response);
+
+       void ProcessMessageAsync(HttpRequest& request, HttpResponse& response, const ApiUser::Ptr&);
 };
 
 }