From: Marko Kreen Date: Mon, 3 Aug 2015 18:54:49 +0000 (+0300) Subject: Support TLS connections. X-Git-Tag: pgbouncer_1_7_rc1~50 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=baebf4a8dbb68d94e2170c90caf00eadf86f9813;p=pgbouncer Support TLS connections. --- diff --git a/Makefile b/Makefile index 34bad28..b9825b7 100644 --- a/Makefile +++ b/Makefile @@ -40,7 +40,7 @@ pgbouncer_SOURCES = \ include/util.h \ include/varcache.h -pgbouncer_CPPFLAGS = -Iinclude $(CARES_CFLAGS) +pgbouncer_CPPFLAGS = -Iinclude $(CARES_CFLAGS) $(TLS_CPPFLAGS) # include libusual sources directly AM_FEATURES = libusual @@ -85,7 +85,8 @@ endif # win32 # -pgbouncer_LDADD := $(CARES_LIBS) $(LIBS) +pgbouncer_LDFLAGS := $(TLS_LDFLAGS) +pgbouncer_LDADD := $(CARES_LIBS) $(TLS_LIBS) $(LIBS) LIBS := EXTRA_pgbouncer_SOURCES = win32/win32support.c win32/win32support.h diff --git a/config.mak.in b/config.mak.in index 5345c05..ba3cc33 100644 --- a/config.mak.in +++ b/config.mak.in @@ -55,6 +55,10 @@ nosub_top_builddir ?= @top_builddir@ CARES_CFLAGS = @CARES_CFLAGS@ CARES_LIBS = @CARES_LIBS@ +TLS_CPPFLAGS = @TLS_CPPFLAGS@ +TLS_LDFLAGS = @TLS_LDFLAGS@ +TLS_LIBS = @TLS_LIBS@ + XMLTO = @XMLTO@ ASCIIDOC = @ASCIIDOC@ DLLWRAP = @DLLWRAP@ diff --git a/configure.ac b/configure.ac index 40407d7..cf9954c 100644 --- a/configure.ac +++ b/configure.ac @@ -181,6 +181,8 @@ fi # !cares ## end of DNS +AC_USUAL_TLS + AC_USUAL_DEBUG AC_USUAL_CASSERT AC_USUAL_WERROR @@ -199,4 +201,5 @@ echo "Results" echo " c-ares = $use_cares" echo " evdns = $use_evdns" echo " udns = $use_udns" +echo " tls = $tls_support" echo "" diff --git a/include/bouncer.h b/include/bouncer.h index e2dedaa..ec28986 100644 --- a/include/bouncer.h +++ b/include/bouncer.h @@ -68,6 +68,15 @@ enum PauseMode { P_SUSPEND = 2 /* wait for buffers to be empty */ }; +enum SSLMode { + SSLMODE_DISABLED, + SSLMODE_ALLOW, + SSLMODE_PREFER, + SSLMODE_REQUIRE, + SSLMODE_VERIFY_CA, + SSLMODE_VERIFY_FULL +}; + #define is_server_socket(sk) ((sk)->state >= SV_FREE) @@ -327,6 +336,8 @@ struct PgSocket { bool own_user:1; /* console client: client with same uid on unix socket */ bool wait_for_response:1;/* console client: waits for completion of PAUSE/SUSPEND cmd */ + bool wait_sslchar:1; /* server: waiting for ssl response: S/N */ + usec_t connect_time; /* when connection was made */ usec_t request_time; /* last activity time */ usec_t query_start; /* query start moment */ @@ -431,6 +442,22 @@ extern int cf_log_disconnections; extern int cf_log_pooler_errors; extern int cf_application_name_add_host; +extern int cf_client_tls_sslmode; +extern char *cf_client_tls_protocols; +extern char *cf_client_tls_ca_file; +extern char *cf_client_tls_cert_file; +extern char *cf_client_tls_key_file; +extern char *cf_client_tls_ciphers; +extern char *cf_client_tls_dheparams; +extern char *cf_client_tls_ecdhecurve; + +extern int cf_server_tls_sslmode; +extern char *cf_server_tls_protocols; +extern char *cf_server_tls_ca_file; +extern char *cf_server_tls_cert_file; +extern char *cf_server_tls_key_file; +extern char *cf_server_tls_ciphers; + extern const struct CfLookup pool_mode_map[]; extern usec_t g_suspend_start; diff --git a/include/pktbuf.h b/include/pktbuf.h index 9a07ffa..1e52649 100644 --- a/include/pktbuf.h +++ b/include/pktbuf.h @@ -106,6 +106,9 @@ void pktbuf_write_ExtQuery(PktBuf *buf, const char *query, int nargs, ...); #define pktbuf_write_Notice(buf, msg) \ pktbuf_write_generic(buf, 'N', "sscss", "SNOTICE", "C00000", 'M', msg, ""); +#define pktbuf_write_SSLRequest(buf) \ + pktbuf_write_generic(buf, PKT_SSLREQ, "") + /* * Shortcut for creating DataRow in memory. */ diff --git a/include/proto.h b/include/proto.h index 3089f42..abd9f6f 100644 --- a/include/proto.h +++ b/include/proto.h @@ -50,6 +50,7 @@ bool welcome_client(PgSocket *client) _MUSTCHECK; bool answer_authreq(PgSocket *server, PktHdr *pkt) _MUSTCHECK; bool send_startup_packet(PgSocket *server) _MUSTCHECK; +bool send_sslreq_packet(PgSocket *server) _MUSTCHECK; int scan_text_result(struct MBuf *pkt, const char *tupdesc, ...) _MUSTCHECK; diff --git a/include/sbuf.h b/include/sbuf.h index 55fdfff..9fcbdbd 100644 --- a/include/sbuf.h +++ b/include/sbuf.h @@ -27,6 +27,7 @@ typedef enum { SBUF_EV_CONNECT_OK, /* got connection */ SBUF_EV_FLUSH, /* data is sent, buffer empty */ SBUF_EV_PKT_CALLBACK, /* next part of pkt data */ + SBUF_EV_TLS_READY /* TLS was established */ } SBufEvent; /* @@ -39,6 +40,8 @@ typedef enum { */ #define SBUF_SMALL_PKT 64 +struct tls; + /* fwd def */ typedef struct SBuf SBuf; typedef struct SBufIO SBufIO; @@ -82,6 +85,8 @@ struct SBuf { IOBuf *io; /* data buffer, lazily allocated */ const SBufIO *ops; /* normal vs. TLS */ + struct tls *tls; /* TLS context */ + const char *tls_host; /* target hostname */ }; #define sbuf_socket(sbuf) ((sbuf)->sock) @@ -90,6 +95,10 @@ void sbuf_init(SBuf *sbuf, sbuf_cb_t proto_fn); bool sbuf_accept(SBuf *sbuf, int read_sock, bool is_unix) _MUSTCHECK; bool sbuf_connect(SBuf *sbuf, const struct sockaddr *sa, int sa_len, int timeout_sec) _MUSTCHECK; +void sbuf_tls_setup(void); +bool sbuf_tls_accept(SBuf *sbuf) _MUSTCHECK; +bool sbuf_tls_connect(SBuf *sbuf, const char *hostname) _MUSTCHECK; + bool sbuf_pause(SBuf *sbuf) _MUSTCHECK; void sbuf_continue(SBuf *sbuf); bool sbuf_close(SBuf *sbuf) _MUSTCHECK; @@ -102,6 +111,7 @@ void sbuf_prepare_fetch(SBuf *sbuf, unsigned amount); bool sbuf_answer(SBuf *sbuf, const void *buf, unsigned len) _MUSTCHECK; bool sbuf_continue_with_callback(SBuf *sbuf, sbuf_libevent_cb cb) _MUSTCHECK; +bool sbuf_use_callback_once(SBuf *sbuf, short ev, sbuf_libevent_cb user_cb) _MUSTCHECK; /* * Returns true if SBuf is has no data buffered diff --git a/include/system.h b/include/system.h index a3417ba..cb531ae 100644 --- a/include/system.h +++ b/include/system.h @@ -32,6 +32,8 @@ #include #include +#include + #ifdef HAVE_CRYPT_H #include #endif diff --git a/lib b/lib index 7dd946a..7177b2a 160000 --- a/lib +++ b/lib @@ -1 +1 @@ -Subproject commit 7dd946ae6023574eefdb9d254faae46805c16c07 +Subproject commit 7177b2af4f65037d19ff193073b06d6347d4b614 diff --git a/src/admin.c b/src/admin.c index ef6239a..bc389db 100644 --- a/src/admin.c +++ b/src/admin.c @@ -270,7 +270,7 @@ static bool send_one_fd(PgSocket *admin, msg.msg_iovlen = 1; /* attach a fd */ - if (pga_is_unix(&admin->remote_addr) && admin->own_user) { + if (pga_is_unix(&admin->remote_addr) && admin->own_user && !admin->sbuf.tls) { msg.msg_control = cntbuf; msg.msg_controllen = sizeof(cntbuf); @@ -314,6 +314,10 @@ static bool show_one_fd(PgSocket *admin, PgSocket *sk) char addrbuf[PGADDR_BUF]; const char *password = NULL; + /* Skip TLS sockets */ + if (sk->sbuf.tls || (sk->link && sk->link->sbuf.tls)) + return true; + mbuf_init_fixed_reader(&tmp, sk->cancel_key, 8); if (!mbuf_get_uint64be(&tmp, &ckey)) return false; @@ -546,8 +550,8 @@ static bool admin_show_users(PgSocket *admin, const char *arg) return true; } -#define SKF_STD "sssssisiTTssi" -#define SKF_DBG "sssssisiTTssiiiiiiii" +#define SKF_STD "sssssisiTTssis" +#define SKF_DBG "sssssisiTTssisiiiiiii" static void socket_header(PktBuf *buf, bool debug) { @@ -555,7 +559,8 @@ static void socket_header(PktBuf *buf, bool debug) "type", "user", "database", "state", "addr", "port", "local_addr", "local_port", "connect_time", "request_time", - "ptr", "link", "remote_pid", + "ptr", "link", "remote_pid", "tls", + /* debug follows */ "recv_pos", "pkt_pos", "pkt_remain", "send_pos", "send_remain", "pkt_avail", "send_avail"); @@ -573,6 +578,7 @@ static void socket_row(PktBuf *buf, PgSocket *sk, const char *state, bool debug) char ptrbuf[128], linkbuf[128]; char l_addr[PGADDR_BUF], r_addr[PGADDR_BUF]; IOBuf *io = sk->sbuf.io; + char infobuf[96] = ""; if (io) { pkt_avail = iobuf_amount_parse(sk->sbuf.io); @@ -597,6 +603,9 @@ static void socket_row(PktBuf *buf, PgSocket *sk, const char *state, bool debug) if (is_server_socket(sk) && remote_pid == 0) remote_pid = be32dec(sk->cancel_key); + if (sk->sbuf.tls) + tls_get_connection_info(sk->sbuf.tls, infobuf, sizeof infobuf); + pktbuf_write_DataRow(buf, debug ? SKF_DBG : SKF_STD, is_server_socket(sk) ? "S" :"C", sk->auth_user ? sk->auth_user->name : "(nouser)", @@ -605,7 +614,8 @@ static void socket_row(PktBuf *buf, PgSocket *sk, const char *state, bool debug) l_addr, pga_port(&sk->local_addr), sk->connect_time, sk->request_time, - ptrbuf, linkbuf, remote_pid, + ptrbuf, linkbuf, remote_pid, infobuf, + /* debug */ io ? io->recv_pos : 0, io ? io->parse_pos : 0, sk->sbuf.pkt_remain, diff --git a/src/client.c b/src/client.c index ec85374..6c243bd 100644 --- a/src/client.c +++ b/src/client.c @@ -141,8 +141,17 @@ static bool finish_set_pool(PgSocket *client, bool takeover) return false; } - if (cf_log_connections) - slog_info(client, "login attempt: db=%s user=%s", client->db->name, client->auth_user->name); + if (cf_log_connections) { + if (client->sbuf.tls) { + char infobuf[96] = ""; + tls_get_connection_info(client->sbuf.tls, infobuf, sizeof infobuf); + slog_info(client, "login attempt: db=%s user=%s tls=%s", + client->db->name, client->auth_user->name, infobuf); + } else { + slog_info(client, "login attempt: db=%s user=%s tls=no", + client->db->name, client->auth_user->name); + } + } if (!check_fast_fail(client)) return false; @@ -433,9 +442,27 @@ static bool handle_client_startup(PgSocket *client, PktHdr *pkt) switch (pkt->type) { case PKT_SSLREQ: slog_noise(client, "C: req SSL"); - slog_noise(client, "P: nak"); +#ifdef USE_TLS + if (client->sbuf.tls) { + disconnect_client(client, false, "SSL req inside SSL"); + return false; + } + if (cf_client_tls_sslmode != SSLMODE_DISABLED) { + slog_noise(client, "P: SSL ack"); + if (!sbuf_answer(&client->sbuf, "S", 1)) { + disconnect_client(client, false, "failed to ack SSL"); + return false; + } + if (!sbuf_tls_accept(&client->sbuf)) { + disconnect_client(client, false, "failed to accept SSL"); + return false; + } + break; + } +#endif /* reject SSL attempt */ + slog_noise(client, "P: nak"); if (!sbuf_answer(&client->sbuf, "N", 1)) { disconnect_client(client, false, "failed to nak SSL"); return false; @@ -445,6 +472,12 @@ static bool handle_client_startup(PgSocket *client, PktHdr *pkt) disconnect_client(client, true, "Old V2 protocol not supported"); return false; case PKT_STARTUP: + /* require SSL except on unix socket */ + if (cf_client_tls_sslmode >= SSLMODE_REQUIRE && !client->sbuf.tls && !pga_is_unix(&client->remote_addr)) { + disconnect_client(client, true, "SSL required"); + return false; + } + if (client->pool && !client->wait_for_user_conn && !client->wait_for_user) { disconnect_client(client, true, "client re-sent startup pkt"); return false; @@ -633,6 +666,10 @@ bool client_proto(SBuf *sbuf, SBufEvent evtype, struct MBuf *data) case SBUF_EV_PKT_CALLBACK: /* unused ATM */ break; + case SBUF_EV_TLS_READY: + sbuf_continue(&client->sbuf); + res = true; + break; } return res; } diff --git a/src/main.c b/src/main.c index f194eb1..ef0344d 100644 --- a/src/main.c +++ b/src/main.c @@ -31,6 +31,11 @@ #include #endif +#ifndef DEFAULT_TLS_CIPHERS +/* enable only PFS, deprioritize/remove slower ones */ +#define DEFAULT_TLS_CIPHERS "EECDH+HIGH:EDH+HIGH:+AES256:+SHA256:+SHA384:+SSLv3:+EDH:-CAMELLIA:-3DES:!DSS:!aNULL" +#endif + static const char usage_str[] = "Usage: %s [OPTION]... config.ini\n" " -d, --daemon Run in background (as a daemon)\n" @@ -137,6 +142,22 @@ int cf_log_disconnections; int cf_log_pooler_errors; int cf_application_name_add_host; +int cf_client_tls_sslmode; +char *cf_client_tls_protocols; +char *cf_client_tls_ca_file; +char *cf_client_tls_cert_file; +char *cf_client_tls_key_file; +char *cf_client_tls_ciphers; +char *cf_client_tls_dheparams; +char *cf_client_tls_ecdhecurve; + +int cf_server_tls_sslmode; +char *cf_server_tls_protocols; +char *cf_server_tls_ca_file; +char *cf_server_tls_cert_file; +char *cf_server_tls_key_file; +char *cf_server_tls_ciphers; + /* * config file description */ @@ -162,6 +183,18 @@ const struct CfLookup pool_mode_map[] = { { NULL } }; +const struct CfLookup sslmode_map[] = { + { "disabled", SSLMODE_DISABLED }, +#ifdef USE_TLS + { "allow", SSLMODE_ALLOW }, + { "prefer", SSLMODE_PREFER }, + { "require", SSLMODE_REQUIRE }, + { "verify-ca", SSLMODE_VERIFY_CA }, + { "verify-full", SSLMODE_VERIFY_FULL }, +#endif + { NULL } +}; + static const struct CfKey bouncer_params [] = { CF_ABS("job_name", CF_STR, cf_jobname, CF_NO_RELOAD, "pgbouncer"), #ifdef WIN32 @@ -235,6 +268,23 @@ CF_ABS("log_connections", CF_INT, cf_log_connections, 0, "1"), CF_ABS("log_disconnections", CF_INT, cf_log_disconnections, 0, "1"), CF_ABS("log_pooler_errors", CF_INT, cf_log_pooler_errors, 0, "1"), CF_ABS("application_name_add_host", CF_INT, cf_application_name_add_host, 0, "0"), + +CF_ABS("client_tls_sslmode", CF_LOOKUP(sslmode_map), cf_client_tls_sslmode, CF_NO_RELOAD, "disabled"), +CF_ABS("client_tls_ca_file", CF_STR, cf_client_tls_ca_file, CF_NO_RELOAD, ""), +CF_ABS("client_tls_cert_file", CF_STR, cf_client_tls_cert_file, CF_NO_RELOAD, ""), +CF_ABS("client_tls_key_file", CF_STR, cf_client_tls_key_file, CF_NO_RELOAD, ""), +CF_ABS("client_tls_protocols", CF_STR, cf_client_tls_protocols, CF_NO_RELOAD, "all"), +CF_ABS("client_tls_ciphers", CF_STR, cf_client_tls_ciphers, CF_NO_RELOAD, DEFAULT_TLS_CIPHERS), +CF_ABS("client_tls_dheparams", CF_STR, cf_client_tls_dheparams, CF_NO_RELOAD, "auto"), +CF_ABS("client_tls_ecdhcurve", CF_STR, cf_client_tls_ecdhecurve, CF_NO_RELOAD, "auto"), + +CF_ABS("server_tls_sslmode", CF_LOOKUP(sslmode_map), cf_server_tls_sslmode, CF_NO_RELOAD, "disabled"), +CF_ABS("server_tls_ca_file", CF_STR, cf_server_tls_ca_file, CF_NO_RELOAD, ""), +CF_ABS("server_tls_cert_file", CF_STR, cf_server_tls_cert_file, CF_NO_RELOAD, ""), +CF_ABS("server_tls_key_file", CF_STR, cf_server_tls_key_file, CF_NO_RELOAD, ""), +CF_ABS("server_tls_protocols", CF_STR, cf_server_tls_protocols, CF_NO_RELOAD, "all"), +CF_ABS("server_tls_ciphers", CF_STR, cf_server_tls_ciphers, CF_NO_RELOAD, DEFAULT_TLS_CIPHERS), + {NULL} }; @@ -736,6 +786,8 @@ int main(int argc, char *argv[]) init_caches(); logging_prefix_cb = log_socket_prefix; + sbuf_tls_setup(); + /* prefer cmdline over config for username */ if (arg_username) { if (cf_username) diff --git a/src/proto.c b/src/proto.c index acce1f6..8aad019 100644 --- a/src/proto.c +++ b/src/proto.c @@ -369,6 +369,13 @@ bool send_startup_packet(PgSocket *server) return pktbuf_send_immediate(pkt, server); } +bool send_sslreq_packet(PgSocket *server) +{ + int res; + SEND_wrap(16, pktbuf_write_SSLRequest, res, server); + return res; +} + int scan_text_result(struct MBuf *pkt, const char *tupdesc, ...) { const char *val = NULL; diff --git a/src/sbuf.c b/src/sbuf.c index a00076f..3c59c9b 100644 --- a/src/sbuf.c +++ b/src/sbuf.c @@ -40,6 +40,7 @@ enum WaitType { W_CONNECT, W_RECV, W_SEND, + W_ONCE }; #define AssertSanity(sbuf) do { \ @@ -77,6 +78,20 @@ static const SBufIO raw_sbufio_ops = { raw_sbufio_close }; +/* I/O over TLS */ +#ifdef USE_TLS +static int tls_sbufio_recv(struct SBuf *sbuf, void *dst, unsigned int len); +static int tls_sbufio_send(struct SBuf *sbuf, const void *data, unsigned int len); +static int tls_sbufio_close(struct SBuf *sbuf); +static const SBufIO tls_sbufio_ops = { + tls_sbufio_recv, + tls_sbufio_send, + tls_sbufio_close +}; +static void sbuf_tls_accept_cb(int fd, short flags, void *_sbuf); +static void sbuf_tls_connect_cb(int fd, short flags, void *_sbuf); +#endif + /********************************* * Public functions *********************************/ @@ -237,6 +252,31 @@ bool sbuf_continue_with_callback(SBuf *sbuf, sbuf_libevent_cb user_cb) return true; } +bool sbuf_use_callback_once(SBuf *sbuf, short ev, sbuf_libevent_cb user_cb) +{ + int err; + AssertActive(sbuf); + + if (sbuf->wait_type != W_NONE) { + err = event_del(&sbuf->ev); + sbuf->wait_type = W_NONE; /* make sure its called only once */ + if (err < 0) { + log_warning("sbuf_queue_once: event_del failed: %s", strerror(errno)); + return false; + } + } + + /* setup one one-off event handler */ + event_set(&sbuf->ev, sbuf->sock, ev, user_cb, sbuf); + err = event_add(&sbuf->ev, NULL); + if (err < 0) { + log_warning("sbuf_queue_once: event_add failed: %s", strerror(errno)); + return false; + } + sbuf->wait_type = W_ONCE; + return true; +} + /* socket cleanup & close: keeps .handler and .arg values */ bool sbuf_close(SBuf *sbuf) { @@ -758,3 +798,285 @@ static int raw_sbufio_close(struct SBuf *sbuf) return 0; } +/* + * TLS support. + */ + +#ifdef USE_TLS + +static struct tls_config *client_accept_conf; +static struct tls_config *server_connect_conf; +static struct tls *client_accept_base; + +/* + * TLS setup + */ + +static void setup_tls(struct tls_config *conf, const char *pfx, int sslmode, + const char *protocols, const char *ciphers, + const char *keyfile, const char *certfile, const char *cafile, + const char *dheparams, const char *ecdhecurve, + bool does_connect) +{ + int err; + if (*protocols) { + uint32_t protos = TLS_PROTOCOLS_ALL; + err = tls_config_parse_protocols(&protos, protocols); + if (err) { + log_error("Invalid %s_protocols: %s", pfx, protocols); + } else { + tls_config_set_protocols(conf, protos); + } + } + if (*ciphers) { + err = tls_config_set_ciphers(conf, ciphers); + if (err) + log_error("Invalid %s_ciphers: %s", pfx, ciphers); + } + if (*dheparams) { + err = tls_config_set_dheparams(conf, dheparams); + if (err) + log_error("Invalid %s_dheparams: %s", pfx, dheparams); + } + if (*ecdhecurve) { + err = tls_config_set_ecdhecurve(conf, ecdhecurve); + if (err) + log_error("Invalid %s_ecdhecurve: %s", pfx, ecdhecurve); + } + if (*cafile) { + err = tls_config_set_ca_file(conf, cafile); + if (err) + log_error("Invalid %s_ca_file: %s", pfx, cafile); + } + if (*keyfile) { + err = tls_config_set_key_file(conf, keyfile); + if (err) + log_error("Invalid %s_key_file: %s", pfx, keyfile); + } + if (*certfile) { + err = tls_config_set_cert_file(conf, certfile); + if (err) + log_error("Invalid %s_cert_file: %s", pfx, certfile); + } + + if (sslmode == SSLMODE_VERIFY_FULL) { + tls_config_verify(conf); + } else if (sslmode == SSLMODE_VERIFY_CA) { + tls_config_insecure_noverifyname(conf); + } else { + tls_config_insecure_noverifycert(conf); + tls_config_insecure_noverifyname(conf); + } +} + +void sbuf_tls_setup(void) +{ + int err; + + if (cf_client_tls_sslmode != SSLMODE_DISABLED) { + if (!*cf_client_tls_key_file || !*cf_client_tls_cert_file) + die("To allow TLS connections from clients, client_tls_key_file and client_tls_cert_file must be set."); + } + if (cf_auth_type == AUTH_CERT) { + if (cf_client_tls_sslmode != SSLMODE_VERIFY_FULL) + die("auth_type=cert requires client_tls_sslmode=SSLMODE_VERIFY_FULL"); + if (*cf_client_tls_ca_file == '\0') + die("auth_type=cert requires client_tls_ca_file"); + } else if (cf_client_tls_sslmode > SSLMODE_VERIFY_CA && *cf_client_tls_ca_file == '\0') { + die("client_tls_sslmode requires client_tls_ca_file"); + } + + err = tls_init(); + if (err) + fatal("tls_init failed"); + + if (cf_server_tls_sslmode != SSLMODE_DISABLED) { + server_connect_conf = tls_config_new(); + if (!server_connect_conf) + die("tls_config_new failed 1"); + setup_tls(server_connect_conf, "server_tls", cf_server_tls_sslmode, + cf_server_tls_protocols, cf_server_tls_ciphers, + cf_server_tls_key_file, cf_server_tls_cert_file, + cf_server_tls_ca_file, "", "", true); + } + + if (cf_client_tls_sslmode != SSLMODE_DISABLED) { + client_accept_conf = tls_config_new(); + if (!client_accept_conf) + die("tls_config_new failed 2"); + setup_tls(client_accept_conf, "client_tls", cf_client_tls_sslmode, + cf_client_tls_protocols, cf_client_tls_ciphers, + cf_client_tls_key_file, cf_client_tls_cert_file, + cf_client_tls_ca_file, cf_client_tls_dheparams, + cf_client_tls_ecdhecurve, false); + + client_accept_base = tls_server(); + if (!client_accept_base) + die("server_base failed"); + err = tls_configure(client_accept_base, client_accept_conf); + if (err) + die("TLS setup failed: %s", tls_error(client_accept_base)); + } +} + +/* + * Accept TLS connection. + */ + +static bool handle_tls_accept(struct SBuf *sbuf) +{ + int err; + + err = tls_accept_fds(client_accept_base, &sbuf->tls, sbuf->sock, sbuf->sock); + log_noise("tls_accept_fds: err=%d", err); + if (err == TLS_READ_AGAIN) { + return sbuf_use_callback_once(sbuf, EV_READ, sbuf_tls_accept_cb); + } else if (err == TLS_WRITE_AGAIN) { + return sbuf_use_callback_once(sbuf, EV_WRITE, sbuf_tls_accept_cb); + } else if (err == 0) { + sbuf_call_proto(sbuf, SBUF_EV_TLS_READY); + return true; + } else { + log_warning("TLS accept error: %s", tls_error(sbuf->tls)); + return false; + } +} + +static void sbuf_tls_accept_cb(int fd, short flags, void *_sbuf) +{ + SBuf *sbuf = _sbuf; + sbuf->wait_type = W_NONE; + if (!handle_tls_accept(sbuf)) + sbuf_call_proto(sbuf, SBUF_EV_RECV_FAILED); +} + +bool sbuf_tls_accept(SBuf *sbuf) +{ + sbuf->ops = &tls_sbufio_ops; + return handle_tls_accept(sbuf); +} + +/* + * Connect to remote TLS host. + */ + +static bool handle_tls_connect(SBuf *sbuf) +{ + int err; + + err = tls_connect_fds(sbuf->tls, sbuf->sock, sbuf->sock, sbuf->tls_host); + log_noise("tls_connect_fds: err=%d", err); + if (err == TLS_READ_AGAIN) { + return sbuf_use_callback_once(sbuf, EV_READ, sbuf_tls_connect_cb); + } else if (err == TLS_WRITE_AGAIN) { + return sbuf_use_callback_once(sbuf, EV_WRITE, sbuf_tls_connect_cb); + } else if (err == 0) { + sbuf_call_proto(sbuf, SBUF_EV_TLS_READY); + return true; + } else { + log_warning("TLS connect error: %s", tls_error(sbuf->tls)); + return false; + } +} + +static void sbuf_tls_connect_cb(int fd, short flags, void *_sbuf) +{ + SBuf *sbuf = _sbuf; + sbuf->wait_type = W_NONE; + if (!handle_tls_connect(sbuf)) + sbuf_call_proto(sbuf, SBUF_EV_RECV_FAILED); +} + +bool sbuf_tls_connect(SBuf *sbuf, const char *hostname) +{ + struct tls *ctls; + int err; + + if (cf_server_tls_sslmode != SSLMODE_VERIFY_FULL) + hostname = NULL; + + ctls = tls_client(); + if (!ctls) + return false; + err = tls_configure(ctls, server_connect_conf); + if (err) { + log_error("tls client config failed: %s", tls_error(ctls)); + tls_free(ctls); + return false; + } + + sbuf->tls = ctls; + sbuf->tls_host = hostname; + sbuf->ops = &tls_sbufio_ops; + + return handle_tls_connect(sbuf); +} + +/* + * TLS IO ops. + */ + +static int tls_sbufio_recv(struct SBuf *sbuf, void *dst, unsigned int len) +{ + int err; + size_t out = 0; + + err = tls_read(sbuf->tls, dst, len, &out); + log_noise("tls_read: req=%u err=%d out=%d", len, err, (int)out); + if (!err) { + return out; + } else if (err == TLS_READ_AGAIN) { + errno = EAGAIN; + } else if (err == TLS_WRITE_AGAIN) { + log_warning("tls_sbufio_recv: got TLS_WRITE_AGAIN"); + errno = EIO; + } else { + log_warning("tls_sbufio_recv: %s", tls_error(sbuf->tls)); + errno = EIO; + } + return -1; +} + +static int tls_sbufio_send(struct SBuf *sbuf, const void *data, unsigned int len) +{ + size_t out = 0; + int err; + + err = tls_write(sbuf->tls, data, len, &out); + log_noise("tls_write: req=%u err=%d out=%d", len, err, (int)out); + if (!err) { + return out; + } else if (err == TLS_WRITE_AGAIN) { + errno = EAGAIN; + } else if (err == TLS_READ_AGAIN) { + log_warning("tls_sbufio_send: got TLS_READ_AGAIN"); + errno = EIO; + } else { + log_warning("tls_sbufio_send: %s", tls_error(sbuf->tls)); + errno = EIO; + } + return -1; +} + +static int tls_sbufio_close(struct SBuf *sbuf) +{ + log_noise("tls_close"); + if (sbuf->tls) { + tls_close(sbuf->tls); + tls_free(sbuf->tls); + sbuf->tls = NULL; + } + if (sbuf->sock > 0) { + safe_close(sbuf->sock); + sbuf->sock = 0; + } + return 0; +} + +#else + +void sbuf_tls_setup(void) { } +bool sbuf_tls_accept(SBuf *sbuf) { return false; } +bool sbuf_tls_connect(SBuf *sbuf, const char *hostname) { return false; } + +#endif diff --git a/src/server.c b/src/server.c index 1cdb2c3..cb78224 100644 --- a/src/server.c +++ b/src/server.c @@ -369,13 +369,53 @@ static bool handle_connect(PgSocket *server) disconnect_server(server, false, "sent cancel req"); } else { /* proceed with login */ - res = send_startup_packet(server); + if (cf_server_tls_sslmode > SSLMODE_DISABLED) { + slog_noise(server, "P: SSL request"); + res = send_sslreq_packet(server); + if (res) + server->wait_sslchar = true; + } else { + slog_noise(server, "P: startup"); + res = send_startup_packet(server); + } if (!res) disconnect_server(server, false, "startup pkt failed"); } return res; } +static bool handle_sslchar(PgSocket *server, struct MBuf *data) +{ + uint8_t schar = '?'; + bool ok; + + server->wait_sslchar = false; + + ok = mbuf_get_byte(data, &schar); + if (!ok || (schar != 'S' && schar != 'N') || mbuf_avail_for_read(data) != 0) { + disconnect_server(server, false, "bad sslreq answer"); + return false; + } + + if (schar == 'S') { + slog_noise(server, "launching tls"); + ok = sbuf_tls_connect(&server->sbuf, server->pool->db->host); + } else if (cf_server_tls_sslmode >= SSLMODE_REQUIRE) { + disconnect_server(server, false, "server refused SSL"); + return false; + } else { + /* proceed with non-TLS connection */ + ok = send_startup_packet(server); + } + + if (ok) { + sbuf_prepare_skip(&server->sbuf, 1); + } else { + disconnect_server(server, false, "sslreq processing failed"); + } + return ok; +} + /* callback from SBuf */ bool server_proto(SBuf *sbuf, SBufEvent evtype, struct MBuf *data) { @@ -383,6 +423,7 @@ bool server_proto(SBuf *sbuf, SBufEvent evtype, struct MBuf *data) PgSocket *server = container_of(sbuf, PgSocket, sbuf); PgPool *pool = server->pool; PktHdr pkt; + char infobuf[96]; Assert(is_server_socket(server)); Assert(server->state != SV_FREE); @@ -399,6 +440,10 @@ bool server_proto(SBuf *sbuf, SBufEvent evtype, struct MBuf *data) disconnect_client(server->link, false, "unexpected eof"); break; case SBUF_EV_READ: + if (server->wait_sslchar) { + res = handle_sslchar(server, data); + break; + } if (incomplete_header(data)) { slog_noise(server, "S: got partial header, trying to wait a bit"); break; @@ -468,6 +513,23 @@ bool server_proto(SBuf *sbuf, SBufEvent evtype, struct MBuf *data) case SBUF_EV_PKT_CALLBACK: slog_warning(server, "SBUF_EV_PKT_CALLBACK with state=%d", server->state); break; + case SBUF_EV_TLS_READY: + Assert(server->state == SV_LOGIN); + + tls_get_connection_info(server->sbuf.tls, infobuf, sizeof infobuf); + if (cf_log_connections) { + slog_info(server, "SSL established: %s", infobuf); + } else { + slog_noise(server, "SSL established: %s", infobuf); + } + + server->request_time = get_cached_time(); + res = send_startup_packet(server); + if (res) + sbuf_continue(&server->sbuf); + else + disconnect_server(server, false, "TLS startup failed"); + break; } if (!res && pool->db->admin) takeover_login_failed();