]> granicus.if.org Git - pgbouncer/blob - src/sbuf.c
Fix some scan-build warnings
[pgbouncer] / src / sbuf.c
1 /*
2  * PgBouncer - Lightweight connection pooler for PostgreSQL.
3  *
4  * Copyright (c) 2007-2009  Marko Kreen, Skype Technologies OÜ
5  *
6  * Permission to use, copy, modify, and/or distribute this software for any
7  * purpose with or without fee is hereby granted, provided that the above
8  * copyright notice and this permission notice appear in all copies.
9  *
10  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17  */
18
19 /*
20  * Stream buffer
21  *
22  * The task is to copy data from one socket to another
23  * efficiently, while allowing callbacks to look
24  * at packet headers.
25  */
26
27 #include "bouncer.h"
28
29 #ifdef USUAL_LIBSSL_FOR_TLS
30 #define USE_TLS
31 #endif
32
33 /* sbuf_main_loop() skip_recv values */
34 #define DO_RECV         false
35 #define SKIP_RECV       true
36
37 #define ACT_UNSET 0
38 #define ACT_SEND 1
39 #define ACT_SKIP 2
40 #define ACT_CALL 3
41
42 enum TLSState {
43         SBUF_TLS_NONE,
44         SBUF_TLS_DO_HANDSHAKE,
45         SBUF_TLS_IN_HANDSHAKE,
46         SBUF_TLS_OK,
47 };
48
49 enum WaitType {
50         W_NONE = 0,
51         W_CONNECT,
52         W_RECV,
53         W_SEND,
54         W_ONCE
55 };
56
57 #define AssertSanity(sbuf) do { \
58         Assert(iobuf_sane((sbuf)->io)); \
59 } while (0)
60
61 #define AssertActive(sbuf) do { \
62         Assert((sbuf)->sock > 0); \
63         AssertSanity(sbuf); \
64 } while (0)
65
66 /* declare static stuff */
67 static bool sbuf_queue_send(SBuf *sbuf) _MUSTCHECK;
68 static bool sbuf_send_pending(SBuf *sbuf) _MUSTCHECK;
69 static bool sbuf_process_pending(SBuf *sbuf) _MUSTCHECK;
70 static void sbuf_connect_cb(evutil_socket_t sock, short flags, void *arg);
71 static void sbuf_recv_cb(evutil_socket_t sock, short flags, void *arg);
72 static void sbuf_send_cb(evutil_socket_t sock, short flags, void *arg);
73 static void sbuf_try_resync(SBuf *sbuf, bool release);
74 static bool sbuf_wait_for_data(SBuf *sbuf) _MUSTCHECK;
75 static void sbuf_main_loop(SBuf *sbuf, bool skip_recv);
76 static bool sbuf_call_proto(SBuf *sbuf, int event) /* _MUSTCHECK */;
77 static bool sbuf_actual_recv(SBuf *sbuf, unsigned len)  _MUSTCHECK;
78 static bool sbuf_after_connect_check(SBuf *sbuf)  _MUSTCHECK;
79 static bool handle_tls_handshake(SBuf *sbuf) /* _MUSTCHECK */;
80
81 /* regular I/O */
82 static int raw_sbufio_recv(struct SBuf *sbuf, void *dst, unsigned int len);
83 static int raw_sbufio_send(struct SBuf *sbuf, const void *data, unsigned int len);
84 static int raw_sbufio_close(struct SBuf *sbuf);
85 static const SBufIO raw_sbufio_ops = {
86         raw_sbufio_recv,
87         raw_sbufio_send,
88         raw_sbufio_close
89 };
90
91 /* I/O over TLS */
92 #ifdef USE_TLS
93 static int tls_sbufio_recv(struct SBuf *sbuf, void *dst, unsigned int len);
94 static int tls_sbufio_send(struct SBuf *sbuf, const void *data, unsigned int len);
95 static int tls_sbufio_close(struct SBuf *sbuf);
96 static const SBufIO tls_sbufio_ops = {
97         tls_sbufio_recv,
98         tls_sbufio_send,
99         tls_sbufio_close
100 };
101 static void sbuf_tls_handshake_cb(evutil_socket_t fd, short flags, void *_sbuf);
102 #endif
103
104 /*********************************
105  * Public functions
106  *********************************/
107
108 /* initialize SBuf with proto handler */
109 void sbuf_init(SBuf *sbuf, sbuf_cb_t proto_fn)
110 {
111         memset(sbuf, 0, sizeof(SBuf));
112         sbuf->proto_cb = proto_fn;
113         sbuf->ops = &raw_sbufio_ops;
114 }
115
116 /* got new socket from accept() */
117 bool sbuf_accept(SBuf *sbuf, int sock, bool is_unix)
118 {
119         bool res;
120
121         Assert(iobuf_empty(sbuf->io) && sbuf->sock == 0);
122         AssertSanity(sbuf);
123
124         sbuf->sock = sock;
125         if (!tune_socket(sock, is_unix))
126                 goto failed;
127
128         if (!cf_reboot) {
129                 res = sbuf_wait_for_data(sbuf);
130                 if (!res)
131                         goto failed;
132                 /* socket should already have some data (linux only) */
133                 if (cf_tcp_defer_accept && !is_unix) {
134                         sbuf_main_loop(sbuf, DO_RECV);
135                         if (!sbuf->sock)
136                                 return false;
137                 }
138         }
139         return true;
140 failed:
141         sbuf_call_proto(sbuf, SBUF_EV_RECV_FAILED);
142         return false;
143 }
144
145 /* need to connect() to get a socket */
146 bool sbuf_connect(SBuf *sbuf, const struct sockaddr *sa, int sa_len, int timeout_sec)
147 {
148         int res, sock;
149         struct timeval timeout;
150         bool is_unix = sa->sa_family == AF_UNIX;
151
152         Assert(iobuf_empty(sbuf->io) && sbuf->sock == 0);
153         AssertSanity(sbuf);
154
155         /*
156          * common stuff
157          */
158         sock = socket(sa->sa_family, SOCK_STREAM, 0);
159         if (sock < 0) {
160                 /* probably fd limit */
161                 goto failed;
162         }
163
164         if (!tune_socket(sock, is_unix))
165                 goto failed;
166
167         sbuf->sock = sock;
168
169         timeout.tv_sec = timeout_sec;
170         timeout.tv_usec = 0;
171
172         /* launch connection */
173         res = safe_connect(sock, sa, sa_len);
174         if (res == 0) {
175                 /* unix socket gives connection immediately */
176                 sbuf_connect_cb(sock, EV_WRITE, sbuf);
177                 return true;
178         } else if (errno == EINPROGRESS || errno == EAGAIN) {
179                 /* tcp socket needs waiting */
180                 event_set(&sbuf->ev, sock, EV_WRITE, sbuf_connect_cb, sbuf);
181                 res = event_add(&sbuf->ev, &timeout);
182                 if (res >= 0) {
183                         sbuf->wait_type = W_CONNECT;
184                         return true;
185                 }
186         }
187
188 failed:
189         log_warning("sbuf_connect failed: %s", strerror(errno));
190
191         if (sock >= 0)
192                 safe_close(sock);
193         sbuf->sock = 0;
194         sbuf_call_proto(sbuf, SBUF_EV_CONNECT_FAILED);
195         return false;
196 }
197
198 /* don't wait for data on this socket */
199 bool sbuf_pause(SBuf *sbuf)
200 {
201         AssertActive(sbuf);
202         Assert(sbuf->wait_type == W_RECV);
203
204         if (event_del(&sbuf->ev) < 0) {
205                 log_warning("event_del: %s", strerror(errno));
206                 return false;
207         }
208         sbuf->wait_type = W_NONE;
209         return true;
210 }
211
212 /* resume from pause, start waiting for data */
213 void sbuf_continue(SBuf *sbuf)
214 {
215         bool do_recv = DO_RECV;
216         bool res;
217         AssertActive(sbuf);
218
219         res = sbuf_wait_for_data(sbuf);
220         if (!res) {
221                 /* drop if problems */
222                 sbuf_call_proto(sbuf, SBUF_EV_RECV_FAILED);
223                 return;
224         }
225
226         /*
227          * It's tempting to try to avoid the recv() but that would
228          * only work if no code wants to see full packet.
229          *
230          * This is not true in ServerParameter case.
231          */
232         /*
233          * if (sbuf->recv_pos - sbuf->pkt_pos >= SBUF_SMALL_PKT)
234          *      do_recv = false;
235          */
236
237         sbuf_main_loop(sbuf, do_recv);
238 }
239
240 /*
241  * Resume from pause and give socket over to external
242  * callback function.
243  *
244  * The callback will be called with arg given to sbuf_init.
245  */
246 bool sbuf_continue_with_callback(SBuf *sbuf, sbuf_libevent_cb user_cb)
247 {
248         int err;
249
250         AssertActive(sbuf);
251
252         event_set(&sbuf->ev, sbuf->sock, EV_READ | EV_PERSIST,
253                   user_cb, sbuf);
254
255         err = event_add(&sbuf->ev, NULL);
256         if (err < 0) {
257                 log_warning("sbuf_continue_with_callback: %s", strerror(errno));
258                 return false;
259         }
260         sbuf->wait_type = W_RECV;
261         return true;
262 }
263
264 bool sbuf_use_callback_once(SBuf *sbuf, short ev, sbuf_libevent_cb user_cb)
265 {
266         int err;
267         AssertActive(sbuf);
268
269         if (sbuf->wait_type != W_NONE) {
270                 err = event_del(&sbuf->ev);
271                 sbuf->wait_type = W_NONE; /* make sure its called only once */
272                 if (err < 0) {
273                         log_warning("sbuf_queue_once: event_del failed: %s", strerror(errno));
274                         return false;
275                 }
276         }
277
278         /* setup one one-off event handler */
279         event_set(&sbuf->ev, sbuf->sock, ev, user_cb, sbuf);
280         err = event_add(&sbuf->ev, NULL);
281         if (err < 0) {
282                 log_warning("sbuf_queue_once: event_add failed: %s", strerror(errno));
283                 return false;
284         }
285         sbuf->wait_type = W_ONCE;
286         return true;
287 }
288
289 /* socket cleanup & close: keeps .handler and .arg values */
290 bool sbuf_close(SBuf *sbuf)
291 {
292         if (sbuf->wait_type) {
293                 Assert(sbuf->sock);
294                 /* event_del() acts funny occasionally, debug it */
295                 errno = 0;
296                 if (event_del(&sbuf->ev) < 0) {
297                         if (errno) {
298                                 log_warning("event_del: %s", strerror(errno));
299                         } else {
300                                 log_warning("event_del: libevent error");
301                         }
302                         /* we can retry whole sbuf_close() if needed */
303                         /* if (errno == ENOMEM) return false; */
304                 }
305         }
306         sbuf_op_close(sbuf);
307         sbuf->dst = NULL;
308         sbuf->sock = 0;
309         sbuf->pkt_remain = 0;
310         sbuf->pkt_action = sbuf->wait_type = 0;
311         if (sbuf->io) {
312                 slab_free(iobuf_cache, sbuf->io);
313                 sbuf->io = NULL;
314         }
315         return true;
316 }
317
318 /* proto_fn tells to send some bytes to socket */
319 void sbuf_prepare_send(SBuf *sbuf, SBuf *dst, unsigned amount)
320 {
321         AssertActive(sbuf);
322         Assert(sbuf->pkt_remain == 0);
323         /* Assert(sbuf->pkt_action == ACT_UNSET || sbuf->pkt_action == ACT_SEND || iobuf_amount_pending(&sbuf->io)); */
324         Assert(amount > 0);
325
326         sbuf->pkt_action = ACT_SEND;
327         sbuf->pkt_remain = amount;
328         sbuf->dst = dst;
329 }
330
331 /* proto_fn tells to skip some amount of bytes */
332 void sbuf_prepare_skip(SBuf *sbuf, unsigned amount)
333 {
334         AssertActive(sbuf);
335         Assert(sbuf->pkt_remain == 0);
336         /* Assert(sbuf->pkt_action == ACT_UNSET || iobuf_send_pending_avail(&sbuf->io)); */
337         Assert(amount > 0);
338
339         sbuf->pkt_action = ACT_SKIP;
340         sbuf->pkt_remain = amount;
341 }
342
343 /* proto_fn tells to skip some amount of bytes */
344 void sbuf_prepare_fetch(SBuf *sbuf, unsigned amount)
345 {
346         AssertActive(sbuf);
347         Assert(sbuf->pkt_remain == 0);
348         /* Assert(sbuf->pkt_action == ACT_UNSET || iobuf_send_pending_avail(&sbuf->io)); */
349         Assert(amount > 0);
350
351         sbuf->pkt_action = ACT_CALL;
352         sbuf->pkt_remain = amount;
353         /* sbuf->dst = NULL; // FIXME ?? */
354 }
355
356 /*************************
357  * Internal functions
358  *************************/
359
360 /*
361  * Call proto callback with proper struct MBuf.
362  *
363  * If callback returns true it used one of sbuf_prepare_* on sbuf,
364  * and processing can continue.
365  *
366  * If it returned false it used sbuf_pause(), sbuf_close() or simply
367  * wants to wait for next event loop (e.g. too few data available).
368  * Callee should not touch sbuf in that case and just return to libevent.
369  */
370 static bool sbuf_call_proto(SBuf *sbuf, int event)
371 {
372         struct MBuf mbuf;
373         IOBuf *io = sbuf->io;
374         bool res;
375
376         AssertSanity(sbuf);
377         Assert(event != SBUF_EV_READ || iobuf_amount_parse(io) > 0);
378
379         /* if pkt callback, limit only with current packet */
380         if (event == SBUF_EV_PKT_CALLBACK) {
381                 iobuf_parse_limit(io, &mbuf, sbuf->pkt_remain);
382         } else if (event == SBUF_EV_READ) {
383                 iobuf_parse_all(io, &mbuf);
384         } else {
385                 memset(&mbuf, 0, sizeof(mbuf));
386         }
387         res = sbuf->proto_cb(sbuf, event, &mbuf);
388
389         AssertSanity(sbuf);
390         Assert(event != SBUF_EV_READ || !res || sbuf->sock > 0);
391
392         return res;
393 }
394
395 /* let's wait for new data */
396 static bool sbuf_wait_for_data(SBuf *sbuf)
397 {
398         int err;
399
400         event_set(&sbuf->ev, sbuf->sock, EV_READ | EV_PERSIST, sbuf_recv_cb, sbuf);
401         err = event_add(&sbuf->ev, NULL);
402         if (err < 0) {
403                 log_warning("sbuf_wait_for_data: event_add failed: %s", strerror(errno));
404                 return false;
405         }
406         sbuf->wait_type = W_RECV;
407         return true;
408 }
409
410 static void sbuf_recv_forced_cb(evutil_socket_t sock, short flags, void *arg)
411 {
412         SBuf *sbuf = arg;
413
414         sbuf->wait_type = W_NONE;
415
416         if (sbuf_wait_for_data(sbuf)) {
417                 sbuf_recv_cb(sock, flags, arg);
418         } else {
419                 sbuf_call_proto(sbuf, SBUF_EV_RECV_FAILED);
420         }
421 }
422
423 static bool sbuf_wait_for_data_forced(SBuf *sbuf)
424 {
425         int err;
426         struct timeval tv_min;
427
428         tv_min.tv_sec = 0;
429         tv_min.tv_usec = 1;
430
431         if (sbuf->wait_type != W_NONE) {
432                 event_del(&sbuf->ev);
433                 sbuf->wait_type = W_NONE;
434         }
435
436         event_set(&sbuf->ev, sbuf->sock, EV_READ, sbuf_recv_forced_cb, sbuf);
437         err = event_add(&sbuf->ev, &tv_min);
438         if (err < 0) {
439                 log_warning("sbuf_wait_for_data: event_add failed: %s", strerror(errno));
440                 return false;
441         }
442         sbuf->wait_type = W_ONCE;
443         return true;
444 }
445
446 /* libevent EV_WRITE: called when dest socket is writable again */
447 static void sbuf_send_cb(evutil_socket_t sock, short flags, void *arg)
448 {
449         SBuf *sbuf = arg;
450         bool res;
451
452         /* sbuf was closed before in this loop */
453         if (!sbuf->sock)
454                 return;
455
456         AssertSanity(sbuf);
457         Assert(sbuf->wait_type == W_SEND);
458
459         sbuf->wait_type = W_NONE;
460
461         /* prepare normal situation for sbuf_main_loop */
462         res = sbuf_wait_for_data(sbuf);
463         if (res) {
464                 /* here we should certainly skip recv() */
465                 sbuf_main_loop(sbuf, SKIP_RECV);
466         } else {
467                 /* drop if problems */
468                 sbuf_call_proto(sbuf, SBUF_EV_SEND_FAILED);
469         }
470 }
471
472 /* socket is full, wait until it's writable again */
473 static bool sbuf_queue_send(SBuf *sbuf)
474 {
475         int err;
476         AssertActive(sbuf);
477         Assert(sbuf->wait_type == W_RECV);
478
479         /* if false is returned, the socket will be closed later */
480
481         /* stop waiting for read events */
482         err = event_del(&sbuf->ev);
483         sbuf->wait_type = W_NONE; /* make sure its called only once */
484         if (err < 0) {
485                 log_warning("sbuf_queue_send: event_del failed: %s", strerror(errno));
486                 return false;
487         }
488
489         /* instead wait for EV_WRITE on destination socket */
490         event_set(&sbuf->ev, sbuf->dst->sock, EV_WRITE, sbuf_send_cb, sbuf);
491         err = event_add(&sbuf->ev, NULL);
492         if (err < 0) {
493                 log_warning("sbuf_queue_send: event_add failed: %s", strerror(errno));
494                 return false;
495         }
496         sbuf->wait_type = W_SEND;
497
498         return true;
499 }
500
501 /*
502  * There's data in buffer to be sent. Returns bool if processing can continue.
503  *
504  * Does not look at pkt_pos/remain fields, expects them to be merged to send_*
505  */
506 static bool sbuf_send_pending(SBuf *sbuf)
507 {
508         int res, avail;
509         IOBuf *io = sbuf->io;
510
511         AssertActive(sbuf);
512         Assert(sbuf->dst || iobuf_amount_pending(io) == 0);
513
514 try_more:
515         /* how much data is available for sending */
516         avail = iobuf_amount_pending(io);
517         if (avail == 0)
518                 return true;
519
520         if (sbuf->dst->sock == 0) {
521                 log_error("sbuf_send_pending: no dst sock?");
522                 return false;
523         }
524
525         /* actually send it */
526         //res = iobuf_send_pending(io, sbuf->dst->sock);
527         res = sbuf_op_send(sbuf->dst, io->buf + io->done_pos, avail);
528         if (res > 0) {
529                 io->done_pos += res;
530         } else if (res < 0) {
531                 if (errno == EAGAIN) {
532                         if (!sbuf_queue_send(sbuf))
533                                 /* drop if queue failed */
534                                 sbuf_call_proto(sbuf, SBUF_EV_SEND_FAILED);
535                 } else {
536                         sbuf_call_proto(sbuf, SBUF_EV_SEND_FAILED);
537                 }
538                 return false;
539         }
540
541         AssertActive(sbuf);
542
543         /*
544          * Should do sbuf_queue_send() immediately?
545          *
546          * To be sure, let's run into EAGAIN.
547          */
548         goto try_more;
549 }
550
551 /* process as much data as possible */
552 static bool sbuf_process_pending(SBuf *sbuf)
553 {
554         unsigned avail;
555         IOBuf *io = sbuf->io;
556         bool full = iobuf_amount_recv(io) <= 0;
557         bool res;
558
559         while (1) {
560                 AssertActive(sbuf);
561
562                 /*
563                  * Enough for now?
564                  *
565                  * The (avail <= SBUF_SMALL_PKT) check is to avoid partial pkts.
566                  * As SBuf should not assume knowledge about packets,
567                  * the check is not done in !full case.  Packet handler can
568                  * then still notify about partial packet by returning false.
569                  */
570                 avail = iobuf_amount_parse(io);
571                 if (avail == 0 || (full && avail <= SBUF_SMALL_PKT))
572                         break;
573
574                 /*
575                  * If start of packet, process packet header.
576                  */
577                 if (sbuf->pkt_remain == 0) {
578                         res = sbuf_call_proto(sbuf, SBUF_EV_READ);
579                         if (!res)
580                                 return false;
581                         Assert(sbuf->pkt_remain > 0);
582                 }
583
584                 if (sbuf->pkt_action == ACT_SKIP || sbuf->pkt_action == ACT_CALL) {
585                         /* send any pending data before skipping */
586                         if (iobuf_amount_pending(io) > 0) {
587                                 res = sbuf_send_pending(sbuf);
588                                 if (!res)
589                                         return res;
590                         }
591                 }
592
593                 if (avail > sbuf->pkt_remain)
594                         avail = sbuf->pkt_remain;
595
596                 switch (sbuf->pkt_action) {
597                 case ACT_SEND:
598                         iobuf_tag_send(io, avail);
599                         break;
600                 case ACT_CALL:
601                         res = sbuf_call_proto(sbuf, SBUF_EV_PKT_CALLBACK);
602                         if (!res)
603                                 return false;
604                         /* fallthrough */
605                         /* after callback, skip pkt */
606                 case ACT_SKIP:
607                         iobuf_tag_skip(io, avail);
608                         break;
609                 }
610                 sbuf->pkt_remain -= avail;
611         }
612
613         return sbuf_send_pending(sbuf);
614 }
615
616 /* reposition at buffer start again */
617 static void sbuf_try_resync(SBuf *sbuf, bool release)
618 {
619         IOBuf *io = sbuf->io;
620
621         if (io) {
622                 log_noise("resync(%d): done=%d, parse=%d, recv=%d",
623                           sbuf->sock,
624                           io->done_pos, io->parse_pos, io->recv_pos);
625         }
626         AssertActive(sbuf);
627
628         if (!io)
629                 return;
630
631         if (release && iobuf_empty(io)) {
632                 slab_free(iobuf_cache, io);
633                 sbuf->io = NULL;
634         } else {
635                 iobuf_try_resync(io, SBUF_SMALL_PKT);
636         }
637 }
638
639 /* actually ask kernel for more data */
640 static bool sbuf_actual_recv(SBuf *sbuf, unsigned len)
641 {
642         int got;
643         IOBuf *io = sbuf->io;
644         uint8_t *dst = io->buf + io->recv_pos;
645         unsigned avail = iobuf_amount_recv(io);
646         if (len > avail)
647                 len = avail;
648         got = sbuf_op_recv(sbuf, dst, len);
649         if (got > 0) {
650                 io->recv_pos += got;
651         } else if (got == 0) {
652                 /* eof from socket */
653                 sbuf_call_proto(sbuf, SBUF_EV_RECV_FAILED);
654                 return false;
655         } else if (got < 0 && errno != EAGAIN) {
656                 /* some error occurred */
657                 sbuf_call_proto(sbuf, SBUF_EV_RECV_FAILED);
658                 return false;
659         }
660         return true;
661 }
662
663 /* callback for libevent EV_READ */
664 static void sbuf_recv_cb(evutil_socket_t sock, short flags, void *arg)
665 {
666         SBuf *sbuf = arg;
667         sbuf_main_loop(sbuf, DO_RECV);
668 }
669
670 static bool allocate_iobuf(SBuf *sbuf)
671 {
672         if (sbuf->io == NULL) {
673                 sbuf->io = slab_alloc(iobuf_cache);
674                 if (sbuf->io == NULL) {
675                         sbuf_call_proto(sbuf, SBUF_EV_RECV_FAILED);
676                         return false;
677                 }
678                 iobuf_reset(sbuf->io);
679         }
680         return true;
681 }
682
683 /*
684  * Main recv-parse-send-repeat loop.
685  *
686  * Reason for skip_recv is to avoid extra recv().  The problem with it
687  * is EOF from socket.  Currently that means that the pending data is
688  * dropped.  Fortunately server sockets are not paused and dropping
689  * data from client is no problem.  So only place where skip_recv is
690  * important is sbuf_send_cb().
691  */
692 static void sbuf_main_loop(SBuf *sbuf, bool skip_recv)
693 {
694         unsigned free, ok;
695         int loopcnt = 0;
696
697         /* sbuf was closed before in this event loop */
698         if (!sbuf->sock)
699                 return;
700
701         /* reading should be disabled when waiting */
702         Assert(sbuf->wait_type == W_RECV);
703         AssertSanity(sbuf);
704
705         if (!allocate_iobuf(sbuf))
706                 return;
707
708         /* avoid recv() if asked */
709         if (skip_recv)
710                 goto skip_recv;
711
712 try_more:
713         /* make room in buffer */
714         sbuf_try_resync(sbuf, false);
715
716         /* avoid spending too much time on single socket */
717         if (cf_sbuf_loopcnt > 0 && loopcnt >= cf_sbuf_loopcnt) {
718                 bool _ignore;
719
720                 log_debug("loopcnt full");
721                 /*
722                  * sbuf_process_pending() avoids some data if buffer is full,
723                  * but as we exit processing loop here, we need to retry
724                  * after resync to process all data. (result is ignored)
725                  */
726                 _ignore = sbuf_process_pending(sbuf);
727                 (void) _ignore;
728
729                 sbuf_wait_for_data_forced(sbuf);
730                 return;
731         }
732         loopcnt++;
733
734         /*
735          * here used to be if (free > SBUF_SMALL_PKT) check
736          * but with skip_recv switch its should not be needed anymore.
737          */
738         free = iobuf_amount_recv(sbuf->io);
739         if (free > 0) {
740                 /*
741                  * When suspending, try to hit packet boundary ASAP.
742                  */
743                 if (cf_pause_mode == P_SUSPEND
744                     && sbuf->pkt_remain > 0
745                     && sbuf->pkt_remain < free)
746                 {
747                         free = sbuf->pkt_remain;
748                 }
749
750                 /* now fetch the data */
751                 ok = sbuf_actual_recv(sbuf, free);
752                 if (!ok)
753                         return;
754         }
755
756 skip_recv:
757         /* now handle it */
758         ok = sbuf_process_pending(sbuf);
759         if (!ok)
760                 return;
761
762         /* if the buffer is full, there can be more data available */
763         if (iobuf_amount_recv(sbuf->io) <= 0)
764                 goto try_more;
765
766         /* clean buffer */
767         sbuf_try_resync(sbuf, true);
768
769         /* notify proto that all is sent */
770         if (sbuf_is_empty(sbuf))
771                 sbuf_call_proto(sbuf, SBUF_EV_FLUSH);
772
773         if (sbuf->tls_state == SBUF_TLS_DO_HANDSHAKE) {
774                 sbuf->pkt_action = SBUF_TLS_IN_HANDSHAKE;
775                 handle_tls_handshake(sbuf);
776         }
777 }
778
779 /* check if there is any error pending on socket */
780 static bool sbuf_after_connect_check(SBuf *sbuf)
781 {
782         int optval = 0, err;
783         socklen_t optlen = sizeof(optval);
784
785         err = getsockopt(sbuf->sock, SOL_SOCKET, SO_ERROR, (void*)&optval, &optlen);
786         if (err < 0) {
787                 log_debug("sbuf_after_connect_check: getsockopt: %s",
788                           strerror(errno));
789                 return false;
790         }
791         if (optval != 0) {
792                 log_debug("sbuf_after_connect_check: pending error: %s",
793                           strerror(optval));
794                 return false;
795         }
796         return true;
797 }
798
799 /* callback for libevent EV_WRITE when connecting */
800 static void sbuf_connect_cb(evutil_socket_t sock, short flags, void *arg)
801 {
802         SBuf *sbuf = arg;
803
804         Assert(sbuf->wait_type == W_CONNECT || sbuf->wait_type == W_NONE);
805         sbuf->wait_type = W_NONE;
806
807         if (flags & EV_WRITE) {
808                 if (!sbuf_after_connect_check(sbuf))
809                         goto failed;
810                 if (!sbuf_call_proto(sbuf, SBUF_EV_CONNECT_OK))
811                         return;
812                 if (!sbuf_wait_for_data(sbuf))
813                         goto failed;
814                 return;
815         }
816 failed:
817         sbuf_call_proto(sbuf, SBUF_EV_CONNECT_FAILED);
818 }
819
820 /* send some data to listening socket */
821 bool sbuf_answer(SBuf *sbuf, const void *buf, unsigned len)
822 {
823         int res;
824         if (sbuf->sock <= 0)
825                 return false;
826         res = sbuf_op_send(sbuf, buf, len);
827         if (res < 0) {
828                 log_debug("sbuf_answer: error sending: %s", strerror(errno));
829         } else if ((unsigned)res != len) {
830                 log_debug("sbuf_answer: partial send: len=%d sent=%d", len, res);
831         }
832         return (unsigned)res == len;
833 }
834
835 /*
836  * Standard IO ops.
837  */
838
839 static int raw_sbufio_recv(struct SBuf *sbuf, void *dst, unsigned int len)
840 {
841         return safe_recv(sbuf->sock, dst, len, 0);
842 }
843
844 static int raw_sbufio_send(struct SBuf *sbuf, const void *data, unsigned int len)
845 {
846         return safe_send(sbuf->sock, data, len, 0);
847 }
848
849 static int raw_sbufio_close(struct SBuf *sbuf)
850 {
851         if (sbuf->sock > 0) {
852                 safe_close(sbuf->sock);
853                 sbuf->sock = 0;
854         }
855         return 0;
856 }
857
858 /*
859  * TLS support.
860  */
861
862 #ifdef USE_TLS
863
864 static struct tls_config *client_accept_conf;
865 static struct tls_config *server_connect_conf;
866 static struct tls *client_accept_base;
867
868 /*
869  * TLS setup
870  */
871
872 static void setup_tls(struct tls_config *conf, const char *pfx, int sslmode,
873                       const char *protocols, const char *ciphers,
874                       const char *keyfile, const char *certfile, const char *cafile,
875                       const char *dheparams, const char *ecdhecurve,
876                       bool does_connect)
877 {
878         int err;
879         if (*protocols) {
880                 uint32_t protos = TLS_PROTOCOLS_ALL;
881                 err = tls_config_parse_protocols(&protos, protocols);
882                 if (err) {
883                         log_error("invalid %s_protocols: %s", pfx, protocols);
884                 } else {
885                         tls_config_set_protocols(conf, protos);
886                 }
887         }
888         if (*ciphers) {
889                 err = tls_config_set_ciphers(conf, ciphers);
890                 if (err)
891                         log_error("invalid %s_ciphers: %s", pfx, ciphers);
892         }
893         if (*dheparams) {
894                 err = tls_config_set_dheparams(conf, dheparams);
895                 if (err)
896                         log_error("invalid %s_dheparams: %s", pfx, dheparams);
897         }
898         if (*ecdhecurve) {
899                 err = tls_config_set_ecdhecurve(conf, ecdhecurve);
900                 if (err)
901                         log_error("invalid %s_ecdhecurve: %s", pfx, ecdhecurve);
902         }
903         if (*cafile) {
904                 err = tls_config_set_ca_file(conf, cafile);
905                 if (err)
906                         log_error("invalid %s_ca_file: %s", pfx, cafile);
907         }
908         if (*keyfile) {
909                 err = tls_config_set_key_file(conf, keyfile);
910                 if (err)
911                         log_error("invalid %s_key_file: %s", pfx, keyfile);
912         }
913         if (*certfile) {
914                 err = tls_config_set_cert_file(conf, certfile);
915                 if (err)
916                         log_error("invalid %s_cert_file: %s", pfx, certfile);
917         }
918
919         if (does_connect) {
920                 /* TLS client, check server? */
921                 if (sslmode == SSLMODE_VERIFY_FULL) {
922                         tls_config_verify(conf);
923                 } else if (sslmode == SSLMODE_VERIFY_CA) {
924                         tls_config_verify(conf);
925                         tls_config_insecure_noverifyname(conf);
926                 } else {
927                         tls_config_insecure_noverifycert(conf);
928                         tls_config_insecure_noverifyname(conf);
929                 }
930         } else {
931                 /* TLS server, check client? */
932                 if (sslmode == SSLMODE_VERIFY_FULL) {
933                         tls_config_verify_client(conf);
934                 } else if (sslmode == SSLMODE_VERIFY_CA) {
935                         tls_config_verify_client(conf);
936                 } else {
937                         tls_config_verify_client_optional(conf);
938                 }
939         }
940 }
941
942 void sbuf_tls_setup(void)
943 {
944         int err;
945
946         if (cf_client_tls_sslmode != SSLMODE_DISABLED) {
947                 if (!*cf_client_tls_key_file || !*cf_client_tls_cert_file)
948                         die("To allow TLS connections from clients, client_tls_key_file and client_tls_cert_file must be set.");
949         }
950         if (cf_auth_type == AUTH_CERT) {
951                 if (cf_client_tls_sslmode != SSLMODE_VERIFY_FULL)
952                         die("auth_type=cert requires client_tls_sslmode=SSLMODE_VERIFY_FULL");
953                 if (*cf_client_tls_ca_file == '\0')
954                         die("auth_type=cert requires client_tls_ca_file");
955         } else if (cf_client_tls_sslmode > SSLMODE_VERIFY_CA && *cf_client_tls_ca_file == '\0') {
956                 die("client_tls_sslmode requires client_tls_ca_file");
957         }
958
959         err = tls_init();
960         if (err)
961                 fatal("tls_init failed");
962
963         if (cf_server_tls_sslmode != SSLMODE_DISABLED) {
964                 server_connect_conf = tls_config_new();
965                 if (!server_connect_conf)
966                         die("tls_config_new failed 1");
967                 setup_tls(server_connect_conf, "server_tls", cf_server_tls_sslmode,
968                           cf_server_tls_protocols, cf_server_tls_ciphers,
969                           cf_server_tls_key_file, cf_server_tls_cert_file,
970                           cf_server_tls_ca_file, "", "", true);
971         }
972
973         if (cf_client_tls_sslmode != SSLMODE_DISABLED) {
974                 client_accept_conf = tls_config_new();
975                 if (!client_accept_conf)
976                         die("tls_config_new failed 2");
977                 setup_tls(client_accept_conf, "client_tls", cf_client_tls_sslmode,
978                           cf_client_tls_protocols, cf_client_tls_ciphers,
979                           cf_client_tls_key_file, cf_client_tls_cert_file,
980                           cf_client_tls_ca_file, cf_client_tls_dheparams,
981                           cf_client_tls_ecdhecurve, false);
982
983                 client_accept_base = tls_server();
984                 if (!client_accept_base)
985                         die("server_base failed");
986                 err = tls_configure(client_accept_base, client_accept_conf);
987                 if (err)
988                         die("TLS setup failed: %s", tls_error(client_accept_base));
989         }
990 }
991
992 /*
993  * TLS handshake
994  */
995
996 static bool handle_tls_handshake(SBuf *sbuf)
997 {
998         int err;
999
1000         err = tls_handshake(sbuf->tls);
1001         log_noise("tls_handshake: err=%d", err);
1002         if (err == TLS_WANT_POLLIN) {
1003                 return sbuf_use_callback_once(sbuf, EV_READ, sbuf_tls_handshake_cb);
1004         } else if (err == TLS_WANT_POLLOUT) {
1005                 return sbuf_use_callback_once(sbuf, EV_WRITE, sbuf_tls_handshake_cb);
1006         } else if (err == 0) {
1007                 sbuf->tls_state = SBUF_TLS_OK;
1008                 sbuf_call_proto(sbuf, SBUF_EV_TLS_READY);
1009                 return true;
1010         } else {
1011                 log_warning("TLS handshake error: %s", tls_error(sbuf->tls));
1012                 return false;
1013         }
1014 }
1015
1016 static void sbuf_tls_handshake_cb(evutil_socket_t fd, short flags, void *_sbuf)
1017 {
1018         SBuf *sbuf = _sbuf;
1019         sbuf->wait_type = W_NONE;
1020         if (!handle_tls_handshake(sbuf))
1021                 sbuf_call_proto(sbuf, SBUF_EV_RECV_FAILED);
1022 }
1023
1024 /*
1025  * Accept TLS connection.
1026  */
1027
1028 bool sbuf_tls_accept(SBuf *sbuf)
1029 {
1030         int err;
1031
1032         if (!sbuf_pause(sbuf))
1033                 return false;
1034
1035         sbuf->ops = &tls_sbufio_ops;
1036
1037         err = tls_accept_fds(client_accept_base, &sbuf->tls, sbuf->sock, sbuf->sock);
1038         log_noise("tls_accept_fds: err=%d", err);
1039         if (err < 0) {
1040                 log_warning("TLS accept error: %s", tls_error(sbuf->tls));
1041                 return false;
1042         }
1043
1044         sbuf->tls_state = SBUF_TLS_DO_HANDSHAKE;
1045         return true;
1046 }
1047
1048 /*
1049  * Connect to remote TLS host.
1050  */
1051
1052 bool sbuf_tls_connect(SBuf *sbuf, const char *hostname)
1053 {
1054         struct tls *ctls;
1055         int err;
1056
1057         if (!sbuf_pause(sbuf))
1058                 return false;
1059
1060         if (cf_server_tls_sslmode != SSLMODE_VERIFY_FULL)
1061                 hostname = NULL;
1062
1063         ctls = tls_client();
1064         if (!ctls)
1065                 return false;
1066         err = tls_configure(ctls, server_connect_conf);
1067         if (err < 0) {
1068                 log_error("tls client config failed: %s", tls_error(ctls));
1069                 tls_free(ctls);
1070                 return false;
1071         }
1072
1073         sbuf->tls = ctls;
1074         sbuf->tls_host = hostname;
1075         sbuf->ops = &tls_sbufio_ops;
1076
1077         err = tls_connect_fds(sbuf->tls, sbuf->sock, sbuf->sock, sbuf->tls_host);
1078         if (err < 0) {
1079                 log_warning("TLS connect error: %s", tls_error(sbuf->tls));
1080                 return false;
1081         }
1082
1083         sbuf->tls_state = SBUF_TLS_DO_HANDSHAKE;
1084         return true;
1085 }
1086
1087 /*
1088  * TLS IO ops.
1089  */
1090
1091 static int tls_sbufio_recv(struct SBuf *sbuf, void *dst, unsigned int len)
1092 {
1093         ssize_t out = 0;
1094
1095         if (sbuf->tls_state != SBUF_TLS_OK) {
1096                 errno = EIO;
1097                 return -1;
1098         }
1099
1100         out = tls_read(sbuf->tls, dst, len);
1101         log_noise("tls_read: req=%u out=%d", len, (int)out);
1102         if (out >= 0) {
1103                 return out;
1104         } else if (out == TLS_WANT_POLLIN) {
1105                 errno = EAGAIN;
1106         } else if (out == TLS_WANT_POLLOUT) {
1107                 log_warning("tls_sbufio_recv: got TLS_WANT_POLLOUT");
1108                 errno = EIO;
1109         } else {
1110                 log_warning("tls_sbufio_recv: %s", tls_error(sbuf->tls));
1111                 errno = EIO;
1112         }
1113         return -1;
1114 }
1115
1116 static int tls_sbufio_send(struct SBuf *sbuf, const void *data, unsigned int len)
1117 {
1118         ssize_t out;
1119
1120         if (sbuf->tls_state != SBUF_TLS_OK) {
1121                 errno = EIO;
1122                 return -1;
1123         }
1124
1125         out = tls_write(sbuf->tls, data, len);
1126         log_noise("tls_write: req=%u out=%d", len, (int)out);
1127         if (out >= 0) {
1128                 return out;
1129         } else if (out == TLS_WANT_POLLOUT) {
1130                 errno = EAGAIN;
1131         } else if (out == TLS_WANT_POLLIN) {
1132                 log_warning("tls_sbufio_send: got TLS_WANT_POLLIN");
1133                 errno = EIO;
1134         } else {
1135                 log_warning("tls_sbufio_send: %s", tls_error(sbuf->tls));
1136                 errno = EIO;
1137         }
1138         return -1;
1139 }
1140
1141 static int tls_sbufio_close(struct SBuf *sbuf)
1142 {
1143         log_noise("tls_close");
1144         if (sbuf->tls) {
1145                 tls_close(sbuf->tls);
1146                 tls_free(sbuf->tls);
1147                 sbuf->tls = NULL;
1148         }
1149         if (sbuf->sock > 0) {
1150                 safe_close(sbuf->sock);
1151                 sbuf->sock = 0;
1152         }
1153         return 0;
1154 }
1155
1156 void sbuf_cleanup(void)
1157 {
1158         tls_free(client_accept_base);
1159         tls_config_free(client_accept_conf);
1160         tls_config_free(server_connect_conf);
1161         client_accept_conf = NULL;
1162         server_connect_conf = NULL;
1163         client_accept_base = NULL;
1164 }
1165
1166 #else
1167
1168 void sbuf_tls_setup(void) { }
1169 bool sbuf_tls_accept(SBuf *sbuf) { return false; }
1170 bool sbuf_tls_connect(SBuf *sbuf, const char *hostname) { return false; }
1171
1172 void sbuf_cleanup(void)
1173 {
1174 }
1175
1176 static bool handle_tls_handshake(SBuf *sbuf)
1177 {
1178         return false;
1179 }
1180
1181 #endif