]> granicus.if.org Git - icinga2/blob - lib/base/netstring.cpp
489a8b40db727a00db9dee1652be5b2d6151d598
[icinga2] / lib / base / netstring.cpp
1 /* Icinga 2 | (c) 2012 Icinga GmbH | GPLv2+ */
2
3 #include "base/netstring.hpp"
4 #include "base/debug.hpp"
5 #include "base/tlsstream.hpp"
6 #include <cstdint>
7 #include <memory>
8 #include <sstream>
9 #include <utility>
10 #include <boost/asio/buffer.hpp>
11 #include <boost/asio/read.hpp>
12 #include <boost/asio/spawn.hpp>
13 #include <boost/asio/write.hpp>
14
15 using namespace icinga;
16
17 /**
18  * Reads data from a stream in netstring format.
19  *
20  * @param stream The stream to read from.
21  * @param[out] str The String that has been read from the IOQueue.
22  * @returns true if a complete String was read from the IOQueue, false otherwise.
23  * @exception invalid_argument The input stream is invalid.
24  * @see https://github.com/PeterScott/netstring-c/blob/master/netstring.c
25  */
26 StreamReadStatus NetString::ReadStringFromStream(const Stream::Ptr& stream, String *str, StreamReadContext& context,
27         bool may_wait, ssize_t maxMessageLength)
28 {
29         if (context.Eof)
30                 return StatusEof;
31
32         if (context.MustRead) {
33                 if (!context.FillFromStream(stream, may_wait)) {
34                         context.Eof = true;
35                         return StatusEof;
36                 }
37
38                 context.MustRead = false;
39         }
40
41         size_t header_length = 0;
42
43         for (size_t i = 0; i < context.Size; i++) {
44                 if (context.Buffer[i] == ':') {
45                         header_length = i;
46
47                         /* make sure there's a header */
48                         if (header_length == 0)
49                                 BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (no length specifier)"));
50
51                         break;
52                 } else if (i > 16)
53                         BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing :)"));
54         }
55
56         if (header_length == 0) {
57                 context.MustRead = true;
58                 return StatusNeedData;
59         }
60
61         /* no leading zeros allowed */
62         if (context.Buffer[0] == '0' && isdigit(context.Buffer[1]))
63                 BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (leading zero)"));
64
65         size_t len, i;
66
67         len = 0;
68         for (i = 0; i < header_length && isdigit(context.Buffer[i]); i++) {
69                 /* length specifier must have at most 9 characters */
70                 if (i >= 9)
71                         BOOST_THROW_EXCEPTION(std::invalid_argument("Length specifier must not exceed 9 characters"));
72
73                 len = len * 10 + (context.Buffer[i] - '0');
74         }
75
76         /* read the whole message */
77         size_t data_length = len + 1;
78
79         if (maxMessageLength >= 0 && data_length > (size_t)maxMessageLength) {
80                 std::stringstream errorMessage;
81                 errorMessage << "Max data length exceeded: " << (maxMessageLength / 1024) << " KB";
82
83                 BOOST_THROW_EXCEPTION(std::invalid_argument(errorMessage.str()));
84         }
85
86         char *data = context.Buffer + header_length + 1;
87
88         if (context.Size < header_length + 1 + data_length) {
89                 context.MustRead = true;
90                 return StatusNeedData;
91         }
92
93         if (data[len] != ',')
94                 BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing ,)"));
95
96         *str = String(&data[0], &data[len]);
97
98         context.DropData(header_length + 1 + len + 1);
99
100         return StatusNewItem;
101 }
102
103 /**
104  * Writes data into a stream using the netstring format and returns bytes written.
105  *
106  * @param stream The stream.
107  * @param str The String that is to be written.
108  *
109  * @return The amount of bytes written.
110  */
111 size_t NetString::WriteStringToStream(const Stream::Ptr& stream, const String& str)
112 {
113         std::ostringstream msgbuf;
114         WriteStringToStream(msgbuf, str);
115
116         String msg = msgbuf.str();
117         stream->Write(msg.CStr(), msg.GetLength());
118         return msg.GetLength();
119 }
120
121 /**
122  * Reads data from a stream in netstring format.
123  *
124  * @param stream The stream to read from.
125  * @returns The String that has been read from the IOQueue.
126  * @exception invalid_argument The input stream is invalid.
127  * @see https://github.com/PeterScott/netstring-c/blob/master/netstring.c
128  */
129 String NetString::ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& stream,
130         boost::asio::yield_context yc, ssize_t maxMessageLength)
131 {
132         namespace asio = boost::asio;
133
134         size_t len = 0;
135         bool leadingZero = false;
136
137         for (uint_fast8_t readBytes = 0;; ++readBytes) {
138                 char byte = 0;
139
140                 {
141                         asio::mutable_buffer byteBuf (&byte, 1);
142                         asio::async_read(*stream, byteBuf, yc);
143                 }
144
145                 if (isdigit(byte)) {
146                         if (readBytes == 9) {
147                                 BOOST_THROW_EXCEPTION(std::invalid_argument("Length specifier must not exceed 9 characters"));
148                         }
149
150                         if (leadingZero) {
151                                 BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (leading zero)"));
152                         }
153
154                         len = len * 10u + size_t(byte - '0');
155
156                         if (!readBytes && byte == '0') {
157                                 leadingZero = true;
158                         }
159                 } else if (byte == ':') {
160                         if (!readBytes) {
161                                 BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (no length specifier)"));
162                         }
163
164                         break;
165                 } else {
166                         BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing :)"));
167                 }
168         }
169
170         if (maxMessageLength >= 0 && len > maxMessageLength) {
171                 std::stringstream errorMessage;
172                 errorMessage << "Max data length exceeded: " << (maxMessageLength / 1024) << " KB";
173
174                 BOOST_THROW_EXCEPTION(std::invalid_argument(errorMessage.str()));
175         }
176
177         String payload;
178
179         if (len) {
180                 payload.Append(len, 0);
181
182                 asio::mutable_buffer payloadBuf (&*payload.Begin(), payload.GetLength());
183                 asio::async_read(*stream, payloadBuf, yc);
184         }
185
186         char trailer = 0;
187
188         {
189                 asio::mutable_buffer trailerBuf (&trailer, 1);
190                 asio::async_read(*stream, trailerBuf, yc);
191         }
192
193         if (trailer != ',') {
194                 BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing ,)"));
195         }
196
197         return std::move(payload);
198 }
199
200 /**
201  * Writes data into a stream using the netstring format and returns bytes written.
202  *
203  * @param stream The stream.
204  * @param str The String that is to be written.
205  *
206  * @return The amount of bytes written.
207  */
208 size_t NetString::WriteStringToStream(const std::shared_ptr<AsioTlsStream>& stream, const String& str, boost::asio::yield_context yc)
209 {
210         namespace asio = boost::asio;
211
212         std::ostringstream msgbuf;
213         WriteStringToStream(msgbuf, str);
214
215         String msg = msgbuf.str();
216         asio::const_buffer msgBuf (msg.CStr(), msg.GetLength());
217
218         asio::async_write(*stream, msgBuf, yc);
219
220         return msg.GetLength();
221 }
222
223 /**
224  * Writes data into a stream using the netstring format.
225  *
226  * @param stream The stream.
227  * @param str The String that is to be written.
228  */
229 void NetString::WriteStringToStream(std::ostream& stream, const String& str)
230 {
231         stream << str.GetLength() << ":" << str << ",";
232 }