]> granicus.if.org Git - icinga2/commitdiff
Add non-async overloads for NetString::ReadStringFromStream() and NetString::WriteStr...
authorAlexander A. Klimov <alexander.klimov@icinga.com>
Mon, 25 Feb 2019 17:12:32 +0000 (18:12 +0100)
committerAlexander A. Klimov <alexander.klimov@icinga.com>
Mon, 1 Apr 2019 15:11:10 +0000 (17:11 +0200)
lib/base/netstring.cpp
lib/base/netstring.hpp

index 489a8b40db727a00db9dee1652be5b2d6151d598..2be7675a7247f4fe5873db6eef42de8b5807be0a 100644 (file)
@@ -118,6 +118,85 @@ size_t NetString::WriteStringToStream(const Stream::Ptr& stream, const String& s
        return msg.GetLength();
 }
 
+/**
+ * Reads data from a stream in netstring format.
+ *
+ * @param stream The stream to read from.
+ * @returns The String that has been read from the IOQueue.
+ * @exception invalid_argument The input stream is invalid.
+ * @see https://github.com/PeterScott/netstring-c/blob/master/netstring.c
+ */
+String NetString::ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& stream,
+       ssize_t maxMessageLength)
+{
+       namespace asio = boost::asio;
+
+       size_t len = 0;
+       bool leadingZero = false;
+
+       for (uint_fast8_t readBytes = 0;; ++readBytes) {
+               char byte = 0;
+
+               {
+                       asio::mutable_buffer byteBuf (&byte, 1);
+                       asio::read(*stream, byteBuf);
+               }
+
+               if (isdigit(byte)) {
+                       if (readBytes == 9) {
+                               BOOST_THROW_EXCEPTION(std::invalid_argument("Length specifier must not exceed 9 characters"));
+                       }
+
+                       if (leadingZero) {
+                               BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (leading zero)"));
+                       }
+
+                       len = len * 10u + size_t(byte - '0');
+
+                       if (!readBytes && byte == '0') {
+                               leadingZero = true;
+                       }
+               } else if (byte == ':') {
+                       if (!readBytes) {
+                               BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (no length specifier)"));
+                       }
+
+                       break;
+               } else {
+                       BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing :)"));
+               }
+       }
+
+       if (maxMessageLength >= 0 && len > maxMessageLength) {
+               std::stringstream errorMessage;
+               errorMessage << "Max data length exceeded: " << (maxMessageLength / 1024) << " KB";
+
+               BOOST_THROW_EXCEPTION(std::invalid_argument(errorMessage.str()));
+       }
+
+       String payload;
+
+       if (len) {
+               payload.Append(len, 0);
+
+               asio::mutable_buffer payloadBuf (&*payload.Begin(), payload.GetLength());
+               asio::read(*stream, payloadBuf);
+       }
+
+       char trailer = 0;
+
+       {
+               asio::mutable_buffer trailerBuf (&trailer, 1);
+               asio::read(*stream, trailerBuf);
+       }
+
+       if (trailer != ',') {
+               BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing ,)"));
+       }
+
+       return std::move(payload);
+}
+
 /**
  * Reads data from a stream in netstring format.
  *
@@ -197,6 +276,29 @@ String NetString::ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& str
        return std::move(payload);
 }
 
+/**
+ * Writes data into a stream using the netstring format and returns bytes written.
+ *
+ * @param stream The stream.
+ * @param str The String that is to be written.
+ *
+ * @return The amount of bytes written.
+ */
+size_t NetString::WriteStringToStream(const std::shared_ptr<AsioTlsStream>& stream, const String& str)
+{
+       namespace asio = boost::asio;
+
+       std::ostringstream msgbuf;
+       WriteStringToStream(msgbuf, str);
+
+       String msg = msgbuf.str();
+       asio::const_buffer msgBuf (msg.CStr(), msg.GetLength());
+
+       asio::write(*stream, msgBuf);
+
+       return msg.GetLength();
+}
+
 /**
  * Writes data into a stream using the netstring format and returns bytes written.
  *
index f84eac7a3138b6ac316f03a491258b17faf7d0f5..2d24359075ad83bd53831d423ea12b066c7ca0af 100644 (file)
@@ -26,9 +26,11 @@ class NetString
 public:
        static StreamReadStatus ReadStringFromStream(const Stream::Ptr& stream, String *message, StreamReadContext& context,
                bool may_wait = false, ssize_t maxMessageLength = -1);
+       static String ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& stream, ssize_t maxMessageLength = -1);
        static String ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& stream,
                boost::asio::yield_context yc, ssize_t maxMessageLength = -1);
        static size_t WriteStringToStream(const Stream::Ptr& stream, const String& message);
+       static size_t WriteStringToStream(const std::shared_ptr<AsioTlsStream>& stream, const String& message);
        static size_t WriteStringToStream(const std::shared_ptr<AsioTlsStream>& stream, const String& message, boost::asio::yield_context yc);
        static void WriteStringToStream(std::ostream& stream, const String& message);