From: Heikki Linnakangas Date: Fri, 28 Apr 2017 12:04:02 +0000 (+0300) Subject: Misc SCRAM code cleanups. X-Git-Tag: REL_10_BETA1~143 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=d981074c24d2f1e4f44bc6d80e967e523ce64f50;p=postgresql Misc SCRAM code cleanups. * Move computation of SaltedPassword to a separate function from scram_ClientOrServerKey(). This saves a lot of cycles in libpq, by computing SaltedPassword only once per authentication. (Computing SaltedPassword is expensive by design.) * Split scram_ClientOrServerKey() into two functions. Improves readability, by making the calling code less verbose. * Rename "server proof" to "server signature", to better match the nomenclature used in RFC 5802. * Rename SCRAM_SALT_LEN to SCRAM_DEFAULT_SALT_LEN, to make it more clear that the salt can be of any length, and the constant only specifies how long a salt we use when we generate a new verifier. Also rename SCRAM_ITERATIONS_DEFAULT to SCRAM_DEFAULT_ITERATIONS, for consistency. These things caught my eye while working on other upcoming changes. --- diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c index 16bea446e3..5c85af943c 100644 --- a/src/backend/libpq/auth-scram.c +++ b/src/backend/libpq/auth-scram.c @@ -396,7 +396,8 @@ scram_build_verifier(const char *username, const char *password, { char *prep_password = NULL; pg_saslprep_rc rc; - char saltbuf[SCRAM_SALT_LEN]; + char saltbuf[SCRAM_DEFAULT_SALT_LEN]; + uint8 salted_password[SCRAM_KEY_LEN]; uint8 keybuf[SCRAM_KEY_LEN]; char *encoded_salt; char *encoded_storedkey; @@ -414,10 +415,10 @@ scram_build_verifier(const char *username, const char *password, password = (const char *) prep_password; if (iterations <= 0) - iterations = SCRAM_ITERATIONS_DEFAULT; + iterations = SCRAM_DEFAULT_ITERATIONS; /* Generate salt, and encode it in base64 */ - if (!pg_backend_random(saltbuf, SCRAM_SALT_LEN)) + if (!pg_backend_random(saltbuf, SCRAM_DEFAULT_SALT_LEN)) { ereport(LOG, (errcode(ERRCODE_INTERNAL_ERROR), @@ -425,13 +426,14 @@ scram_build_verifier(const char *username, const char *password, return NULL; } - encoded_salt = palloc(pg_b64_enc_len(SCRAM_SALT_LEN) + 1); - encoded_len = pg_b64_encode(saltbuf, SCRAM_SALT_LEN, encoded_salt); + encoded_salt = palloc(pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN) + 1); + encoded_len = pg_b64_encode(saltbuf, SCRAM_DEFAULT_SALT_LEN, encoded_salt); encoded_salt[encoded_len] = '\0'; /* Calculate StoredKey, and encode it in base64 */ - scram_ClientOrServerKey(password, saltbuf, SCRAM_SALT_LEN, - iterations, SCRAM_CLIENT_KEY_NAME, keybuf); + scram_SaltedPassword(password, saltbuf, SCRAM_DEFAULT_SALT_LEN, + iterations, salted_password); + scram_ClientKey(salted_password, keybuf); scram_H(keybuf, SCRAM_KEY_LEN, keybuf); /* StoredKey */ encoded_storedkey = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1); @@ -440,8 +442,7 @@ scram_build_verifier(const char *username, const char *password, encoded_storedkey[encoded_len] = '\0'; /* And same for ServerKey */ - scram_ClientOrServerKey(password, saltbuf, SCRAM_SALT_LEN, iterations, - SCRAM_SERVER_KEY_NAME, keybuf); + scram_ServerKey(salted_password, keybuf); encoded_serverkey = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1); encoded_len = pg_b64_encode((const char *) keybuf, SCRAM_KEY_LEN, @@ -473,6 +474,7 @@ scram_verify_plain_password(const char *username, const char *password, char *salt; int saltlen; int iterations; + uint8 salted_password[SCRAM_KEY_LEN]; uint8 stored_key[SCRAM_KEY_LEN]; uint8 server_key[SCRAM_KEY_LEN]; uint8 computed_key[SCRAM_KEY_LEN]; @@ -502,9 +504,9 @@ scram_verify_plain_password(const char *username, const char *password, if (rc == SASLPREP_SUCCESS) password = prep_password; - /* Compute Server key based on the user-supplied plaintext password */ - scram_ClientOrServerKey(password, salt, saltlen, iterations, - SCRAM_SERVER_KEY_NAME, computed_key); + /* Compute Server Key based on the user-supplied plaintext password */ + scram_SaltedPassword(password, salt, saltlen, iterations, salted_password); + scram_ServerKey(salted_password, computed_key); if (prep_password) pfree(prep_password); @@ -630,12 +632,12 @@ mock_scram_verifier(const char *username, int *iterations, char **salt, /* Generate deterministic salt */ raw_salt = scram_MockSalt(username); - encoded_salt = (char *) palloc(pg_b64_enc_len(SCRAM_SALT_LEN) + 1); - encoded_len = pg_b64_encode(raw_salt, SCRAM_SALT_LEN, encoded_salt); + encoded_salt = (char *) palloc(pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN) + 1); + encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt); encoded_salt[encoded_len] = '\0'; *salt = encoded_salt; - *iterations = SCRAM_ITERATIONS_DEFAULT; + *iterations = SCRAM_DEFAULT_ITERATIONS; /* StoredKey and ServerKey are not used in a doomed authentication */ memset(stored_key, 0, SCRAM_KEY_LEN); @@ -1179,7 +1181,7 @@ build_server_final_message(scram_state *state) /* * Determinisitcally generate salt for mock authentication, using a SHA256 * hash based on the username and a cluster-level secret key. Returns a - * pointer to a static buffer of size SCRAM_SALT_LEN. + * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN. */ static char * scram_MockSalt(const char *username) @@ -1194,7 +1196,7 @@ scram_MockSalt(const char *username) * not larger the SHA256 digest length. If the salt is smaller, the caller * will just ignore the extra data)) */ - StaticAssertStmt(PG_SHA256_DIGEST_LENGTH >= SCRAM_SALT_LEN, + StaticAssertStmt(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN, "salt length greater than SHA256 digest length"); pg_sha256_init(&ctx); diff --git a/src/common/scram-common.c b/src/common/scram-common.c index df9f0eaa90..a8ea44944c 100644 --- a/src/common/scram-common.c +++ b/src/common/scram-common.c @@ -98,14 +98,16 @@ scram_HMAC_final(uint8 *result, scram_HMAC_ctx *ctx) } /* - * Iterate hash calculation of HMAC entry using given salt. - * scram_Hi() is essentially PBKDF2 (see RFC2898) with HMAC() as the - * pseudorandom function. + * Calculate SaltedPassword. + * + * The password should already be normalized by SASLprep. */ -static void -scram_Hi(const char *str, const char *salt, int saltlen, int iterations, uint8 *result) +void +scram_SaltedPassword(const char *password, + const char *salt, int saltlen, int iterations, + uint8 *result) { - int str_len = strlen(str); + int password_len = strlen(password); uint32 one = htonl(1); int i, j; @@ -113,8 +115,14 @@ scram_Hi(const char *str, const char *salt, int saltlen, int iterations, uint8 * uint8 Ui_prev[SCRAM_KEY_LEN]; scram_HMAC_ctx hmac_ctx; + /* + * Iterate hash calculation of HMAC entry using given salt. This is + * essentially PBKDF2 (see RFC2898) with HMAC() as the pseudorandom + * function. + */ + /* First iteration */ - scram_HMAC_init(&hmac_ctx, (uint8 *) str, str_len); + scram_HMAC_init(&hmac_ctx, (uint8 *) password, password_len); scram_HMAC_update(&hmac_ctx, salt, saltlen); scram_HMAC_update(&hmac_ctx, (char *) &one, sizeof(uint32)); scram_HMAC_final(Ui_prev, &hmac_ctx); @@ -123,7 +131,7 @@ scram_Hi(const char *str, const char *salt, int saltlen, int iterations, uint8 * /* Subsequent iterations */ for (i = 2; i <= iterations; i++) { - scram_HMAC_init(&hmac_ctx, (uint8 *) str, str_len); + scram_HMAC_init(&hmac_ctx, (uint8 *) password, password_len); scram_HMAC_update(&hmac_ctx, (const char *) Ui_prev, SCRAM_KEY_LEN); scram_HMAC_final(Ui, &hmac_ctx); for (j = 0; j < SCRAM_KEY_LEN; j++) @@ -148,20 +156,27 @@ scram_H(const uint8 *input, int len, uint8 *result) } /* - * Calculate ClientKey or ServerKey. - * - * The password should already be normalized by SASLprep. + * Calculate ClientKey. */ void -scram_ClientOrServerKey(const char *password, - const char *salt, int saltlen, int iterations, - const char *keystr, uint8 *result) +scram_ClientKey(const uint8 *salted_password, uint8 *result) +{ + scram_HMAC_ctx ctx; + + scram_HMAC_init(&ctx, salted_password, SCRAM_KEY_LEN); + scram_HMAC_update(&ctx, "Client Key", strlen("Client Key")); + scram_HMAC_final(result, &ctx); +} + +/* + * Calculate ServerKey. + */ +void +scram_ServerKey(const uint8 *salted_password, uint8 *result) { - uint8 keybuf[SCRAM_KEY_LEN]; scram_HMAC_ctx ctx; - scram_Hi(password, salt, saltlen, iterations, keybuf); - scram_HMAC_init(&ctx, keybuf, SCRAM_KEY_LEN); - scram_HMAC_update(&ctx, keystr, strlen(keystr)); + scram_HMAC_init(&ctx, salted_password, SCRAM_KEY_LEN); + scram_HMAC_update(&ctx, "Server Key", strlen("Server Key")); scram_HMAC_final(result, &ctx); } diff --git a/src/include/common/scram-common.h b/src/include/common/scram-common.h index 6740069eee..656d9e1e6b 100644 --- a/src/include/common/scram-common.h +++ b/src/include/common/scram-common.h @@ -29,14 +29,10 @@ #define SCRAM_RAW_NONCE_LEN 10 /* length of salt when generating new verifiers */ -#define SCRAM_SALT_LEN 10 +#define SCRAM_DEFAULT_SALT_LEN 10 /* default number of iterations when generating verifier */ -#define SCRAM_ITERATIONS_DEFAULT 4096 - -/* Base name of keys used for proof generation */ -#define SCRAM_SERVER_KEY_NAME "Server Key" -#define SCRAM_CLIENT_KEY_NAME "Client Key" +#define SCRAM_DEFAULT_ITERATIONS 4096 /* * Context data for HMAC used in SCRAM authentication. @@ -51,9 +47,10 @@ extern void scram_HMAC_init(scram_HMAC_ctx *ctx, const uint8 *key, int keylen); extern void scram_HMAC_update(scram_HMAC_ctx *ctx, const char *str, int slen); extern void scram_HMAC_final(uint8 *result, scram_HMAC_ctx *ctx); +extern void scram_SaltedPassword(const char *password, const char *salt, + int saltlen, int iterations, uint8 *result); extern void scram_H(const uint8 *str, int len, uint8 *result); -extern void scram_ClientOrServerKey(const char *password, const char *salt, - int saltlen, int iterations, - const char *keystr, uint8 *result); +extern void scram_ClientKey(const uint8 *salted_password, uint8 *result); +extern void scram_ServerKey(const uint8 *salted_password, uint8 *result); #endif /* SCRAM_COMMON_H */ diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c index c56e91e0e0..be271ce8ac 100644 --- a/src/interfaces/libpq/fe-auth-scram.c +++ b/src/interfaces/libpq/fe-auth-scram.c @@ -46,6 +46,7 @@ typedef struct char *password; /* We construct these */ + uint8 SaltedPassword[SCRAM_KEY_LEN]; char *client_nonce; char *client_first_message_bare; char *client_final_message_without_proof; @@ -59,7 +60,7 @@ typedef struct /* These come from the server-final message */ char *server_final_message; - char ServerProof[SCRAM_KEY_LEN]; + char ServerSignature[SCRAM_KEY_LEN]; } fe_scram_state; static bool read_server_first_message(fe_scram_state *state, char *input, @@ -70,7 +71,7 @@ static char *build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage); static char *build_client_final_message(fe_scram_state *state, PQExpBuffer errormessage); -static bool verify_server_proof(fe_scram_state *state); +static bool verify_server_signature(fe_scram_state *state); static void calculate_client_proof(fe_scram_state *state, const char *client_final_message_without_proof, uint8 *result); @@ -216,12 +217,12 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen, goto error; /* - * Verify server proof, to make sure we're talking to the genuine - * server. XXX: A fake server could simply not require + * Verify server signature, to make sure we're talking to the + * genuine server. XXX: A fake server could simply not require * authentication, though. There is currently no option in libpq * to reject a connection, if SCRAM authentication did not happen. */ - if (verify_server_proof(state)) + if (verify_server_signature(state)) *success = true; else { @@ -486,12 +487,11 @@ read_server_first_message(fe_scram_state *state, char *input, * Read the final exchange message coming from the server. */ static bool -read_server_final_message(fe_scram_state *state, - char *input, +read_server_final_message(fe_scram_state *state, char *input, PQExpBuffer errormessage) { - char *encoded_server_proof; - int server_proof_len; + char *encoded_server_signature; + int server_signature_len; state->server_final_message = strdup(input); if (!state->server_final_message) @@ -513,8 +513,8 @@ read_server_final_message(fe_scram_state *state, } /* Parse the message. */ - encoded_server_proof = read_attr_value(&input, 'v', errormessage); - if (encoded_server_proof == NULL) + encoded_server_signature = read_attr_value(&input, 'v', errormessage); + if (encoded_server_signature == NULL) { /* read_attr_value() has generated an error message */ return false; @@ -524,13 +524,13 @@ read_server_final_message(fe_scram_state *state, printfPQExpBuffer(errormessage, libpq_gettext("malformed SCRAM message (garbage at end of server-final-message)\n")); - server_proof_len = pg_b64_decode(encoded_server_proof, - strlen(encoded_server_proof), - state->ServerProof); - if (server_proof_len != SCRAM_KEY_LEN) + server_signature_len = pg_b64_decode(encoded_server_signature, + strlen(encoded_server_signature), + state->ServerSignature); + if (server_signature_len != SCRAM_KEY_LEN) { printfPQExpBuffer(errormessage, - libpq_gettext("malformed SCRAM message (invalid server proof)\n")); + libpq_gettext("malformed SCRAM message (invalid server signature)\n")); return false; } @@ -552,8 +552,14 @@ calculate_client_proof(fe_scram_state *state, int i; scram_HMAC_ctx ctx; - scram_ClientOrServerKey(state->password, state->salt, state->saltlen, - state->iterations, SCRAM_CLIENT_KEY_NAME, ClientKey); + /* + * Calculate SaltedPassword, and store it in 'state' so that we can reuse + * it later in verify_server_signature. + */ + scram_SaltedPassword(state->password, state->salt, state->saltlen, + state->iterations, state->SaltedPassword); + + scram_ClientKey(state->SaltedPassword, ClientKey); scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey); scram_HMAC_init(&ctx, StoredKey, SCRAM_KEY_LEN); @@ -575,19 +581,17 @@ calculate_client_proof(fe_scram_state *state, } /* - * Validate the server proof, received as part of the final exchange message - * received from the server. + * Validate the server signature, received as part of the final exchange + * message received from the server. */ static bool -verify_server_proof(fe_scram_state *state) +verify_server_signature(fe_scram_state *state) { - uint8 ServerSignature[SCRAM_KEY_LEN]; + uint8 expected_ServerSignature[SCRAM_KEY_LEN]; uint8 ServerKey[SCRAM_KEY_LEN]; scram_HMAC_ctx ctx; - scram_ClientOrServerKey(state->password, state->salt, state->saltlen, - state->iterations, SCRAM_SERVER_KEY_NAME, - ServerKey); + scram_ServerKey(state->SaltedPassword, ServerKey); /* calculate ServerSignature */ scram_HMAC_init(&ctx, ServerKey, SCRAM_KEY_LEN); @@ -602,9 +606,9 @@ verify_server_proof(fe_scram_state *state) scram_HMAC_update(&ctx, state->client_final_message_without_proof, strlen(state->client_final_message_without_proof)); - scram_HMAC_final(ServerSignature, &ctx); + scram_HMAC_final(expected_ServerSignature, &ctx); - if (memcmp(ServerSignature, state->ServerProof, SCRAM_KEY_LEN) != 0) + if (memcmp(expected_ServerSignature, state->ServerSignature, SCRAM_KEY_LEN) != 0) return false; return true;