]> granicus.if.org Git - pgbouncer/commitdiff
tls: avoid recursive socket loop
authorMarko Kreen <markokr@gmail.com>
Thu, 3 Dec 2015 18:53:13 +0000 (20:53 +0200)
committerMarko Kreen <markokr@gmail.com>
Fri, 4 Dec 2015 10:20:34 +0000 (12:20 +0200)
TLS handshake may happen immediately without
going though libevent poll.  (Loaded CPU with fast
network - local testing).  This will lead to

sbuf_main_loop
 ->sbuf_tls_connect
   ->SBUF_EV_TLS_READY
     ->sbuf_continue
       ->sbuf_main_loop

call which finally end up in sbuf_send_pending()
running on JUSTFREE socket which crashes.

To improve things:

* Always perform sbuf_pause before handshake.
  Otherwise sbuf_continue can be called on
  unpaused socket.

* Move actual handshake out from from sbuf_tls_* functions
  to avoid recursive sbuf_main_loop().

Fixes: #97
include/sbuf.h
src/sbuf.c

index 9fcbdbd794d001d56560492455ca109cce45aed5..75032313cddbce51a53ed75d22d88c73cf5696f5 100644 (file)
@@ -73,6 +73,7 @@ struct SBuf {
 
        uint8_t wait_type;      /* track wait state */
        uint8_t pkt_action;     /* method for handling current pkt */
+       uint8_t tls_state;      /* progress of tls */
 
        int sock;               /* fd for this socket */
 
index ab69fbf0dbc89090aeadcb577f903f45de9f1e8a..b651717e426ec4a12e33cba23eef9311a21be36f 100644 (file)
 #define ACT_SKIP 2
 #define ACT_CALL 3
 
+enum TLSState {
+       SBUF_TLS_NONE,
+       SBUF_TLS_DO_HANDSHAKE,
+       SBUF_TLS_IN_HANDSHAKE,
+       SBUF_TLS_OK,
+};
+
 enum WaitType {
        W_NONE = 0,
        W_CONNECT,
@@ -69,6 +76,7 @@ static void sbuf_main_loop(SBuf *sbuf, bool skip_recv);
 static bool sbuf_call_proto(SBuf *sbuf, int event) /* _MUSTCHECK */;
 static bool sbuf_actual_recv(SBuf *sbuf, unsigned len)  _MUSTCHECK;
 static bool sbuf_after_connect_check(SBuf *sbuf)  _MUSTCHECK;
+static bool handle_tls_handshake(SBuf *sbuf) /* _MUSTCHECK */;
 
 static inline IOBuf *get_iobuf(SBuf *sbuf) { return sbuf->io; }
 
@@ -720,6 +728,11 @@ skip_recv:
        /* notify proto that all is sent */
        if (sbuf_is_empty(sbuf))
                sbuf_call_proto(sbuf, SBUF_EV_FLUSH);
+
+       if (sbuf->tls_state == SBUF_TLS_DO_HANDSHAKE) {
+               sbuf->pkt_action = SBUF_TLS_IN_HANDSHAKE;
+               handle_tls_handshake(sbuf);
+       }
 }
 
 /* check if there is any error pending on socket */
@@ -948,6 +961,7 @@ static bool handle_tls_handshake(SBuf *sbuf)
        } else if (err == TLS_WANT_POLLOUT) {
                return sbuf_use_callback_once(sbuf, EV_WRITE, sbuf_tls_handshake_cb);
        } else if (err == 0) {
+               sbuf->tls_state = SBUF_TLS_OK;
                sbuf_call_proto(sbuf, SBUF_EV_TLS_READY);
                return true;
        } else {
@@ -972,6 +986,9 @@ bool sbuf_tls_accept(SBuf *sbuf)
 {
        int err;
 
+       if (!sbuf_pause(sbuf))
+               return false;
+
        sbuf->ops = &tls_sbufio_ops;
 
        err = tls_accept_fds(client_accept_base, &sbuf->tls, sbuf->sock, sbuf->sock);
@@ -981,7 +998,8 @@ bool sbuf_tls_accept(SBuf *sbuf)
                return false;
        }
 
-       return handle_tls_handshake(sbuf);
+       sbuf->tls_state = SBUF_TLS_DO_HANDSHAKE;
+       return true;
 }
 
 /*
@@ -993,6 +1011,9 @@ bool sbuf_tls_connect(SBuf *sbuf, const char *hostname)
        struct tls *ctls;
        int err;
 
+       if (!sbuf_pause(sbuf))
+               return false;
+
        if (cf_server_tls_sslmode != SSLMODE_VERIFY_FULL)
                hostname = NULL;
 
@@ -1016,7 +1037,8 @@ bool sbuf_tls_connect(SBuf *sbuf, const char *hostname)
                return false;
        }
 
-       return handle_tls_handshake(sbuf);
+       sbuf->tls_state = SBUF_TLS_DO_HANDSHAKE;
+       return true;
 }
 
 /*
@@ -1027,6 +1049,11 @@ static int tls_sbufio_recv(struct SBuf *sbuf, void *dst, unsigned int len)
 {
        ssize_t out = 0;
 
+       if (sbuf->tls_state != SBUF_TLS_OK) {
+               errno = EIO;
+               return -1;
+       }
+
        out = tls_read(sbuf->tls, dst, len);
        log_noise("tls_read: req=%u out=%d", len, (int)out);
        if (out >= 0) {
@@ -1047,6 +1074,11 @@ static int tls_sbufio_send(struct SBuf *sbuf, const void *data, unsigned int len
 {
        ssize_t out;
 
+       if (sbuf->tls_state != SBUF_TLS_OK) {
+               errno = EIO;
+               return -1;
+       }
+
        out = tls_write(sbuf->tls, data, len);
        log_noise("tls_write: req=%u out=%d", len, (int)out);
        if (out >= 0) {