Support TLS connections.
authorMarko Kreen <markokr@gmail.com>
Mon, 3 Aug 2015 18:54:49 +0000 (21:54 +0300)
committerMarko Kreen <markokr@gmail.com>
Mon, 3 Aug 2015 20:20:08 +0000 (23:20 +0300)
15 files changed:
Makefile
config.mak.in
configure.ac
include/bouncer.h
include/pktbuf.h
include/proto.h
include/sbuf.h
include/system.h
lib
src/admin.c
src/client.c
src/main.c
src/proto.c
src/sbuf.c
src/server.c

index 34bad2855c86f03dc32e6eb6992de9c18c77d43a..b9825b766d967842d5f91fc945f2a3be6ed88076 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -40,7 +40,7 @@ pgbouncer_SOURCES = \
        include/util.h \
        include/varcache.h
 
-pgbouncer_CPPFLAGS = -Iinclude $(CARES_CFLAGS)
+pgbouncer_CPPFLAGS = -Iinclude $(CARES_CFLAGS) $(TLS_CPPFLAGS)
 
 # include libusual sources directly
 AM_FEATURES = libusual
@@ -85,7 +85,8 @@ endif
 # win32
 #
 
-pgbouncer_LDADD := $(CARES_LIBS) $(LIBS)
+pgbouncer_LDFLAGS := $(TLS_LDFLAGS)
+pgbouncer_LDADD := $(CARES_LIBS) $(TLS_LIBS) $(LIBS)
 LIBS :=
 
 EXTRA_pgbouncer_SOURCES = win32/win32support.c win32/win32support.h
index 5345c050c105ffbbec12f6dbd06930eceef94384..ba3cc33e5eb048dff4724cb4e9c10256d168e7f6 100644 (file)
@@ -55,6 +55,10 @@ nosub_top_builddir ?= @top_builddir@
 CARES_CFLAGS = @CARES_CFLAGS@
 CARES_LIBS = @CARES_LIBS@
 
+TLS_CPPFLAGS = @TLS_CPPFLAGS@
+TLS_LDFLAGS = @TLS_LDFLAGS@
+TLS_LIBS = @TLS_LIBS@
+
 XMLTO = @XMLTO@
 ASCIIDOC = @ASCIIDOC@
 DLLWRAP = @DLLWRAP@
index 40407d724277daa2e2a5101114d8f617d1fa7f2f..cf9954c2eef8eb4151de60eacdc9e97cc1d87c5d 100644 (file)
@@ -181,6 +181,8 @@ fi # !cares
 
 ## end of DNS
 
+AC_USUAL_TLS
+
 AC_USUAL_DEBUG
 AC_USUAL_CASSERT
 AC_USUAL_WERROR
@@ -199,4 +201,5 @@ echo "Results"
 echo "  c-ares = $use_cares"
 echo "  evdns = $use_evdns"
 echo "  udns = $use_udns"
+echo "  tls = $tls_support"
 echo ""
index e2dedaa88c1d11ef379cc60bfed252cf3e502c80..ec28986f105f899366b75f95f164332cc749d1fb 100644 (file)
@@ -68,6 +68,15 @@ enum PauseMode {
        P_SUSPEND = 2           /* wait for buffers to be empty */
 };
 
+enum SSLMode {
+       SSLMODE_DISABLED,
+       SSLMODE_ALLOW,
+       SSLMODE_PREFER,
+       SSLMODE_REQUIRE,
+       SSLMODE_VERIFY_CA,
+       SSLMODE_VERIFY_FULL
+};
+
 #define is_server_socket(sk) ((sk)->state >= SV_FREE)
 
 
@@ -327,6 +336,8 @@ struct PgSocket {
        bool own_user:1;        /* console client: client with same uid on unix socket */
        bool wait_for_response:1;/* console client: waits for completion of PAUSE/SUSPEND cmd */
 
+       bool wait_sslchar:1;    /* server: waiting for ssl response: S/N */
+
        usec_t connect_time;    /* when connection was made */
        usec_t request_time;    /* last activity time */
        usec_t query_start;     /* query start moment */
@@ -431,6 +442,22 @@ extern int cf_log_disconnections;
 extern int cf_log_pooler_errors;
 extern int cf_application_name_add_host;
 
+extern int cf_client_tls_sslmode;
+extern char *cf_client_tls_protocols;
+extern char *cf_client_tls_ca_file;
+extern char *cf_client_tls_cert_file;
+extern char *cf_client_tls_key_file;
+extern char *cf_client_tls_ciphers;
+extern char *cf_client_tls_dheparams;
+extern char *cf_client_tls_ecdhecurve;
+
+extern int cf_server_tls_sslmode;
+extern char *cf_server_tls_protocols;
+extern char *cf_server_tls_ca_file;
+extern char *cf_server_tls_cert_file;
+extern char *cf_server_tls_key_file;
+extern char *cf_server_tls_ciphers;
+
 extern const struct CfLookup pool_mode_map[];
 
 extern usec_t g_suspend_start;
index 9a07ffa2195f37d4512b6405e11458b634a762b9..1e52649a7e3f4616f0732da25cd482c60f595886 100644 (file)
@@ -106,6 +106,9 @@ void pktbuf_write_ExtQuery(PktBuf *buf, const char *query, int nargs, ...);
 #define pktbuf_write_Notice(buf, msg) \
        pktbuf_write_generic(buf, 'N', "sscss", "SNOTICE", "C00000", 'M', msg, "");
 
+#define pktbuf_write_SSLRequest(buf) \
+       pktbuf_write_generic(buf, PKT_SSLREQ, "")
+
 /*
  * Shortcut for creating DataRow in memory.
  */
index 3089f425fb09efb08213106c0180660b9b892c22..abd9f6feacd2c96ac0814f6a5522f16af2280fef 100644 (file)
@@ -50,6 +50,7 @@ bool welcome_client(PgSocket *client) _MUSTCHECK;
 bool answer_authreq(PgSocket *server, PktHdr *pkt) _MUSTCHECK;
 
 bool send_startup_packet(PgSocket *server) _MUSTCHECK;
+bool send_sslreq_packet(PgSocket *server) _MUSTCHECK;
 
 int scan_text_result(struct MBuf *pkt, const char *tupdesc, ...) _MUSTCHECK;
 
index 55fdfff635276b0868aecd004ced92f47489bf4d..9fcbdbd794d001d56560492455ca109cce45aed5 100644 (file)
@@ -27,6 +27,7 @@ typedef enum {
        SBUF_EV_CONNECT_OK,     /* got connection */
        SBUF_EV_FLUSH,          /* data is sent, buffer empty */
        SBUF_EV_PKT_CALLBACK,   /* next part of pkt data */
+       SBUF_EV_TLS_READY       /* TLS was established */
 } SBufEvent;
 
 /*
@@ -39,6 +40,8 @@ typedef enum {
  */
 #define SBUF_SMALL_PKT 64
 
+struct tls;
+
 /* fwd def */
 typedef struct SBuf SBuf;
 typedef struct SBufIO SBufIO;
@@ -82,6 +85,8 @@ struct SBuf {
        IOBuf *io;              /* data buffer, lazily allocated */
 
        const SBufIO *ops;      /* normal vs. TLS */
+       struct tls *tls;        /* TLS context */
+       const char *tls_host;   /* target hostname */
 };
 
 #define sbuf_socket(sbuf) ((sbuf)->sock)
@@ -90,6 +95,10 @@ void sbuf_init(SBuf *sbuf, sbuf_cb_t proto_fn);
 bool sbuf_accept(SBuf *sbuf, int read_sock, bool is_unix)  _MUSTCHECK;
 bool sbuf_connect(SBuf *sbuf, const struct sockaddr *sa, int sa_len, int timeout_sec)  _MUSTCHECK;
 
+void sbuf_tls_setup(void);
+bool sbuf_tls_accept(SBuf *sbuf)  _MUSTCHECK;
+bool sbuf_tls_connect(SBuf *sbuf, const char *hostname)  _MUSTCHECK;
+
 bool sbuf_pause(SBuf *sbuf) _MUSTCHECK;
 void sbuf_continue(SBuf *sbuf);
 bool sbuf_close(SBuf *sbuf) _MUSTCHECK;
@@ -102,6 +111,7 @@ void sbuf_prepare_fetch(SBuf *sbuf, unsigned amount);
 bool sbuf_answer(SBuf *sbuf, const void *buf, unsigned len)  _MUSTCHECK;
 
 bool sbuf_continue_with_callback(SBuf *sbuf, sbuf_libevent_cb cb)  _MUSTCHECK;
+bool sbuf_use_callback_once(SBuf *sbuf, short ev, sbuf_libevent_cb user_cb) _MUSTCHECK;
 
 /*
  * Returns true if SBuf is has no data buffered
index a3417ba3d1934e12678e7789c1e63faa43c59e46..cb531ae6ca1fcaf656df96d04fc29dcdb6669556 100644 (file)
@@ -32,6 +32,8 @@
 #include <stdarg.h>
 #include <limits.h>
 
+#include <usual/tls/tls.h>
+
 #ifdef HAVE_CRYPT_H
 #include <crypt.h>
 #endif
diff --git a/lib b/lib
index 7dd946ae6023574eefdb9d254faae46805c16c07..7177b2af4f65037d19ff193073b06d6347d4b614 160000 (submodule)
--- a/lib
+++ b/lib
@@ -1 +1 @@
-Subproject commit 7dd946ae6023574eefdb9d254faae46805c16c07
+Subproject commit 7177b2af4f65037d19ff193073b06d6347d4b614
index ef6239a75fa4c8cd1d26f77a608533f22e4bf843..bc389db4c4d79384753958e275dc5db137be93eb 100644 (file)
@@ -270,7 +270,7 @@ static bool send_one_fd(PgSocket *admin,
        msg.msg_iovlen = 1;
 
        /* attach a fd */
-       if (pga_is_unix(&admin->remote_addr) && admin->own_user) {
+       if (pga_is_unix(&admin->remote_addr) && admin->own_user && !admin->sbuf.tls) {
                msg.msg_control = cntbuf;
                msg.msg_controllen = sizeof(cntbuf);
 
@@ -314,6 +314,10 @@ static bool show_one_fd(PgSocket *admin, PgSocket *sk)
        char addrbuf[PGADDR_BUF];
        const char *password = NULL;
 
+       /* Skip TLS sockets */
+       if (sk->sbuf.tls || (sk->link && sk->link->sbuf.tls))
+               return true;
+
        mbuf_init_fixed_reader(&tmp, sk->cancel_key, 8);
        if (!mbuf_get_uint64be(&tmp, &ckey))
                return false;
@@ -546,8 +550,8 @@ static bool admin_show_users(PgSocket *admin, const char *arg)
        return true;
 }
 
-#define SKF_STD "sssssisiTTssi"
-#define SKF_DBG "sssssisiTTssiiiiiiii"
+#define SKF_STD "sssssisiTTssis"
+#define SKF_DBG "sssssisiTTssisiiiiiii"
 
 static void socket_header(PktBuf *buf, bool debug)
 {
@@ -555,7 +559,8 @@ static void socket_header(PktBuf *buf, bool debug)
                                    "type", "user", "database", "state",
                                    "addr", "port", "local_addr", "local_port",
                                    "connect_time", "request_time",
-                                   "ptr", "link", "remote_pid",
+                                   "ptr", "link", "remote_pid", "tls",
+                                   /* debug follows */
                                    "recv_pos", "pkt_pos", "pkt_remain",
                                    "send_pos", "send_remain",
                                    "pkt_avail", "send_avail");
@@ -573,6 +578,7 @@ static void socket_row(PktBuf *buf, PgSocket *sk, const char *state, bool debug)
        char ptrbuf[128], linkbuf[128];
        char l_addr[PGADDR_BUF], r_addr[PGADDR_BUF];
        IOBuf *io = sk->sbuf.io;
+       char infobuf[96] = "";
 
        if (io) {
                pkt_avail = iobuf_amount_parse(sk->sbuf.io);
@@ -597,6 +603,9 @@ static void socket_row(PktBuf *buf, PgSocket *sk, const char *state, bool debug)
        if (is_server_socket(sk) && remote_pid == 0)
                remote_pid = be32dec(sk->cancel_key);
 
+       if (sk->sbuf.tls)
+               tls_get_connection_info(sk->sbuf.tls, infobuf, sizeof infobuf);
+
        pktbuf_write_DataRow(buf, debug ? SKF_DBG : SKF_STD,
                             is_server_socket(sk) ? "S" :"C",
                             sk->auth_user ? sk->auth_user->name : "(nouser)",
@@ -605,7 +614,8 @@ static void socket_row(PktBuf *buf, PgSocket *sk, const char *state, bool debug)
                             l_addr, pga_port(&sk->local_addr),
                             sk->connect_time,
                             sk->request_time,
-                            ptrbuf, linkbuf, remote_pid,
+                            ptrbuf, linkbuf, remote_pid, infobuf,
+                            /* debug */
                             io ? io->recv_pos : 0,
                             io ? io->parse_pos : 0,
                             sk->sbuf.pkt_remain,
index ec853749911441af20b2124c2e24053903aa97f8..6c243bdc7b281b5b9c95dc61869be70954991223 100644 (file)
@@ -141,8 +141,17 @@ static bool finish_set_pool(PgSocket *client, bool takeover)
                return false;
        }
 
-       if (cf_log_connections)
-               slog_info(client, "login attempt: db=%s user=%s", client->db->name, client->auth_user->name);
+       if (cf_log_connections) {
+               if (client->sbuf.tls) {
+                       char infobuf[96] = "";
+                       tls_get_connection_info(client->sbuf.tls, infobuf, sizeof infobuf);
+                       slog_info(client, "login attempt: db=%s user=%s tls=%s",
+                                 client->db->name, client->auth_user->name, infobuf);
+               } else {
+                       slog_info(client, "login attempt: db=%s user=%s tls=no",
+                                 client->db->name, client->auth_user->name);
+               }
+       }
 
        if (!check_fast_fail(client))
                return false;
@@ -433,9 +442,27 @@ static bool handle_client_startup(PgSocket *client, PktHdr *pkt)
        switch (pkt->type) {
        case PKT_SSLREQ:
                slog_noise(client, "C: req SSL");
-               slog_noise(client, "P: nak");
 
+#ifdef USE_TLS
+               if (client->sbuf.tls) {
+                       disconnect_client(client, false, "SSL req inside SSL");
+                       return false;
+               }
+               if (cf_client_tls_sslmode != SSLMODE_DISABLED) {
+                       slog_noise(client, "P: SSL ack");
+                       if (!sbuf_answer(&client->sbuf, "S", 1)) {
+                               disconnect_client(client, false, "failed to ack SSL");
+                               return false;
+                       }
+                       if (!sbuf_tls_accept(&client->sbuf)) {
+                               disconnect_client(client, false, "failed to accept SSL");
+                               return false;
+                       }
+                       break;
+               }
+#endif
                /* reject SSL attempt */
+               slog_noise(client, "P: nak");
                if (!sbuf_answer(&client->sbuf, "N", 1)) {
                        disconnect_client(client, false, "failed to nak SSL");
                        return false;
@@ -445,6 +472,12 @@ static bool handle_client_startup(PgSocket *client, PktHdr *pkt)
                disconnect_client(client, true, "Old V2 protocol not supported");
                return false;
        case PKT_STARTUP:
+               /* require SSL except on unix socket */
+               if (cf_client_tls_sslmode >= SSLMODE_REQUIRE && !client->sbuf.tls && !pga_is_unix(&client->remote_addr)) {
+                       disconnect_client(client, true, "SSL required");
+                       return false;
+               }
+
                if (client->pool && !client->wait_for_user_conn && !client->wait_for_user) {
                        disconnect_client(client, true, "client re-sent startup pkt");
                        return false;
@@ -633,6 +666,10 @@ bool client_proto(SBuf *sbuf, SBufEvent evtype, struct MBuf *data)
        case SBUF_EV_PKT_CALLBACK:
                /* unused ATM */
                break;
+       case SBUF_EV_TLS_READY:
+               sbuf_continue(&client->sbuf);
+               res = true;
+               break;
        }
        return res;
 }
index f194eb1506f2544d068b58c54d12c4f42dc8431b..ef0344db94c19c715949a4d868d937d00b196707 100644 (file)
 #include <sys/resource.h>
 #endif
 
+#ifndef DEFAULT_TLS_CIPHERS
+/* enable only PFS, deprioritize/remove slower ones */
+#define DEFAULT_TLS_CIPHERS "EECDH+HIGH:EDH+HIGH:+AES256:+SHA256:+SHA384:+SSLv3:+EDH:-CAMELLIA:-3DES:!DSS:!aNULL"
+#endif
+
 static const char usage_str[] =
 "Usage: %s [OPTION]... config.ini\n"
 "  -d, --daemon           Run in background (as a daemon)\n"
@@ -137,6 +142,22 @@ int cf_log_disconnections;
 int cf_log_pooler_errors;
 int cf_application_name_add_host;
 
+int cf_client_tls_sslmode;
+char *cf_client_tls_protocols;
+char *cf_client_tls_ca_file;
+char *cf_client_tls_cert_file;
+char *cf_client_tls_key_file;
+char *cf_client_tls_ciphers;
+char *cf_client_tls_dheparams;
+char *cf_client_tls_ecdhecurve;
+
+int cf_server_tls_sslmode;
+char *cf_server_tls_protocols;
+char *cf_server_tls_ca_file;
+char *cf_server_tls_cert_file;
+char *cf_server_tls_key_file;
+char *cf_server_tls_ciphers;
+
 /*
  * config file description
  */
@@ -162,6 +183,18 @@ const struct CfLookup pool_mode_map[] = {
        { NULL }
 };
 
+const struct CfLookup sslmode_map[] = {
+       { "disabled", SSLMODE_DISABLED },
+#ifdef USE_TLS
+       { "allow", SSLMODE_ALLOW },
+       { "prefer", SSLMODE_PREFER },
+       { "require", SSLMODE_REQUIRE },
+       { "verify-ca", SSLMODE_VERIFY_CA },
+       { "verify-full", SSLMODE_VERIFY_FULL },
+#endif
+       { NULL }
+};
+
 static const struct CfKey bouncer_params [] = {
 CF_ABS("job_name", CF_STR, cf_jobname, CF_NO_RELOAD, "pgbouncer"),
 #ifdef WIN32
@@ -235,6 +268,23 @@ CF_ABS("log_connections", CF_INT, cf_log_connections, 0, "1"),
 CF_ABS("log_disconnections", CF_INT, cf_log_disconnections, 0, "1"),
 CF_ABS("log_pooler_errors", CF_INT, cf_log_pooler_errors, 0, "1"),
 CF_ABS("application_name_add_host", CF_INT, cf_application_name_add_host, 0, "0"),
+
+CF_ABS("client_tls_sslmode", CF_LOOKUP(sslmode_map), cf_client_tls_sslmode, CF_NO_RELOAD, "disabled"),
+CF_ABS("client_tls_ca_file", CF_STR, cf_client_tls_ca_file, CF_NO_RELOAD, ""),
+CF_ABS("client_tls_cert_file", CF_STR, cf_client_tls_cert_file, CF_NO_RELOAD, ""),
+CF_ABS("client_tls_key_file", CF_STR, cf_client_tls_key_file, CF_NO_RELOAD, ""),
+CF_ABS("client_tls_protocols", CF_STR, cf_client_tls_protocols, CF_NO_RELOAD, "all"),
+CF_ABS("client_tls_ciphers", CF_STR, cf_client_tls_ciphers, CF_NO_RELOAD, DEFAULT_TLS_CIPHERS),
+CF_ABS("client_tls_dheparams", CF_STR, cf_client_tls_dheparams, CF_NO_RELOAD, "auto"),
+CF_ABS("client_tls_ecdhcurve", CF_STR, cf_client_tls_ecdhecurve, CF_NO_RELOAD, "auto"),
+
+CF_ABS("server_tls_sslmode", CF_LOOKUP(sslmode_map), cf_server_tls_sslmode, CF_NO_RELOAD, "disabled"),
+CF_ABS("server_tls_ca_file", CF_STR, cf_server_tls_ca_file, CF_NO_RELOAD, ""),
+CF_ABS("server_tls_cert_file", CF_STR, cf_server_tls_cert_file, CF_NO_RELOAD, ""),
+CF_ABS("server_tls_key_file", CF_STR, cf_server_tls_key_file, CF_NO_RELOAD, ""),
+CF_ABS("server_tls_protocols", CF_STR, cf_server_tls_protocols, CF_NO_RELOAD, "all"),
+CF_ABS("server_tls_ciphers", CF_STR, cf_server_tls_ciphers, CF_NO_RELOAD, DEFAULT_TLS_CIPHERS),
+
 {NULL}
 };
 
@@ -736,6 +786,8 @@ int main(int argc, char *argv[])
        init_caches();
        logging_prefix_cb = log_socket_prefix;
 
+       sbuf_tls_setup();
+
        /* prefer cmdline over config for username */
        if (arg_username) {
                if (cf_username)
index acce1f64836508a894781fa272355d8494ea1cd8..8aad019dcc90ef77afb88df4c77d396793a9b56b 100644 (file)
@@ -369,6 +369,13 @@ bool send_startup_packet(PgSocket *server)
        return pktbuf_send_immediate(pkt, server);
 }
 
+bool send_sslreq_packet(PgSocket *server)
+{
+       int res;
+       SEND_wrap(16, pktbuf_write_SSLRequest, res, server);
+       return res;
+}
+
 int scan_text_result(struct MBuf *pkt, const char *tupdesc, ...)
 {
        const char *val = NULL;
index a00076fa1e71872358f164660f6d8f17cd752695..3c59c9b2ec31ff36ff8f0194e6cb64d2b7b59376 100644 (file)
@@ -40,6 +40,7 @@ enum WaitType {
        W_CONNECT,
        W_RECV,
        W_SEND,
+       W_ONCE
 };
 
 #define AssertSanity(sbuf) do { \
@@ -77,6 +78,20 @@ static const SBufIO raw_sbufio_ops = {
        raw_sbufio_close
 };
 
+/* I/O over TLS */
+#ifdef USE_TLS
+static int tls_sbufio_recv(struct SBuf *sbuf, void *dst, unsigned int len);
+static int tls_sbufio_send(struct SBuf *sbuf, const void *data, unsigned int len);
+static int tls_sbufio_close(struct SBuf *sbuf);
+static const SBufIO tls_sbufio_ops = {
+       tls_sbufio_recv,
+       tls_sbufio_send,
+       tls_sbufio_close
+};
+static void sbuf_tls_accept_cb(int fd, short flags, void *_sbuf);
+static void sbuf_tls_connect_cb(int fd, short flags, void *_sbuf);
+#endif
+
 /*********************************
  * Public functions
  *********************************/
@@ -237,6 +252,31 @@ bool sbuf_continue_with_callback(SBuf *sbuf, sbuf_libevent_cb user_cb)
        return true;
 }
 
+bool sbuf_use_callback_once(SBuf *sbuf, short ev, sbuf_libevent_cb user_cb)
+{
+       int err;
+       AssertActive(sbuf);
+
+       if (sbuf->wait_type != W_NONE) {
+               err = event_del(&sbuf->ev);
+               sbuf->wait_type = W_NONE; /* make sure its called only once */
+               if (err < 0) {
+                       log_warning("sbuf_queue_once: event_del failed: %s", strerror(errno));
+                       return false;
+               }
+       }
+
+       /* setup one one-off event handler */
+       event_set(&sbuf->ev, sbuf->sock, ev, user_cb, sbuf);
+       err = event_add(&sbuf->ev, NULL);
+       if (err < 0) {
+               log_warning("sbuf_queue_once: event_add failed: %s", strerror(errno));
+               return false;
+       }
+       sbuf->wait_type = W_ONCE;
+       return true;
+}
+
 /* socket cleanup & close: keeps .handler and .arg values */
 bool sbuf_close(SBuf *sbuf)
 {
@@ -758,3 +798,285 @@ static int raw_sbufio_close(struct SBuf *sbuf)
        return 0;
 }
 
+/*
+ * TLS support.
+ */
+
+#ifdef USE_TLS
+
+static struct tls_config *client_accept_conf;
+static struct tls_config *server_connect_conf;
+static struct tls *client_accept_base;
+
+/*
+ * TLS setup
+ */
+
+static void setup_tls(struct tls_config *conf, const char *pfx, int sslmode,
+                     const char *protocols, const char *ciphers,
+                     const char *keyfile, const char *certfile, const char *cafile,
+                     const char *dheparams, const char *ecdhecurve,
+                     bool does_connect)
+{
+       int err;
+       if (*protocols) {
+               uint32_t protos = TLS_PROTOCOLS_ALL;
+               err = tls_config_parse_protocols(&protos, protocols);
+               if (err) {
+                       log_error("Invalid %s_protocols: %s", pfx, protocols);
+               } else {
+                       tls_config_set_protocols(conf, protos);
+               }
+       }
+       if (*ciphers) {
+               err = tls_config_set_ciphers(conf, ciphers);
+               if (err)
+                       log_error("Invalid %s_ciphers: %s", pfx, ciphers);
+       }
+       if (*dheparams) {
+               err = tls_config_set_dheparams(conf, dheparams);
+               if (err)
+                       log_error("Invalid %s_dheparams: %s", pfx, dheparams);
+       }
+       if (*ecdhecurve) {
+               err = tls_config_set_ecdhecurve(conf, ecdhecurve);
+               if (err)
+                       log_error("Invalid %s_ecdhecurve: %s", pfx, ecdhecurve);
+       }
+       if (*cafile) {
+               err = tls_config_set_ca_file(conf, cafile);
+               if (err)
+                       log_error("Invalid %s_ca_file: %s", pfx, cafile);
+       }
+       if (*keyfile) {
+               err = tls_config_set_key_file(conf, keyfile);
+               if (err)
+                       log_error("Invalid %s_key_file: %s", pfx, keyfile);
+       }
+       if (*certfile) {
+               err = tls_config_set_cert_file(conf, certfile);
+               if (err)
+                       log_error("Invalid %s_cert_file: %s", pfx, certfile);
+       }
+
+       if (sslmode == SSLMODE_VERIFY_FULL) {
+               tls_config_verify(conf);
+       } else if (sslmode == SSLMODE_VERIFY_CA) {
+               tls_config_insecure_noverifyname(conf);
+       } else {
+               tls_config_insecure_noverifycert(conf);
+               tls_config_insecure_noverifyname(conf);
+       }
+}
+
+void sbuf_tls_setup(void)
+{
+       int err;
+
+       if (cf_client_tls_sslmode != SSLMODE_DISABLED) {
+               if (!*cf_client_tls_key_file || !*cf_client_tls_cert_file)
+                       die("To allow TLS connections from clients, client_tls_key_file and client_tls_cert_file must be set.");
+       }
+       if (cf_auth_type == AUTH_CERT) {
+               if (cf_client_tls_sslmode != SSLMODE_VERIFY_FULL)
+                       die("auth_type=cert requires client_tls_sslmode=SSLMODE_VERIFY_FULL");
+               if (*cf_client_tls_ca_file == '\0')
+                       die("auth_type=cert requires client_tls_ca_file");
+       } else if (cf_client_tls_sslmode > SSLMODE_VERIFY_CA && *cf_client_tls_ca_file == '\0') {
+               die("client_tls_sslmode requires client_tls_ca_file");
+       }
+
+       err = tls_init();
+       if (err)
+               fatal("tls_init failed");
+
+       if (cf_server_tls_sslmode != SSLMODE_DISABLED) {
+               server_connect_conf = tls_config_new();
+               if (!server_connect_conf)
+                       die("tls_config_new failed 1");
+               setup_tls(server_connect_conf, "server_tls", cf_server_tls_sslmode,
+                         cf_server_tls_protocols, cf_server_tls_ciphers,
+                         cf_server_tls_key_file, cf_server_tls_cert_file,
+                         cf_server_tls_ca_file, "", "", true);
+       }
+
+       if (cf_client_tls_sslmode != SSLMODE_DISABLED) {
+               client_accept_conf = tls_config_new();
+               if (!client_accept_conf)
+                       die("tls_config_new failed 2");
+               setup_tls(client_accept_conf, "client_tls", cf_client_tls_sslmode,
+                         cf_client_tls_protocols, cf_client_tls_ciphers,
+                         cf_client_tls_key_file, cf_client_tls_cert_file,
+                         cf_client_tls_ca_file, cf_client_tls_dheparams,
+                         cf_client_tls_ecdhecurve, false);
+
+               client_accept_base = tls_server();
+               if (!client_accept_base)
+                       die("server_base failed");
+               err = tls_configure(client_accept_base, client_accept_conf);
+               if (err)
+                       die("TLS setup failed: %s", tls_error(client_accept_base));
+       }
+}
+
+/*
+ * Accept TLS connection.
+ */
+
+static bool handle_tls_accept(struct SBuf *sbuf)
+{
+       int err;
+
+       err = tls_accept_fds(client_accept_base, &sbuf->tls, sbuf->sock, sbuf->sock);
+       log_noise("tls_accept_fds: err=%d", err);
+       if (err == TLS_READ_AGAIN) {
+               return sbuf_use_callback_once(sbuf, EV_READ, sbuf_tls_accept_cb);
+       } else if (err == TLS_WRITE_AGAIN) {
+               return sbuf_use_callback_once(sbuf, EV_WRITE, sbuf_tls_accept_cb);
+       } else if (err == 0) {
+               sbuf_call_proto(sbuf, SBUF_EV_TLS_READY);
+               return true;
+       } else {
+               log_warning("TLS accept error: %s", tls_error(sbuf->tls));
+               return false;
+       }
+}
+
+static void sbuf_tls_accept_cb(int fd, short flags, void *_sbuf)
+{
+       SBuf *sbuf = _sbuf;
+       sbuf->wait_type = W_NONE;
+       if (!handle_tls_accept(sbuf))
+               sbuf_call_proto(sbuf, SBUF_EV_RECV_FAILED);
+}
+
+bool sbuf_tls_accept(SBuf *sbuf)
+{
+       sbuf->ops = &tls_sbufio_ops;
+       return handle_tls_accept(sbuf);
+}
+
+/*
+ * Connect to remote TLS host.
+ */
+
+static bool handle_tls_connect(SBuf *sbuf)
+{
+       int err;
+
+       err = tls_connect_fds(sbuf->tls, sbuf->sock, sbuf->sock, sbuf->tls_host);
+       log_noise("tls_connect_fds: err=%d", err);
+       if (err == TLS_READ_AGAIN) {
+               return sbuf_use_callback_once(sbuf, EV_READ, sbuf_tls_connect_cb);
+       } else if (err == TLS_WRITE_AGAIN) {
+               return sbuf_use_callback_once(sbuf, EV_WRITE, sbuf_tls_connect_cb);
+       } else if (err == 0) {
+               sbuf_call_proto(sbuf, SBUF_EV_TLS_READY);
+               return true;
+       } else {
+               log_warning("TLS connect error: %s", tls_error(sbuf->tls));
+               return false;
+       }
+}
+
+static void sbuf_tls_connect_cb(int fd, short flags, void *_sbuf)
+{
+       SBuf *sbuf = _sbuf;
+       sbuf->wait_type = W_NONE;
+       if (!handle_tls_connect(sbuf))
+               sbuf_call_proto(sbuf, SBUF_EV_RECV_FAILED);
+}
+
+bool sbuf_tls_connect(SBuf *sbuf, const char *hostname)
+{
+       struct tls *ctls;
+       int err;
+
+       if (cf_server_tls_sslmode != SSLMODE_VERIFY_FULL)
+               hostname = NULL;
+
+       ctls = tls_client();
+       if (!ctls)
+               return false;
+       err = tls_configure(ctls, server_connect_conf);
+       if (err) {
+               log_error("tls client config failed: %s", tls_error(ctls));
+               tls_free(ctls);
+               return false;
+       }
+
+       sbuf->tls = ctls;
+       sbuf->tls_host = hostname;
+       sbuf->ops = &tls_sbufio_ops;
+
+       return handle_tls_connect(sbuf);
+}
+
+/*
+ * TLS IO ops.
+ */
+
+static int tls_sbufio_recv(struct SBuf *sbuf, void *dst, unsigned int len)
+{
+       int err;
+       size_t out = 0;
+
+       err = tls_read(sbuf->tls, dst, len, &out);
+       log_noise("tls_read: req=%u err=%d out=%d", len, err, (int)out);
+       if (!err) {
+               return out;
+       } else if (err == TLS_READ_AGAIN) {
+               errno = EAGAIN;
+       } else if (err == TLS_WRITE_AGAIN) {
+               log_warning("tls_sbufio_recv: got TLS_WRITE_AGAIN");
+               errno = EIO;
+       } else {
+               log_warning("tls_sbufio_recv: %s", tls_error(sbuf->tls));
+               errno = EIO;
+       }
+       return -1;
+}
+
+static int tls_sbufio_send(struct SBuf *sbuf, const void *data, unsigned int len)
+{
+       size_t out = 0;
+       int err;
+
+       err = tls_write(sbuf->tls, data, len, &out);
+       log_noise("tls_write: req=%u err=%d out=%d", len, err, (int)out);
+       if (!err) {
+               return out;
+       } else if (err == TLS_WRITE_AGAIN) {
+               errno = EAGAIN;
+       } else if (err == TLS_READ_AGAIN) {
+               log_warning("tls_sbufio_send: got TLS_READ_AGAIN");
+               errno = EIO;
+       } else {
+               log_warning("tls_sbufio_send: %s", tls_error(sbuf->tls));
+               errno = EIO;
+       }
+       return -1;
+}
+
+static int tls_sbufio_close(struct SBuf *sbuf)
+{
+       log_noise("tls_close");
+       if (sbuf->tls) {
+               tls_close(sbuf->tls);
+               tls_free(sbuf->tls);
+               sbuf->tls = NULL;
+       }
+       if (sbuf->sock > 0) {
+               safe_close(sbuf->sock);
+               sbuf->sock = 0;
+       }
+       return 0;
+}
+
+#else
+
+void sbuf_tls_setup(void) { }
+bool sbuf_tls_accept(SBuf *sbuf) { return false; }
+bool sbuf_tls_connect(SBuf *sbuf, const char *hostname) { return false; }
+
+#endif
index 1cdb2c30c127007419f0ba577c5f6ddfbafa19ad..cb782243b9858ea6476d3b96066245ddc8f9aab3 100644 (file)
@@ -369,13 +369,53 @@ static bool handle_connect(PgSocket *server)
                disconnect_server(server, false, "sent cancel req");
        } else {
                /* proceed with login */
-               res = send_startup_packet(server);
+               if (cf_server_tls_sslmode > SSLMODE_DISABLED) {
+                       slog_noise(server, "P: SSL request");
+                       res = send_sslreq_packet(server);
+                       if (res)
+                               server->wait_sslchar = true;
+               } else {
+                       slog_noise(server, "P: startup");
+                       res = send_startup_packet(server);
+               }
                if (!res)
                        disconnect_server(server, false, "startup pkt failed");
        }
        return res;
 }
 
+static bool handle_sslchar(PgSocket *server, struct MBuf *data)
+{
+       uint8_t schar = '?';
+       bool ok;
+
+       server->wait_sslchar = false;
+
+       ok = mbuf_get_byte(data, &schar);
+       if (!ok || (schar != 'S' && schar != 'N') || mbuf_avail_for_read(data) != 0) {
+               disconnect_server(server, false, "bad sslreq answer");
+               return false;
+       }
+
+       if (schar == 'S') {
+               slog_noise(server, "launching tls");
+               ok = sbuf_tls_connect(&server->sbuf, server->pool->db->host);
+       } else if (cf_server_tls_sslmode >= SSLMODE_REQUIRE) {
+               disconnect_server(server, false, "server refused SSL");
+               return false;
+       } else {
+               /* proceed with non-TLS connection */
+               ok = send_startup_packet(server);
+       }
+
+       if (ok) {
+               sbuf_prepare_skip(&server->sbuf, 1);
+       } else {
+               disconnect_server(server, false, "sslreq processing failed");
+       }
+       return ok;
+}
+
 /* callback from SBuf */
 bool server_proto(SBuf *sbuf, SBufEvent evtype, struct MBuf *data)
 {
@@ -383,6 +423,7 @@ bool server_proto(SBuf *sbuf, SBufEvent evtype, struct MBuf *data)
        PgSocket *server = container_of(sbuf, PgSocket, sbuf);
        PgPool *pool = server->pool;
        PktHdr pkt;
+       char infobuf[96];
 
        Assert(is_server_socket(server));
        Assert(server->state != SV_FREE);
@@ -399,6 +440,10 @@ bool server_proto(SBuf *sbuf, SBufEvent evtype, struct MBuf *data)
                disconnect_client(server->link, false, "unexpected eof");
                break;
        case SBUF_EV_READ:
+               if (server->wait_sslchar) {
+                       res = handle_sslchar(server, data);
+                       break;
+               }
                if (incomplete_header(data)) {
                        slog_noise(server, "S: got partial header, trying to wait a bit");
                        break;
@@ -468,6 +513,23 @@ bool server_proto(SBuf *sbuf, SBufEvent evtype, struct MBuf *data)
        case SBUF_EV_PKT_CALLBACK:
                slog_warning(server, "SBUF_EV_PKT_CALLBACK with state=%d", server->state);
                break;
+       case SBUF_EV_TLS_READY:
+               Assert(server->state == SV_LOGIN);
+
+               tls_get_connection_info(server->sbuf.tls, infobuf, sizeof infobuf);
+               if (cf_log_connections) {
+                       slog_info(server, "SSL established: %s", infobuf);
+               } else {
+                       slog_noise(server, "SSL established: %s", infobuf);
+               }
+
+               server->request_time = get_cached_time();
+               res = send_startup_packet(server);
+               if (res)
+                       sbuf_continue(&server->sbuf);
+               else
+                       disconnect_server(server, false, "TLS startup failed");
+               break;
        }
        if (!res && pool->db->admin)
                takeover_login_failed();