]> granicus.if.org Git - pgbouncer/commitdiff
big pkt code reorg.
authorMarko Kreen <markokr@gmail.com>
Wed, 15 Aug 2007 13:20:48 +0000 (13:20 +0000)
committerMarko Kreen <markokr@gmail.com>
Wed, 15 Aug 2007 13:20:48 +0000 (13:20 +0000)
in old code the 'pkt' contained whateved data was available
from SBuf, which made sanity checks complex.

this patch creates wrapper structure for passing current packet
info around.

src/admin.c
src/admin.h
src/bouncer.h
src/client.c
src/mbuf.h
src/proto.c
src/proto.h
src/server.c
src/takeover.c

index 2b41b53c6e698e991f74f59013d5ca6d342e259d..ce594ab257f956620e45455f9137f09b87e79d6a 100644 (file)
@@ -861,20 +861,20 @@ static bool admin_parse_query(PgSocket *admin, const char *q)
 }
 
 /* handle packets */
-bool admin_handle_client(PgSocket *admin, MBuf *pkt, int pkt_type, int pkt_len)
+bool admin_handle_client(PgSocket *admin, PktHdr *pkt)
 {
        const char *q;
        bool res;
 
        /* dont tolerate partial packets */
-       if (mbuf_avail(pkt) < pkt_len - NEW_HEADER_LEN) {
+       if (incomplete_pkt(pkt)) {
                disconnect_client(admin, true, "incomplete pkt");
                return false;
        }
 
-       switch (pkt_type) {
+       switch (pkt->type) {
        case 'Q':
-               q = mbuf_get_string(pkt);
+               q = mbuf_get_string(&pkt->data);
                if (!q) {
                        disconnect_client(admin, true, "incomplete query");
                        return false;
@@ -882,13 +882,13 @@ bool admin_handle_client(PgSocket *admin, MBuf *pkt, int pkt_type, int pkt_len)
                log_debug("got admin query: %s", q);
                res = admin_parse_query(admin, q);
                if (res)
-                       sbuf_prepare_skip(&admin->sbuf, pkt_len);
+                       sbuf_prepare_skip(&admin->sbuf, pkt->len);
                return res;
        case 'X':
                disconnect_client(admin, false, "close req");
                break;
        default:
-               admin_error(admin, "unsupported pkt type: %d", pkt_type);
+               admin_error(admin, "unsupported pkt type: %d", pkt_desc(pkt));
                disconnect_client(admin, true, "bad pkt");
                break;
        }
index 6163fa9bcd51a1aff55d81b70b2291d94192228a..42e810e8d840b34be403dea60ae16950cbd2e837 100644 (file)
@@ -15,7 +15,7 @@
  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  */
-bool admin_handle_client(PgSocket *client, MBuf *pkt, int pkt_type, int pkt_len);
+bool admin_handle_client(PgSocket *client, PktHdr *pkt);
 bool admin_pre_login(PgSocket *client);
 void admin_setup(void);
 bool admin_error(PgSocket *console, const char *fmt, ...);
index c85ffc05ca25933be6c6c4bc73820e71d1c2c416..f14053885414e7d16189d55078542a7b8d78a1e7 100644 (file)
@@ -64,6 +64,7 @@ typedef struct PgPool PgPool;
 typedef struct PgStats PgStats;
 typedef struct PgAddr PgAddr;
 typedef enum SocketState SocketState;
+typedef struct PktHdr PktHdr;
 
 #include "util.h"
 #include "list.h"
@@ -110,6 +111,8 @@ typedef enum SocketState SocketState;
 /* new style V3 packet header len - type:1b, len:4b */ 
 #define NEW_HEADER_LEN 5
 
+#define BACKENDKEY_LEN 8
+
 struct PgAddr {
        struct in_addr ip_addr;
        unsigned short port;
@@ -237,7 +240,7 @@ struct PgSocket {
        usec_t          query_start;    /* query start moment */
 
        char            salt[4];
-       uint8           cancel_key[8];
+       uint8           cancel_key[BACKENDKEY_LEN];
        PgUser *        auth_user;
        PgAddr          addr;
 
index 57ff1cec900ecb1603dd97e9b6b46c40ea379d0f..fd84c55847b336752700cda720ef02fcb5093f51 100644 (file)
@@ -94,16 +94,16 @@ bool set_pool(PgSocket *client, const char *dbname, const char *username)
        return true;
 }
 
-static bool decide_startup_pool(PgSocket *client, MBuf *pkt)
+static bool decide_startup_pool(PgSocket *client, PktHdr *pkt)
 {
        const char *username = NULL, *dbname = NULL;
        const char *key, *val;
 
        while (1) {
-               key = mbuf_get_string(pkt);
+               key = mbuf_get_string(&pkt->data);
                if (!key || *key == 0)
                        break;
-               val = mbuf_get_string(pkt);
+               val = mbuf_get_string(&pkt->data);
                if (!val)
                        break;
 
@@ -172,33 +172,28 @@ static bool send_client_authreq(PgSocket *client)
 }
 
 /* decide on packets of client in login phase */
-static bool handle_client_startup(PgSocket *client, MBuf *pkt)
+static bool handle_client_startup(PgSocket *client, PktHdr *pkt)
 {
-       unsigned pkt_type;
-       unsigned pkt_len;
        const char *passwd;
 
        SBuf *sbuf = &client->sbuf;
 
        /* don't tolerate partial packets */
-       if (!get_header(pkt, &pkt_type, &pkt_len)) {
-               disconnect_client(client, true, "client sent bad pkt header");
+       if (incomplete_pkt(pkt)) {
+               disconnect_client(client, true, "client sent partial pkt in startup phase");
                return false;
        }
 
        if (client->wait_for_welcome) {
                if  (finish_client_login(client)) {
                        /* the packet was already parsed */
-                       sbuf_prepare_skip(sbuf, pkt_len);
+                       sbuf_prepare_skip(sbuf, pkt->len);
                        return true;
                } else
                        return false;
        }
 
-       slog_noise(client, "pkt='%c' len=%d",
-                  pkt_type < 256 ? pkt_type : '?', pkt_len);
-
-       switch (pkt_type) {
+       switch (pkt->type) {
        case PKT_SSLREQ:
                slog_noise(client, "C: req SSL");
                slog_noise(client, "P: nak");
@@ -210,10 +205,6 @@ static bool handle_client_startup(PgSocket *client, MBuf *pkt)
                }
                break;
        case PKT_STARTUP:
-               if (mbuf_avail(pkt) < pkt_len - 8) {
-                       disconnect_client(client, true, "client sent partial pkt in startup");
-                       return false;
-               }
                if (client->pool) {
                        disconnect_client(client, true, "client re-sent startup pkt");
                        return false;
@@ -238,18 +229,13 @@ static bool handle_client_startup(PgSocket *client, MBuf *pkt)
                }
                break;
        case 'p':               /* PasswordMessage */
-               if (mbuf_avail(pkt) < pkt_len - NEW_HEADER_LEN) {
-                       disconnect_client(client, true, "client sent partial pkt in startup");
-                       return false;
-               }
-
                /* haven't requested it */
                if (cf_auth_type <= AUTH_TRUST) {
                        disconnect_client(client, true, "unrequested passwd pkt");
                        return false;
                }
 
-               passwd = mbuf_get_string(pkt);
+               passwd = mbuf_get_string(&pkt->data);
                if (passwd && check_client_passwd(client, passwd)) {
                        if (!finish_client_login(client))
                                return false;
@@ -259,9 +245,9 @@ static bool handle_client_startup(PgSocket *client, MBuf *pkt)
                }
                break;
        case PKT_CANCEL:
-               if (mbuf_avail(pkt) == 8) {
-                       const uint8 *key = mbuf_get_bytes(pkt, 8);
-                       memcpy(client->cancel_key, key, 8);
+               if (mbuf_avail(&pkt->data) == BACKENDKEY_LEN) {
+                       const uint8 *key = mbuf_get_bytes(&pkt->data, BACKENDKEY_LEN);
+                       memcpy(client->cancel_key, key, BACKENDKEY_LEN);
                        accept_cancel_request(client);
                } else
                        disconnect_client(client, false, "bad cancel request");
@@ -270,25 +256,17 @@ static bool handle_client_startup(PgSocket *client, MBuf *pkt)
                disconnect_client(client, false, "bad packet");
                return false;
        }
-       sbuf_prepare_skip(sbuf, pkt_len);
+       sbuf_prepare_skip(sbuf, pkt->len);
        client->request_time = get_cached_time();
        return true;
 }
 
 /* decide on packets of logged-in client */
-static bool handle_client_work(PgSocket *client, MBuf *pkt)
+static bool handle_client_work(PgSocket *client, PktHdr *pkt)
 {
-       unsigned pkt_type;
-       unsigned pkt_len;
        SBuf *sbuf = &client->sbuf;
 
-       if (!get_header(pkt, &pkt_type, &pkt_len)) {
-               disconnect_client(client, true, "bad packet header");
-               return false;
-       }
-       slog_noise(client, "pkt='%c' len=%d", pkt_type, pkt_len);
-
-       switch (pkt_type) {
+       switch (pkt->type) {
 
        /* request immidiate response from server */
        case 'H':               /* Flush */
@@ -320,24 +298,24 @@ static bool handle_client_work(PgSocket *client, MBuf *pkt)
                }
 
                if (client->pool->admin)
-                       return admin_handle_client(client, pkt, pkt_type, pkt_len);
+                       return admin_handle_client(client, pkt);
 
                /* aquire server */
                if (!find_server(client))
                        return false;
 
-               client->pool->stats.client_bytes += pkt_len;
+               client->pool->stats.client_bytes += pkt->len;
 
                /* tag the server as dirty */
                client->link->ready = 0;
 
                /* forward the packet */
-               sbuf_prepare_send(sbuf, &client->link->sbuf, pkt_len);
+               sbuf_prepare_send(sbuf, &client->link->sbuf, pkt->len);
                break;
 
        /* client wants to go away */
        default:
-               slog_error(client, "unknown pkt from client: %d/0x%x", pkt_type, pkt_type);
+               slog_error(client, "unknown pkt from client: %d/0x%x", pkt->type, pkt->type);
                disconnect_client(client, true, "unknown pkt");
                return false;
        case 'X': /* Terminate */
@@ -348,10 +326,12 @@ static bool handle_client_work(PgSocket *client, MBuf *pkt)
 }
 
 /* callback from SBuf */
-bool client_proto(SBuf *sbuf, SBufEvent evtype, MBuf *pkt, void *arg)
+bool client_proto(SBuf *sbuf, SBufEvent evtype, MBuf *data, void *arg)
 {
        bool res = false;
        PgSocket *client = arg;
+       PktHdr pkt;
+
 
        Assert(!is_server_socket(client));
        Assert(client->sbuf.sock);
@@ -374,21 +354,27 @@ bool client_proto(SBuf *sbuf, SBufEvent evtype, MBuf *pkt, void *arg)
                disconnect_server(client->link, false, "Server connection closed");
                break;
        case SBUF_EV_READ:
-               if (mbuf_avail(pkt) < NEW_HEADER_LEN) {
+               if (mbuf_avail(data) < NEW_HEADER_LEN && client->state != CL_LOGIN) {
                        slog_noise(client, "C: got partial header, trying to wait a bit");
                        return false;
                }
 
+               if (!get_header(data, &pkt)) {
+                       disconnect_client(client, true, "bad packet header");
+                       return false;
+               }
+               slog_noise(client, "pkt='%c' len=%d", pkt_desc(&pkt), pkt.len);
+
                client->request_time = get_cached_time();
                switch (client->state) {
                case CL_LOGIN:
-                       res = handle_client_startup(client, pkt);
+                       res = handle_client_startup(client, &pkt);
                        break;
                case CL_ACTIVE:
                        if (client->wait_for_welcome)
-                               res = handle_client_startup(client, pkt);
+                               res = handle_client_startup(client, &pkt);
                        else
-                               res = handle_client_work(client, pkt);
+                               res = handle_client_work(client, &pkt);
                        break;
                case CL_WAITING:
                        fatal("why waiting client in client_proto()");
index 687a9ef85848db266a2025b10daa0762d548e5f0..3142fea284afa361affd40e19b2eeba35599c3ee 100644 (file)
  */
 
 /*
- * Safe and easy access to fixed memory buffer
+ * Safe and easy access to fixed memory buffer.
+ */
+
+/*
+ * FIXME: the code should be converted so that
+ * the fatal()-s can be replaced by Asserts().
  */
 
 typedef struct MBuf MBuf;
@@ -81,12 +86,12 @@ static inline const uint8 * mbuf_get_bytes(MBuf *buf, unsigned len)
        return res;
 }
 
-static inline unsigned mbuf_avail(MBuf *buf)
+static inline unsigned mbuf_avail(const MBuf *buf)
 {
        return buf->end - buf->pos;
 }
 
-static inline unsigned mbuf_size(MBuf *buf)
+static inline unsigned mbuf_size(const MBuf *buf)
 {
        return buf->end - buf->data;
 }
@@ -101,3 +106,16 @@ static inline const char * mbuf_get_string(MBuf *buf)
        return res;
 }
 
+static inline void mbuf_copy(const MBuf *src, MBuf *dst)
+{
+       *dst = *src;
+}
+
+static inline void mbuf_slice(MBuf *src, unsigned len, MBuf *dst)
+{
+       if (len > mbuf_avail(src))
+               fatal("buffer overflow");
+       mbuf_init(dst, src->pos, len);
+       src->pos += len;
+}
+
index d47f3c73a006add44f7995b87a1146acfe6b2ccd..2ec210ac05ccc515e79ca926a152adc07fde26f4 100644 (file)
  */
 
 /* parses pkt header from buffer, returns false if failed */
-bool get_header(MBuf *pkt, unsigned *pkt_type_p, unsigned *pkt_len_p)
+bool get_header(MBuf *data, PktHdr *pkt)
 {
        unsigned type;
        unsigned len;
        unsigned code;
+       unsigned got;
+       unsigned avail;
+       MBuf hdr;
 
-       if (mbuf_avail(pkt) < NEW_HEADER_LEN) {
+       mbuf_copy(data, &hdr);
+
+       if (mbuf_avail(&hdr) < NEW_HEADER_LEN) {
                log_noise("get_header: less then 5 bytes available");
                return false;
        }
-       type = mbuf_get_char(pkt);
+       type = mbuf_get_char(&hdr);
        if (type != 0) {
                /* wire length does not include type byte */
-               len = mbuf_get_uint32(pkt) + 1;
+               len = mbuf_get_uint32(&hdr) + 1;
+               got = NEW_HEADER_LEN;
        } else {
-               if (mbuf_get_char(pkt) != 0) {
+               if (mbuf_get_char(&hdr) != 0) {
                        log_noise("get_header: unknown special pkt");
                        return false;
                }
                /* dont tolerate partial pkt */
-               if (mbuf_avail(pkt) < OLD_HEADER_LEN - 2) {
+               if (mbuf_avail(&hdr) < OLD_HEADER_LEN - 2) {
                        log_noise("get_header: less than 8 bytes for special pkt");
                        return false;
                }
-               len = mbuf_get_uint16(pkt);
-               code = mbuf_get_uint32(pkt);
+               len = mbuf_get_uint16(&hdr);
+               code = mbuf_get_uint32(&hdr);
                if (code == PKT_CANCEL)
                        type = PKT_CANCEL;
                else if (code == PKT_SSLREQ)
@@ -63,14 +69,27 @@ bool get_header(MBuf *pkt, unsigned *pkt_type_p, unsigned *pkt_len_p)
                        log_noise("get_header: unknown special pkt: len=%u code=%u", len, code);
                        return false;
                }
+               got = OLD_HEADER_LEN;
        }
 
        /* don't believe nonsense */
-       if (len < NEW_HEADER_LEN || len >= 0x80000000)
+       if (len < got || len >= 0x80000000)
                return false;
 
-       *pkt_type_p = type;
-       *pkt_len_p = len;
+       /* report pkt info */
+       pkt->type = type;
+       pkt->len = len;
+
+       /* fill apkt with only data for this pkt */
+       if (len > mbuf_avail(data))
+               avail = mbuf_avail(data);
+       else
+               avail = len;
+       mbuf_slice(data, avail, &pkt->data);
+
+       /* tag header as read */
+       mbuf_get_bytes(&pkt->data, got);
+
        return true;
 }
 
@@ -98,15 +117,15 @@ bool send_pooler_error(PgSocket *client, bool send_ready, const char *msg)
 /*
  * Parse server error message and log it.
  */
-void log_server_error(const char *note, MBuf *pkt)
+void log_server_error(const char *note, PktHdr *pkt)
 {
        const char *level = NULL, *msg = NULL, *val;
        int type;
-       while (mbuf_avail(pkt)) {
-               type = mbuf_get_char(pkt);
+       while (mbuf_avail(&pkt->data)) {
+               type = mbuf_get_char(&pkt->data);
                if (type == 0)
                        break;
-               val = mbuf_get_string(pkt);
+               val = mbuf_get_string(&pkt->data);
                if (!val)
                        break;
                if (type == 'S')
@@ -126,8 +145,7 @@ void log_server_error(const char *note, MBuf *pkt)
  */
 
 /* add another server parameter packet to cache */
-bool add_welcome_parameter(PgSocket *server,
-                          unsigned pkt_type, unsigned pkt_len, MBuf *pkt)
+bool add_welcome_parameter(PgSocket *server, PktHdr *pkt)
 {
        PgPool *pool = server->pool;
        PktBuf msg;
@@ -137,7 +155,7 @@ bool add_welcome_parameter(PgSocket *server,
                return true;
 
        /* incomplete startup msg from server? */
-       if (mbuf_avail(pkt) < pkt_len - NEW_HEADER_LEN)
+       if (incomplete_pkt(pkt))
                return false;
 
        pktbuf_static(&msg, pool->welcome_msg + pool->welcome_msg_len,
@@ -146,10 +164,10 @@ bool add_welcome_parameter(PgSocket *server,
        if (pool->welcome_msg_len == 0)
                pktbuf_write_AuthenticationOk(&msg);
 
-       key = mbuf_get_string(pkt);
-       val = mbuf_get_string(pkt);
+       key = mbuf_get_string(&pkt->data);
+       val = mbuf_get_string(&pkt->data);
        if (!key || !val) {
-               slog_error(server, "broken ParameterStatus packet");
+               disconnect_server(server, true, "broken ParameterStatus packet");
                return false;
        }
 
@@ -252,24 +270,17 @@ static bool login_md5_psw(PgSocket *server, const uint8 *salt)
 }
 
 /* answer server authentication request */
-bool answer_authreq(PgSocket *server,
-                   unsigned pkt_type, unsigned pkt_len,
-                   MBuf *pkt)
+bool answer_authreq(PgSocket *server, PktHdr *pkt)
 {
        unsigned cmd;
        const uint8 *salt;
        bool res = false;
-       unsigned pkt_remain;
 
        /* authreq body must contain 32bit cmd */
-       if (pkt_len < NEW_HEADER_LEN + 4)
-               return false;
-       /* is packet fully received? */
-       if (mbuf_avail(pkt) < pkt_len - NEW_HEADER_LEN)
+       if (mbuf_avail(&pkt->data) < 4)
                return false;
 
-       cmd = mbuf_get_uint32(pkt);
-       pkt_remain = pkt_len - NEW_HEADER_LEN - 4;
+       cmd = mbuf_get_uint32(&pkt->data);
        switch (cmd) {
        case 0:
                slog_debug(server, "S: auth ok");
@@ -280,17 +291,17 @@ bool answer_authreq(PgSocket *server,
                res = login_clear_psw(server);
                break;
        case 4:
-               if (pkt_remain < 2)
-                       return false;
                slog_debug(server, "S: req crypt psw");
-               salt = mbuf_get_bytes(pkt, 2);
+               if (mbuf_avail(&pkt->data) < 2)
+                       return false;
+               salt = mbuf_get_bytes(&pkt->data, 2);
                res = login_crypt_psw(server, salt);
                break;
        case 5:
-               if (pkt_remain < 4)
-                       return false;
                slog_debug(server, "S: req md5-crypted psw");
-               salt = mbuf_get_bytes(pkt, 4);
+               if (mbuf_avail(&pkt->data) < 4)
+                       return false;
+               salt = mbuf_get_bytes(&pkt->data, 4);
                res = login_md5_psw(server, salt);
                break;
        case 2: /* kerberos */
index eeaf343ef2f28ad1ddcd35f630d82b4655d7e40b..9f32d85f9f248507dd072e344f4d1a3242352a35 100644 (file)
  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  */
 
-bool get_header(MBuf *pkt, unsigned *pkt_type_p, unsigned *pkt_len_p);
+/*
+ * parsed packet header, plus whatever data is
+ * available in SBuf for this packet.
+ *
+ * if (pkt->len == mbuf_avail(&pkt->data))
+ *     packet is fully in buffer
+ *
+ * get_header() points pkt->data.pos after header.
+ * to packet body.
+ */
+struct PktHdr {
+       unsigned type;
+       unsigned len;
+       MBuf data;
+};
+
+bool get_header(MBuf *data, PktHdr *pkt);
 
 bool send_pooler_error(PgSocket *client, bool send_ready, const char *msg);
-void log_server_error(const char *note, MBuf *pkt);
+void log_server_error(const char *note, PktHdr *pkt);
 
-bool add_welcome_parameter(PgSocket *server, unsigned pkt_type, unsigned pkt_len, MBuf *pkt);
+bool add_welcome_parameter(PgSocket *server, PktHdr *pkt);
 void finish_welcome_msg(PgSocket *server);
 bool welcome_client(PgSocket *client);
 
-bool answer_authreq(PgSocket *server, unsigned pkt_type, unsigned pkt_len, MBuf *pkt);
+bool answer_authreq(PgSocket *server, PktHdr *pkt);
 
 bool send_startup_packet(PgSocket *server);
 
 int scan_text_result(MBuf *pkt, const char *tupdesc, ...);
 
+/* is packet completely in our buffer */
+static inline bool incomplete_pkt(const PktHdr *pkt)
+{
+       return mbuf_avail(&pkt->data) != pkt->len;
+}
+
+/* one char desc */
+static inline char pkt_desc(const PktHdr *pkt)
+{
+       return pkt->type > 256 ? '!' : pkt->type;
+}
+
index e6ad062bbac920bd9d6b668c5f2d4ef587bfe8a4..5ad3df7b385d463f92d658973009914bb85d2af2 100644 (file)
 
 #include "bouncer.h"
 
-static bool load_parameter(PgSocket *server, MBuf *pkt, unsigned pkt_len)
+static bool load_parameter(PgSocket *server, PktHdr *pkt)
 {
        const char *key, *val;
        PgSocket *client = server->link;
 
        /*
         * incomplete startup msg from server?
-        * (hdr is already parsed here)
         */
-       if (mbuf_avail(pkt) < pkt_len - NEW_HEADER_LEN)
+       if (incomplete_pkt(pkt))
                return false;
 
-       key = mbuf_get_string(pkt);
-       val = mbuf_get_string(pkt);
+       key = mbuf_get_string(&pkt->data);
+       val = mbuf_get_string(&pkt->data);
        if (!key || !val) {
                disconnect_server(server, true, "broken ParameterStatus packet");
                return false;
@@ -53,28 +52,20 @@ static bool load_parameter(PgSocket *server, MBuf *pkt, unsigned pkt_len)
 }
 
 /* process packets on server auth phase */
-static bool handle_server_startup(PgSocket *server, MBuf *pkt)
+static bool handle_server_startup(PgSocket *server, PktHdr *pkt)
 {
-       unsigned pkt_type;
-       unsigned pkt_len;
        SBuf *sbuf = &server->sbuf;
        bool res = false;
 
-       if (!get_header(pkt, &pkt_type, &pkt_len)) {
-               disconnect_server(server, true, "bad pkt in login phase");
-               return false;
-       }
-
-       if (mbuf_avail(pkt) < pkt_len - NEW_HEADER_LEN) {
+       if (incomplete_pkt(pkt)) {
                disconnect_server(server, true, "partial pkt in login phase");
                return false;
        }
 
-       slog_noise(server, "S: pkt '%c', len=%d", pkt_type, pkt_len);
 
-       switch (pkt_type) {
+       switch (pkt->type) {
        default:
-               slog_error(server, "unknown pkt from server: '%c'", pkt_type);
+               slog_error(server, "unknown pkt from server: '%c'", pkt_desc(pkt));
                disconnect_server(server, true, "unknown pkt from server");
                break;
 
@@ -86,13 +77,13 @@ static bool handle_server_startup(PgSocket *server, MBuf *pkt)
        /* packets that need closer look */
        case 'R':               /* AuthenticationXXX */
                slog_debug(server, "calling login_answer");
-               res = answer_authreq(server, pkt_type, pkt_len, pkt);
+               res = answer_authreq(server, pkt);
                if (!res)
                        disconnect_server(server, false, "failed to answer authreq");
                break;
 
        case 'S':               /* ParameterStatus */
-               res = add_welcome_parameter(server, pkt_type, pkt_len, pkt);
+               res = add_welcome_parameter(server, pkt);
                break;
 
        case 'Z':               /* ReadyForQuery */
@@ -112,28 +103,28 @@ static bool handle_server_startup(PgSocket *server, MBuf *pkt)
 
        /* ignorable packets */
        case 'K':               /* BackendKeyData */
-               if (mbuf_avail(pkt) >= 8)
-                       memcpy(server->cancel_key, mbuf_get_bytes(pkt, 8), 8);
+               if (mbuf_avail(&pkt->data) >= BACKENDKEY_LEN)
+                       memcpy(server->cancel_key,
+                              mbuf_get_bytes(&pkt->data, BACKENDKEY_LEN),
+                              BACKENDKEY_LEN);
                res = true;
                break;
 
        case 'N':               /* NoticeResponse */
-               slog_noise(server, "skipping pkt: %c", pkt_type);
+               slog_noise(server, "skipping pkt: %c", pkt_desc(pkt));
                res = true;
                break;
        }
 
        if (res)
-               sbuf_prepare_skip(sbuf, pkt_len);
+               sbuf_prepare_skip(sbuf, pkt->len);
 
        return res;
 }
 
 /* process packets on logged in connection */
-static bool handle_server_work(PgSocket *server, MBuf *pkt)
+static bool handle_server_work(PgSocket *server, PktHdr *pkt)
 {
-       unsigned pkt_type;
-       unsigned pkt_len;
        bool ready = 0;
        char state;
        SBuf *sbuf = &server->sbuf;
@@ -141,15 +132,9 @@ static bool handle_server_work(PgSocket *server, MBuf *pkt)
 
        Assert(!server->pool->admin);
 
-       if (!get_header(pkt, &pkt_type, &pkt_len)) {
-               disconnect_server(server, true, "bad pkt header");
-               return false;
-       }
-       slog_noise(server, "pkt='%c' len=%d", pkt_type, pkt_len);
-
-       switch (pkt_type) {
+       switch (pkt->type) {
        default:
-               slog_error(server, "unknown pkt: '%c'", pkt_type);
+               slog_error(server, "unknown pkt: '%c'", pkt_desc(pkt));
                disconnect_server(server, true, "unknown pkt");
                return false;
        
@@ -157,9 +142,9 @@ static bool handle_server_work(PgSocket *server, MBuf *pkt)
        case 'Z':               /* ReadyForQuery */
 
                /* if partial pkt, wait */
-               if (mbuf_avail(pkt) == 0)
+               if (mbuf_avail(&pkt->data) == 0)
                        return false;
-               state = mbuf_get_char(pkt);
+               state = mbuf_get_char(&pkt->data);
 
                /* set ready only if no tx */
                if (state == 'I')
@@ -172,7 +157,7 @@ static bool handle_server_work(PgSocket *server, MBuf *pkt)
                break;
 
        case 'S':               /* ParameterStatus */
-               if (!load_parameter(server, pkt, pkt_len))
+               if (!load_parameter(server, pkt))
                        return false;
                break;
 
@@ -228,17 +213,17 @@ static bool handle_server_work(PgSocket *server, MBuf *pkt)
                break;
        }
        server->ready = ready;
-       server->pool->stats.server_bytes += pkt_len;
+       server->pool->stats.server_bytes += pkt->len;
 
        if (server->setting_vars) {
                Assert(client);
-               sbuf_prepare_skip(sbuf, pkt_len);
+               sbuf_prepare_skip(sbuf, pkt->len);
                if (ready) {
                        server->setting_vars = 0;
                        sbuf_continue(&client->sbuf);
                }
        } else if (client) {
-               sbuf_prepare_send(sbuf, &client->sbuf, pkt_len);
+               sbuf_prepare_send(sbuf, &client->sbuf, pkt->len);
                if (ready) {
                        usec_t total;
                        Assert(client->query_start != 0);
@@ -252,8 +237,8 @@ static bool handle_server_work(PgSocket *server, MBuf *pkt)
                if (server->state != SV_TESTED)
                        slog_warning(server,
                                     "got packet '%c' from server when not linked",
-                                    pkt_type);
-               sbuf_prepare_skip(sbuf, pkt_len);
+                                    pkt_desc(pkt));
+               sbuf_prepare_skip(sbuf, pkt->len);
        }
 
        return true;
@@ -282,10 +267,11 @@ static bool handle_connect(PgSocket *server)
 }
 
 /* callback from SBuf */
-bool server_proto(SBuf *sbuf, SBufEvent evtype, MBuf *pkt, void *arg)
+bool server_proto(SBuf *sbuf, SBufEvent evtype, MBuf *data, void *arg)
 {
        bool res = false;
        PgSocket *server = arg;
+       PktHdr pkt;
 
        Assert(is_server_socket(server));
        Assert(server->state != SV_FREE);
@@ -304,21 +290,28 @@ bool server_proto(SBuf *sbuf, SBufEvent evtype, MBuf *pkt, void *arg)
                disconnect_client(server->link, false, "unexpected eof");
                break;
        case SBUF_EV_READ:
-               if (mbuf_avail(pkt) < NEW_HEADER_LEN) {
+               if (mbuf_avail(data) < NEW_HEADER_LEN) {
                        slog_noise(server, "S: got partial header, trying to wait a bit");
                        return false;
                }
 
+               /* parse pkt header */
+               if (!get_header(data, &pkt)) {
+                       disconnect_server(server, true, "bad pkt header");
+                       return false;
+               }
+               slog_noise(server, "S: pkt '%c', len=%d", pkt_desc(&pkt), pkt.len);
+
                server->request_time = get_cached_time();
                switch (server->state) {
                case SV_LOGIN:
-                       res = handle_server_startup(server, pkt);
+                       res = handle_server_startup(server, &pkt);
                        break;
                case SV_TESTED:
                case SV_USED:
                case SV_ACTIVE:
                case SV_IDLE:
-                       res = handle_server_work(server, pkt);
+                       res = handle_server_work(server, &pkt);
                        break;
                default:
                        fatal("server_proto: server in bad state: %d", server->state);
index 89221516ce71804715774631090da4adeaad02aa..551ecb617cca5febe57106996137edb2da3222e7 100644 (file)
@@ -182,28 +182,28 @@ static void takeover_parse_data(PgSocket *bouncer,
                                struct msghdr *msg, MBuf *data)
 {
        struct cmsghdr *cmsg;
-       unsigned pkt_type, pkt_len;
-       uint8 *pktptr;
-       MBuf pkt;
+       PktHdr pkt;
        
        cmsg = msg->msg_controllen ? CMSG_FIRSTHDR(msg) : NULL;
 
        while (mbuf_avail(data) > 0) {
-               if (!get_header(data, &pkt_type, &pkt_len))
+               if (!get_header(data, &pkt))
                        fatal("cannot parse packet");
 
-               /* crash on overflow is ok here */
-               pktptr = (uint8*)mbuf_get_bytes(data, pkt_len - NEW_HEADER_LEN);
-               mbuf_init(&pkt, pktptr, pkt_len - NEW_HEADER_LEN);
+               /*
+                * There should not be partial reads from UNIX socket.
+                */
+               if (incomplete_pkt(&pkt))
+                       fatal("unexpected partial packet");
 
-               switch (pkt_type) {
+               switch (pkt.type) {
                case 'T': /* RowDescription */
                        log_debug("takeover_parse_data: 'T'");
                        break;
                case 'D': /* DataRow */
                        log_debug("takeover_parse_data: 'D'");
                        if (cmsg) {
-                               takeover_load_fd(&pkt, cmsg);
+                               takeover_load_fd(&pkt.data, cmsg);
                                cmsg = CMSG_NXTHDR(msg, cmsg);
                        } else
                                fatal("got row without fd info");
@@ -213,13 +213,13 @@ static void takeover_parse_data(PgSocket *bouncer,
                        break;
                case 'C': /* CommandComplete */
                        log_debug("takeover_parse_data: 'C'");
-                       next_command(bouncer, &pkt);
+                       next_command(bouncer, &pkt.data);
                        break;
                case 'E': /* ErrorMessage */
                        log_server_error("old bouncer sent", &pkt);
                        fatal("something failed");
                default:
-                       fatal("takeover_parse_data: unexpected pkt: '%c'", pkt_type);
+                       fatal("takeover_parse_data: unexpected pkt: '%c'", pkt_desc(&pkt));
                }
        }
 }