]> granicus.if.org Git - postgresql/commitdiff
Misc SCRAM code cleanups.
authorHeikki Linnakangas <heikki.linnakangas@iki.fi>
Fri, 28 Apr 2017 12:04:02 +0000 (15:04 +0300)
committerHeikki Linnakangas <heikki.linnakangas@iki.fi>
Fri, 28 Apr 2017 12:22:38 +0000 (15:22 +0300)
* Move computation of SaltedPassword to a separate function from
  scram_ClientOrServerKey(). This saves a lot of cycles in libpq, by
  computing SaltedPassword only once per authentication. (Computing
  SaltedPassword is expensive by design.)

* Split scram_ClientOrServerKey() into two functions. Improves
  readability, by making the calling code less verbose.

* Rename "server proof" to "server signature", to better match the
  nomenclature used in RFC 5802.

* Rename SCRAM_SALT_LEN to SCRAM_DEFAULT_SALT_LEN, to make it more clear
  that the salt can be of any length, and the constant only specifies how
  long a salt we use when we generate a new verifier. Also rename
  SCRAM_ITERATIONS_DEFAULT to SCRAM_DEFAULT_ITERATIONS, for consistency.

These things caught my eye while working on other upcoming changes.

src/backend/libpq/auth-scram.c
src/common/scram-common.c
src/include/common/scram-common.h
src/interfaces/libpq/fe-auth-scram.c

index 16bea446e37dae75750b41faa360337fd91234be..5c85af943cdcbd2e62a6d073fdf38b8081a969ff 100644 (file)
@@ -396,7 +396,8 @@ scram_build_verifier(const char *username, const char *password,
 {
        char       *prep_password = NULL;
        pg_saslprep_rc rc;
-       char            saltbuf[SCRAM_SALT_LEN];
+       char            saltbuf[SCRAM_DEFAULT_SALT_LEN];
+       uint8           salted_password[SCRAM_KEY_LEN];
        uint8           keybuf[SCRAM_KEY_LEN];
        char       *encoded_salt;
        char       *encoded_storedkey;
@@ -414,10 +415,10 @@ scram_build_verifier(const char *username, const char *password,
                password = (const char *) prep_password;
 
        if (iterations <= 0)
-               iterations = SCRAM_ITERATIONS_DEFAULT;
+               iterations = SCRAM_DEFAULT_ITERATIONS;
 
        /* Generate salt, and encode it in base64 */
-       if (!pg_backend_random(saltbuf, SCRAM_SALT_LEN))
+       if (!pg_backend_random(saltbuf, SCRAM_DEFAULT_SALT_LEN))
        {
                ereport(LOG,
                                (errcode(ERRCODE_INTERNAL_ERROR),
@@ -425,13 +426,14 @@ scram_build_verifier(const char *username, const char *password,
                return NULL;
        }
 
-       encoded_salt = palloc(pg_b64_enc_len(SCRAM_SALT_LEN) + 1);
-       encoded_len = pg_b64_encode(saltbuf, SCRAM_SALT_LEN, encoded_salt);
+       encoded_salt = palloc(pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN) + 1);
+       encoded_len = pg_b64_encode(saltbuf, SCRAM_DEFAULT_SALT_LEN, encoded_salt);
        encoded_salt[encoded_len] = '\0';
 
        /* Calculate StoredKey, and encode it in base64 */
-       scram_ClientOrServerKey(password, saltbuf, SCRAM_SALT_LEN,
-                                                       iterations, SCRAM_CLIENT_KEY_NAME, keybuf);
+       scram_SaltedPassword(password, saltbuf, SCRAM_DEFAULT_SALT_LEN,
+                                                iterations, salted_password);
+       scram_ClientKey(salted_password, keybuf);
        scram_H(keybuf, SCRAM_KEY_LEN, keybuf);         /* StoredKey */
 
        encoded_storedkey = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1);
@@ -440,8 +442,7 @@ scram_build_verifier(const char *username, const char *password,
        encoded_storedkey[encoded_len] = '\0';
 
        /* And same for ServerKey */
-       scram_ClientOrServerKey(password, saltbuf, SCRAM_SALT_LEN, iterations,
-                                                       SCRAM_SERVER_KEY_NAME, keybuf);
+       scram_ServerKey(salted_password, keybuf);
 
        encoded_serverkey = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1);
        encoded_len = pg_b64_encode((const char *) keybuf, SCRAM_KEY_LEN,
@@ -473,6 +474,7 @@ scram_verify_plain_password(const char *username, const char *password,
        char       *salt;
        int                     saltlen;
        int                     iterations;
+       uint8           salted_password[SCRAM_KEY_LEN];
        uint8           stored_key[SCRAM_KEY_LEN];
        uint8           server_key[SCRAM_KEY_LEN];
        uint8           computed_key[SCRAM_KEY_LEN];
@@ -502,9 +504,9 @@ scram_verify_plain_password(const char *username, const char *password,
        if (rc == SASLPREP_SUCCESS)
                password = prep_password;
 
-       /* Compute Server key based on the user-supplied plaintext password */
-       scram_ClientOrServerKey(password, salt, saltlen, iterations,
-                                                       SCRAM_SERVER_KEY_NAME, computed_key);
+       /* Compute Server Key based on the user-supplied plaintext password */
+       scram_SaltedPassword(password, salt, saltlen, iterations, salted_password);
+       scram_ServerKey(salted_password, computed_key);
 
        if (prep_password)
                pfree(prep_password);
@@ -630,12 +632,12 @@ mock_scram_verifier(const char *username, int *iterations, char **salt,
        /* Generate deterministic salt */
        raw_salt = scram_MockSalt(username);
 
-       encoded_salt = (char *) palloc(pg_b64_enc_len(SCRAM_SALT_LEN) + 1);
-       encoded_len = pg_b64_encode(raw_salt, SCRAM_SALT_LEN, encoded_salt);
+       encoded_salt = (char *) palloc(pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN) + 1);
+       encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt);
        encoded_salt[encoded_len] = '\0';
 
        *salt = encoded_salt;
-       *iterations = SCRAM_ITERATIONS_DEFAULT;
+       *iterations = SCRAM_DEFAULT_ITERATIONS;
 
        /* StoredKey and ServerKey are not used in a doomed authentication */
        memset(stored_key, 0, SCRAM_KEY_LEN);
@@ -1179,7 +1181,7 @@ build_server_final_message(scram_state *state)
 /*
  * Determinisitcally generate salt for mock authentication, using a SHA256
  * hash based on the username and a cluster-level secret key.  Returns a
- * pointer to a static buffer of size SCRAM_SALT_LEN.
+ * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN.
  */
 static char *
 scram_MockSalt(const char *username)
@@ -1194,7 +1196,7 @@ scram_MockSalt(const char *username)
         * not larger the SHA256 digest length. If the salt is smaller, the caller
         * will just ignore the extra data))
         */
-       StaticAssertStmt(PG_SHA256_DIGEST_LENGTH >= SCRAM_SALT_LEN,
+       StaticAssertStmt(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN,
                                         "salt length greater than SHA256 digest length");
 
        pg_sha256_init(&ctx);
index df9f0eaa90d1c6106613fbec6e995d5b963b6384..a8ea44944c493749ce88b82b73fa3322efdc3b21 100644 (file)
@@ -98,14 +98,16 @@ scram_HMAC_final(uint8 *result, scram_HMAC_ctx *ctx)
 }
 
 /*
- * Iterate hash calculation of HMAC entry using given salt.
- * scram_Hi() is essentially PBKDF2 (see RFC2898) with HMAC() as the
- * pseudorandom function.
+ * Calculate SaltedPassword.
+ *
+ * The password should already be normalized by SASLprep.
  */
-static void
-scram_Hi(const char *str, const char *salt, int saltlen, int iterations, uint8 *result)
+void
+scram_SaltedPassword(const char *password,
+                                        const char *salt, int saltlen, int iterations,
+                                        uint8 *result)
 {
-       int                     str_len = strlen(str);
+       int                     password_len = strlen(password);
        uint32          one = htonl(1);
        int                     i,
                                j;
@@ -113,8 +115,14 @@ scram_Hi(const char *str, const char *salt, int saltlen, int iterations, uint8 *
        uint8           Ui_prev[SCRAM_KEY_LEN];
        scram_HMAC_ctx hmac_ctx;
 
+       /*
+        * Iterate hash calculation of HMAC entry using given salt.  This is
+        * essentially PBKDF2 (see RFC2898) with HMAC() as the pseudorandom
+        * function.
+        */
+
        /* First iteration */
-       scram_HMAC_init(&hmac_ctx, (uint8 *) str, str_len);
+       scram_HMAC_init(&hmac_ctx, (uint8 *) password, password_len);
        scram_HMAC_update(&hmac_ctx, salt, saltlen);
        scram_HMAC_update(&hmac_ctx, (char *) &one, sizeof(uint32));
        scram_HMAC_final(Ui_prev, &hmac_ctx);
@@ -123,7 +131,7 @@ scram_Hi(const char *str, const char *salt, int saltlen, int iterations, uint8 *
        /* Subsequent iterations */
        for (i = 2; i <= iterations; i++)
        {
-               scram_HMAC_init(&hmac_ctx, (uint8 *) str, str_len);
+               scram_HMAC_init(&hmac_ctx, (uint8 *) password, password_len);
                scram_HMAC_update(&hmac_ctx, (const char *) Ui_prev, SCRAM_KEY_LEN);
                scram_HMAC_final(Ui, &hmac_ctx);
                for (j = 0; j < SCRAM_KEY_LEN; j++)
@@ -148,20 +156,27 @@ scram_H(const uint8 *input, int len, uint8 *result)
 }
 
 /*
- * Calculate ClientKey or ServerKey.
- *
- * The password should already be normalized by SASLprep.
+ * Calculate ClientKey.
  */
 void
-scram_ClientOrServerKey(const char *password,
-                                               const char *salt, int saltlen, int iterations,
-                                               const char *keystr, uint8 *result)
+scram_ClientKey(const uint8 *salted_password, uint8 *result)
+{
+       scram_HMAC_ctx ctx;
+
+       scram_HMAC_init(&ctx, salted_password, SCRAM_KEY_LEN);
+       scram_HMAC_update(&ctx, "Client Key", strlen("Client Key"));
+       scram_HMAC_final(result, &ctx);
+}
+
+/*
+ * Calculate ServerKey.
+ */
+void
+scram_ServerKey(const uint8 *salted_password, uint8 *result)
 {
-       uint8           keybuf[SCRAM_KEY_LEN];
        scram_HMAC_ctx ctx;
 
-       scram_Hi(password, salt, saltlen, iterations, keybuf);
-       scram_HMAC_init(&ctx, keybuf, SCRAM_KEY_LEN);
-       scram_HMAC_update(&ctx, keystr, strlen(keystr));
+       scram_HMAC_init(&ctx, salted_password, SCRAM_KEY_LEN);
+       scram_HMAC_update(&ctx, "Server Key", strlen("Server Key"));
        scram_HMAC_final(result, &ctx);
 }
index 6740069eee18368715748600c18a07dcddc19060..656d9e1e6b1378e70d6bfb5f1c2d29c84282fa02 100644 (file)
 #define SCRAM_RAW_NONCE_LEN                    10
 
 /* length of salt when generating new verifiers */
-#define SCRAM_SALT_LEN                         10
+#define SCRAM_DEFAULT_SALT_LEN         10
 
 /* default number of iterations when generating verifier */
-#define SCRAM_ITERATIONS_DEFAULT       4096
-
-/* Base name of keys used for proof generation */
-#define SCRAM_SERVER_KEY_NAME "Server Key"
-#define SCRAM_CLIENT_KEY_NAME "Client Key"
+#define SCRAM_DEFAULT_ITERATIONS       4096
 
 /*
  * Context data for HMAC used in SCRAM authentication.
@@ -51,9 +47,10 @@ extern void scram_HMAC_init(scram_HMAC_ctx *ctx, const uint8 *key, int keylen);
 extern void scram_HMAC_update(scram_HMAC_ctx *ctx, const char *str, int slen);
 extern void scram_HMAC_final(uint8 *result, scram_HMAC_ctx *ctx);
 
+extern void scram_SaltedPassword(const char *password, const char *salt,
+                                               int saltlen, int iterations, uint8 *result);
 extern void scram_H(const uint8 *str, int len, uint8 *result);
-extern void scram_ClientOrServerKey(const char *password, const char *salt,
-                                               int saltlen, int iterations,
-                                               const char *keystr, uint8 *result);
+extern void scram_ClientKey(const uint8 *salted_password, uint8 *result);
+extern void scram_ServerKey(const uint8 *salted_password, uint8 *result);
 
 #endif   /* SCRAM_COMMON_H */
index c56e91e0e04bdccd314496144f052f69833cf4a4..be271ce8ac01d3544c8648effd97f141f8f889f1 100644 (file)
@@ -46,6 +46,7 @@ typedef struct
        char       *password;
 
        /* We construct these */
+       uint8           SaltedPassword[SCRAM_KEY_LEN];
        char       *client_nonce;
        char       *client_first_message_bare;
        char       *client_final_message_without_proof;
@@ -59,7 +60,7 @@ typedef struct
 
        /* These come from the server-final message */
        char       *server_final_message;
-       char            ServerProof[SCRAM_KEY_LEN];
+       char            ServerSignature[SCRAM_KEY_LEN];
 } fe_scram_state;
 
 static bool read_server_first_message(fe_scram_state *state, char *input,
@@ -70,7 +71,7 @@ static char *build_client_first_message(fe_scram_state *state,
                                                   PQExpBuffer errormessage);
 static char *build_client_final_message(fe_scram_state *state,
                                                   PQExpBuffer errormessage);
-static bool verify_server_proof(fe_scram_state *state);
+static bool verify_server_signature(fe_scram_state *state);
 static void calculate_client_proof(fe_scram_state *state,
                                           const char *client_final_message_without_proof,
                                           uint8 *result);
@@ -216,12 +217,12 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
                                goto error;
 
                        /*
-                        * Verify server proof, to make sure we're talking to the genuine
-                        * server.  XXX: A fake server could simply not require
+                        * Verify server signature, to make sure we're talking to the
+                        * genuine server.  XXX: A fake server could simply not require
                         * authentication, though.  There is currently no option in libpq
                         * to reject a connection, if SCRAM authentication did not happen.
                         */
-                       if (verify_server_proof(state))
+                       if (verify_server_signature(state))
                                *success = true;
                        else
                        {
@@ -486,12 +487,11 @@ read_server_first_message(fe_scram_state *state, char *input,
  * Read the final exchange message coming from the server.
  */
 static bool
-read_server_final_message(fe_scram_state *state,
-                                                 char *input,
+read_server_final_message(fe_scram_state *state, char *input,
                                                  PQExpBuffer errormessage)
 {
-       char       *encoded_server_proof;
-       int                     server_proof_len;
+       char       *encoded_server_signature;
+       int                     server_signature_len;
 
        state->server_final_message = strdup(input);
        if (!state->server_final_message)
@@ -513,8 +513,8 @@ read_server_final_message(fe_scram_state *state,
        }
 
        /* Parse the message. */
-       encoded_server_proof = read_attr_value(&input, 'v', errormessage);
-       if (encoded_server_proof == NULL)
+       encoded_server_signature = read_attr_value(&input, 'v', errormessage);
+       if (encoded_server_signature == NULL)
        {
                /* read_attr_value() has generated an error message */
                return false;
@@ -524,13 +524,13 @@ read_server_final_message(fe_scram_state *state,
                printfPQExpBuffer(errormessage,
                                                  libpq_gettext("malformed SCRAM message (garbage at end of server-final-message)\n"));
 
-       server_proof_len = pg_b64_decode(encoded_server_proof,
-                                                                        strlen(encoded_server_proof),
-                                                                        state->ServerProof);
-       if (server_proof_len != SCRAM_KEY_LEN)
+       server_signature_len = pg_b64_decode(encoded_server_signature,
+                                                                                strlen(encoded_server_signature),
+                                                                                state->ServerSignature);
+       if (server_signature_len != SCRAM_KEY_LEN)
        {
                printfPQExpBuffer(errormessage,
-                 libpq_gettext("malformed SCRAM message (invalid server proof)\n"));
+                                                 libpq_gettext("malformed SCRAM message (invalid server signature)\n"));
                return false;
        }
 
@@ -552,8 +552,14 @@ calculate_client_proof(fe_scram_state *state,
        int                     i;
        scram_HMAC_ctx ctx;
 
-       scram_ClientOrServerKey(state->password, state->salt, state->saltlen,
-                                               state->iterations, SCRAM_CLIENT_KEY_NAME, ClientKey);
+       /*
+        * Calculate SaltedPassword, and store it in 'state' so that we can reuse
+        * it later in verify_server_signature.
+        */
+       scram_SaltedPassword(state->password, state->salt, state->saltlen,
+                                                state->iterations, state->SaltedPassword);
+
+       scram_ClientKey(state->SaltedPassword, ClientKey);
        scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey);
 
        scram_HMAC_init(&ctx, StoredKey, SCRAM_KEY_LEN);
@@ -575,19 +581,17 @@ calculate_client_proof(fe_scram_state *state,
 }
 
 /*
- * Validate the server proof, received as part of the final exchange message
- * received from the server.
+ * Validate the server signature, received as part of the final exchange
+ * message received from the server.
  */
 static bool
-verify_server_proof(fe_scram_state *state)
+verify_server_signature(fe_scram_state *state)
 {
-       uint8           ServerSignature[SCRAM_KEY_LEN];
+       uint8           expected_ServerSignature[SCRAM_KEY_LEN];
        uint8           ServerKey[SCRAM_KEY_LEN];
        scram_HMAC_ctx ctx;
 
-       scram_ClientOrServerKey(state->password, state->salt, state->saltlen,
-                                                       state->iterations, SCRAM_SERVER_KEY_NAME,
-                                                       ServerKey);
+       scram_ServerKey(state->SaltedPassword, ServerKey);
 
        /* calculate ServerSignature */
        scram_HMAC_init(&ctx, ServerKey, SCRAM_KEY_LEN);
@@ -602,9 +606,9 @@ verify_server_proof(fe_scram_state *state)
        scram_HMAC_update(&ctx,
                                          state->client_final_message_without_proof,
                                          strlen(state->client_final_message_without_proof));
-       scram_HMAC_final(ServerSignature, &ctx);
+       scram_HMAC_final(expected_ServerSignature, &ctx);
 
-       if (memcmp(ServerSignature, state->ServerProof, SCRAM_KEY_LEN) != 0)
+       if (memcmp(expected_ServerSignature, state->ServerSignature, SCRAM_KEY_LEN) != 0)
                return false;
 
        return true;