#include "bouncer.h"
+#include <usual/pgutil.h>
+
static const char *hdr2hex(const struct MBuf *data, char *buf, unsigned buflen)
{
const uint8_t *bin = data->data + data->read_pos;
return false;
}
-bool set_pool(PgSocket *client, const char *dbname, const char *username)
+/* mask to get offset into valid_crypt_salt[] */
+#define SALT_MASK 0x3F
+
+static const char valid_crypt_salt[] =
+"./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
+
+static bool send_client_authreq(PgSocket *client)
+{
+ uint8_t saltlen = 0;
+ int res;
+ int auth = cf_auth_type;
+ uint8_t randbuf[2];
+
+ if (auth == AUTH_CRYPT) {
+ saltlen = 2;
+ get_random_bytes(randbuf, saltlen);
+ client->tmp_login_salt[0] = valid_crypt_salt[randbuf[0] & SALT_MASK];
+ client->tmp_login_salt[1] = valid_crypt_salt[randbuf[1] & SALT_MASK];
+ client->tmp_login_salt[2] = 0;
+ } else if (cf_auth_type == AUTH_MD5) {
+ saltlen = 4;
+ get_random_bytes((void*)client->tmp_login_salt, saltlen);
+ } else if (auth == AUTH_ANY)
+ auth = AUTH_TRUST;
+
+ SEND_generic(res, client, 'R', "ib", auth, client->tmp_login_salt, saltlen);
+ return res;
+}
+
+static void start_auth_request(PgSocket *client, const char *username)
+{
+ int res;
+ char quoted_username[64], query[128];
+
+ client->auth_user = client->db->auth_user;
+ /* have to fetch user info from db */
+ client->pool = get_pool(client->db, client->db->auth_user);
+ if (!find_server(client)) {
+ client->wait_for_user_conn = true;
+ return;
+ }
+ slog_noise(client, "Doing auth_conn query");
+ client->wait_for_user_conn = false;
+ client->wait_for_user = true;
+ if (!sbuf_pause(&client->sbuf)) {
+ release_server(client->link);
+ disconnect_client(client, true, "pause failed");
+ return;
+ }
+ client->link->ready = 0;
+
+ pg_quote_literal(quoted_username, username, sizeof(quoted_username));
+ snprintf(query, sizeof(query), "SELECT usename, passwd FROM pg_shadow WHERE usename=%s", quoted_username);
+ SEND_generic(res, client->link, 'Q', "s", query);
+ if (!res)
+ disconnect_server(client->link, false, "unable to send login query");
+}
+
+static bool finish_set_pool(PgSocket *client, bool takeover)
{
- PgDatabase *db;
- PgUser *user;
+ PgUser *user = client->auth_user;
+ /* pool user may be forced */
+ if (client->db->forced_user) {
+ user = client->db->forced_user;
+ }
+ client->pool = get_pool(client->db, user);
+ if (!client->pool) {
+ disconnect_client(client, true, "no memory for pool");
+ return false;
+ }
+
+ if (cf_log_connections)
+ slog_info(client, "login attempt: db=%s user=%s", client->db->name, client->auth_user->name);
+
+ if (!check_fast_fail(client))
+ return false;
+
+ if (takeover)
+ return true;
+
+ if (client->pool->db->admin) {
+ if (!admin_post_login(client))
+ return false;
+ }
+
+ if (cf_auth_type <= AUTH_TRUST || client->own_user) {
+ if (!finish_client_login(client))
+ return false;
+ } else {
+ if (!send_client_authreq(client)) {
+ disconnect_client(client, false, "failed to send auth req");
+ return false;
+ }
+ }
+ return true;
+}
+bool set_pool(PgSocket *client, const char *dbname, const char *username, bool takeover)
+{
/* find database */
- db = find_database(dbname);
- if (!db) {
- db = register_auto_database(dbname);
- if (!db) {
+ client->db = find_database(dbname);
+ if (!client->db) {
+ client->db = register_auto_database(dbname);
+ if (!client->db) {
disconnect_client(client, true, "No such database: %s", dbname);
+ if (cf_log_connections)
+ slog_info(client, "login failed: db=%s user=%s", dbname, username);
return false;
}
else {
}
/* are new connections allowed? */
- if (db->db_disabled) {
+ if (client->db->db_disabled) {
disconnect_client(client, true, "database does not allow connections: %s", dbname);
return false;
}
+ if (client->db->admin) {
+ if (admin_pre_login(client, username))
+ return finish_set_pool(client, takeover);
+ }
+
/* find user */
if (cf_auth_type == AUTH_ANY) {
/* ignore requested user */
- user = NULL;
-
- if (db->forced_user == NULL) {
+ if (client->db->forced_user == NULL) {
slog_error(client, "auth_type=any requires forced user");
disconnect_client(client, true, "bouncer config error");
return false;
}
- client->auth_user = db->forced_user;
+ client->auth_user = client->db->forced_user;
} else {
/* the user clients wants to log in as */
- user = find_user(username);
- if (!user) {
+ client->auth_user = find_user(username);
+ if (!client->auth_user && client->db->auth_user) {
+ if (takeover) {
+ client->auth_user = add_db_user(client->db, username, "");
+ return finish_set_pool(client, takeover);
+ }
+ start_auth_request(client, username);
+ return false;
+ }
+ if (!client->auth_user) {
disconnect_client(client, true, "No such user: %s", username);
+ if (cf_log_connections)
+ slog_info(client, "login failed: db=%s user=%s", dbname, username);
return false;
}
- client->auth_user = user;
}
+ return finish_set_pool(client, takeover);
+}
- /* pool user may be forced */
- if (db->forced_user)
- user = db->forced_user;
- client->pool = get_pool(db, user);
- if (!client->pool) {
- disconnect_client(client, true, "no memory for pool");
+bool handle_auth_response(PgSocket *client, PktHdr *pkt) {
+ uint16_t columns;
+ uint32_t length;
+ const char *username, *password;
+ PgUser user;
+
+ switch(pkt->type) {
+ case 'T': /* RowDescription */
+ if (!mbuf_get_uint16be(&pkt->data, &columns)) {
+ disconnect_server(client->link, false, "bad packet");
+ return false;
+ }
+ if (columns != 2u) {
+ disconnect_server(client->link, false, "expected 1 column from login query, not %hu", columns);
+ return false;
+ }
+ break;
+ case 'D': /* DataRow */
+ memset(&user, 0, sizeof(user));
+ if (!mbuf_get_uint16be(&pkt->data, &columns)) {
+ disconnect_server(client->link, false, "bad packet");
+ return false;
+ }
+ if (columns != 2u) {
+ disconnect_server(client->link, false, "expected 1 column from login query, not %hu", columns);
+ return false;
+ }
+ if (!mbuf_get_uint32be(&pkt->data, &length)) {
+ disconnect_server(client->link, false, "bad packet");
+ return false;
+ }
+ if (!mbuf_get_chars(&pkt->data, length, &username)) {
+ disconnect_server(client->link, false, "bad packet");
+ return false;
+ }
+ if (sizeof(user.name) - 1 < length)
+ length = sizeof(user.name) - 1;
+ memcpy(user.name, username, length);
+ if (!mbuf_get_uint32be(&pkt->data, &length)) {
+ disconnect_server(client->link, false, "bad packet");
+ return false;
+ }
+ if (length == (uint32_t)-1) {
+ // NULL - set an md5 password with an impossible value,
+ // so that nothing will ever match
+ password = "md5";
+ length = 3;
+ } else {
+ if (!mbuf_get_chars(&pkt->data, length, &password)) {
+ disconnect_server(client->link, false, "bad packet");
+ return false;
+ }
+ }
+ if (sizeof(user.passwd) - 1 < length)
+ length = sizeof(user.passwd) - 1;
+ memcpy(user.passwd, password, length);
+
+ client->auth_user = add_db_user(client->db, user.name, user.passwd);
+ if (!client->auth_user) {
+ disconnect_server(client->link, false, "unable to allocate new user for auth");
+ return false;
+ }
+ break;
+ case 'C': /* CommandComplete */
+ break;
+ case 'Z': /* ReadyForQuery */
+ sbuf_prepare_skip(&client->link->sbuf, pkt->len);
+ if (!client->auth_user) {
+ if (cf_log_connections)
+ slog_info(client, "login failed: db=%s", client->db->name);
+ disconnect_client(client, true, "No such user");
+ } else {
+ slog_noise(client, "auth query complete");
+ sbuf_continue(&client->sbuf);
+ }
+ return true;
+ default:
+ disconnect_server(client->link, false, "unexpected response from login query");
return false;
}
-
- return check_fast_fail(client);
+ sbuf_prepare_skip(&client->link->sbuf, pkt->len);
+ return true;
}
static bool decide_startup_pool(PgSocket *client, PktHdr *pkt)
}
}
- /* find pool and log about it */
- if (set_pool(client, dbname, username)) {
- if (cf_log_connections)
- slog_info(client, "login attempt: db=%s user=%s", dbname, username);
- return true;
- } else {
- if (cf_log_connections)
- slog_info(client, "login failed: db=%s user=%s", dbname, username);
- return false;
- }
-}
-
-/* mask to get offset into valid_crypt_salt[] */
-#define SALT_MASK 0x3F
-
-static const char valid_crypt_salt[] =
-"./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
-
-static bool send_client_authreq(PgSocket *client)
-{
- uint8_t saltlen = 0;
- int res;
- int auth = cf_auth_type;
- uint8_t randbuf[2];
-
- if (auth == AUTH_CRYPT) {
- saltlen = 2;
- get_random_bytes(randbuf, saltlen);
- client->tmp_login_salt[0] = valid_crypt_salt[randbuf[0] & SALT_MASK];
- client->tmp_login_salt[1] = valid_crypt_salt[randbuf[1] & SALT_MASK];
- client->tmp_login_salt[2] = 0;
- } else if (cf_auth_type == AUTH_MD5) {
- saltlen = 4;
- get_random_bytes((void*)client->tmp_login_salt, saltlen);
- } else if (auth == AUTH_ANY)
- auth = AUTH_TRUST;
-
- SEND_generic(res, client, 'R', "ib", auth, client->tmp_login_salt, saltlen);
- return res;
+ /* find pool */
+ return set_pool(client, dbname, username, false);
}
/* decide on packets of client in login phase */
disconnect_client(client, true, "Old V2 protocol not supported");
return false;
case PKT_STARTUP:
- if (client->pool) {
+ if (client->pool && !client->wait_for_user_conn && !client->wait_for_user) {
disconnect_client(client, true, "client re-sent startup pkt");
return false;
}
- if (!decide_startup_pool(client, pkt))
- return false;
-
- if (client->pool->db->admin) {
- if (!admin_pre_login(client))
+ if (client->wait_for_user) {
+ client->wait_for_user = false;
+ if (!finish_set_pool(client, false))
return false;
+ } else if (!decide_startup_pool(client, pkt)) {
+ return false;
}
- if (cf_auth_type <= AUTH_TRUST || client->own_user) {
- if (!finish_client_login(client))
- return false;
- } else {
- if (!send_client_authreq(client)) {
- disconnect_client(client, false, "failed to send auth req");
- return false;
- }
- }
break;
case 'p': /* PasswordMessage */
/* haven't requested it */
return strcmp(name, user->name);
}
+/* destroy PgUser, for usage with btree */
+static void user_node_release(struct AANode *node, void *arg)
+{
+ PgUser *user = container_of(node, PgUser, tree_node);
+ slab_free(user_cache, user);
+}
+
/* initialization before config loading */
void init_objects(void)
{
statlist_remove(&justfree_client_list, &client->head);
break;
case CL_LOGIN:
+ if (newstate == CL_WAITING)
+ newstate = CL_WAITING_LOGIN;
statlist_remove(&login_client_list, &client->head);
break;
+ case CL_WAITING_LOGIN:
+ if (newstate == CL_ACTIVE)
+ newstate = CL_LOGIN;
case CL_WAITING:
statlist_remove(&pool->waiting_client_list, &client->head);
break;
statlist_append(&login_client_list, &client->head);
break;
case CL_WAITING:
+ case CL_WAITING_LOGIN:
statlist_append(&pool->waiting_client_list, &client->head);
break;
case CL_ACTIVE:
slab_free(db_cache, db);
return NULL;
}
+ aatree_init(&db->user_tree, user_node_cmp, user_node_release);
put_in_order(&db->head, &database_list, cmp_database);
}
return user;
}
+/* add or update db users */
+PgUser *add_db_user(PgDatabase *db, const char *name, const char *passwd)
+{
+ PgUser *user = NULL;
+ struct AANode *node;
+
+ node = aatree_search(&db->user_tree, (uintptr_t)name);
+ user = node ? container_of(node, PgUser, tree_node) : NULL;
+
+ if (user == NULL) {
+ user = slab_alloc(user_cache);
+ if (!user)
+ return NULL;
+
+ list_init(&user->head);
+ list_init(&user->pool_list);
+ safe_strcpy(user->name, name, sizeof(user->name));
+
+ aatree_insert(&db->user_tree, (uintptr_t)user->name, &user->tree_node);
+ user->pool_mode = POOL_INHERIT;
+ }
+ safe_strcpy(user->passwd, passwd, sizeof(user->passwd));
+ return user;
+}
+
/* create separate user object for storing server user info */
PgUser *force_user(PgDatabase *db, const char *name, const char *passwd)
{
/* deactivate socket and put into wait queue */
static void pause_client(PgSocket *client)
{
- Assert(client->state == CL_ACTIVE);
+ Assert(client->state == CL_ACTIVE || client->state == CL_LOGIN);
slog_debug(client, "pause_client");
change_client_state(client, CL_WAITING);
/* wake client from wait */
void activate_client(PgSocket *client)
{
- Assert(client->state == CL_WAITING);
+ Assert(client->state == CL_WAITING || client->state == CL_WAITING_LOGIN);
slog_debug(client, "activate_client");
change_client_state(client, CL_ACTIVE);
bool res;
bool varchange = false;
- Assert(client->state == CL_ACTIVE);
+ Assert(client->state == CL_ACTIVE || client->state == CL_LOGIN);
if (client->link)
return true;
}
case CL_LOGIN:
case CL_WAITING:
+ case CL_WAITING_LOGIN:
case CL_CANCEL:
break;
default:
return false;
client->suspended = 1;
- if (!set_pool(client, dbname, username))
+ if (!set_pool(client, dbname, username, true))
return false;
change_client_state(client, CL_ACTIVE);