]> granicus.if.org Git - postgresql/blob - src/backend/libpq/auth-scram.c
Fix use of term "verifier"
[postgresql] / src / backend / libpq / auth-scram.c
1 /*-------------------------------------------------------------------------
2  *
3  * auth-scram.c
4  *        Server-side implementation of the SASL SCRAM-SHA-256 mechanism.
5  *
6  * See the following RFCs for more details:
7  * - RFC 5802: https://tools.ietf.org/html/rfc5802
8  * - RFC 5803: https://tools.ietf.org/html/rfc5803
9  * - RFC 7677: https://tools.ietf.org/html/rfc7677
10  *
11  * Here are some differences:
12  *
13  * - Username from the authentication exchange is not used. The client
14  *       should send an empty string as the username.
15  *
16  * - If the password isn't valid UTF-8, or contains characters prohibited
17  *       by the SASLprep profile, we skip the SASLprep pre-processing and use
18  *       the raw bytes in calculating the hash.
19  *
20  * - If channel binding is used, the channel binding type is always
21  *       "tls-server-end-point".  The spec says the default is "tls-unique"
22  *       (RFC 5802, section 6.1. Default Channel Binding), but there are some
23  *       problems with that.  Firstly, not all SSL libraries provide an API to
24  *       get the TLS Finished message, required to use "tls-unique".  Secondly,
25  *       "tls-unique" is not specified for TLS v1.3, and as of this writing,
26  *       it's not clear if there will be a replacement.  We could support both
27  *       "tls-server-end-point" and "tls-unique", but for our use case,
28  *       "tls-unique" doesn't really have any advantages.  The main advantage
29  *       of "tls-unique" would be that it works even if the server doesn't
30  *       have a certificate, but PostgreSQL requires a server certificate
31  *       whenever SSL is used, anyway.
32  *
33  *
34  * The password stored in pg_authid consists of the iteration count, salt,
35  * StoredKey and ServerKey.
36  *
37  * SASLprep usage
38  * --------------
39  *
40  * One notable difference to the SCRAM specification is that while the
41  * specification dictates that the password is in UTF-8, and prohibits
42  * certain characters, we are more lenient.  If the password isn't a valid
43  * UTF-8 string, or contains prohibited characters, the raw bytes are used
44  * to calculate the hash instead, without SASLprep processing.  This is
45  * because PostgreSQL supports other encodings too, and the encoding being
46  * used during authentication is undefined (client_encoding isn't set until
47  * after authentication).  In effect, we try to interpret the password as
48  * UTF-8 and apply SASLprep processing, but if it looks invalid, we assume
49  * that it's in some other encoding.
50  *
51  * In the worst case, we misinterpret a password that's in a different
52  * encoding as being Unicode, because it happens to consists entirely of
53  * valid UTF-8 bytes, and we apply Unicode normalization to it.  As long
54  * as we do that consistently, that will not lead to failed logins.
55  * Fortunately, the UTF-8 byte sequences that are ignored by SASLprep
56  * don't correspond to any commonly used characters in any of the other
57  * supported encodings, so it should not lead to any significant loss in
58  * entropy, even if the normalization is incorrectly applied to a
59  * non-UTF-8 password.
60  *
61  * Error handling
62  * --------------
63  *
64  * Don't reveal user information to an unauthenticated client.  We don't
65  * want an attacker to be able to probe whether a particular username is
66  * valid.  In SCRAM, the server has to read the salt and iteration count
67  * from the user's stored secret, and send it to the client.  To avoid
68  * revealing whether a user exists, when the client tries to authenticate
69  * with a username that doesn't exist, or doesn't have a valid SCRAM
70  * secret in pg_authid, we create a fake salt and iteration count
71  * on-the-fly, and proceed with the authentication with that.  In the end,
72  * we'll reject the attempt, as if an incorrect password was given.  When
73  * we are performing a "mock" authentication, the 'doomed' flag in
74  * scram_state is set.
75  *
76  * In the error messages, avoid printing strings from the client, unless
77  * you check that they are pure ASCII.  We don't want an unauthenticated
78  * attacker to be able to spam the logs with characters that are not valid
79  * to the encoding being used, whatever that is.  We cannot avoid that in
80  * general, after logging in, but let's do what we can here.
81  *
82  *
83  * Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group
84  * Portions Copyright (c) 1994, Regents of the University of California
85  *
86  * src/backend/libpq/auth-scram.c
87  *
88  *-------------------------------------------------------------------------
89  */
90 #include "postgres.h"
91
92 #include <unistd.h>
93
94 #include "access/xlog.h"
95 #include "catalog/pg_authid.h"
96 #include "catalog/pg_control.h"
97 #include "common/base64.h"
98 #include "common/saslprep.h"
99 #include "common/scram-common.h"
100 #include "common/sha2.h"
101 #include "libpq/auth.h"
102 #include "libpq/crypt.h"
103 #include "libpq/scram.h"
104 #include "miscadmin.h"
105 #include "utils/builtins.h"
106 #include "utils/timestamp.h"
107
108 /*
109  * Status data for a SCRAM authentication exchange.  This should be kept
110  * internal to this file.
111  */
112 typedef enum
113 {
114         SCRAM_AUTH_INIT,
115         SCRAM_AUTH_SALT_SENT,
116         SCRAM_AUTH_FINISHED
117 } scram_state_enum;
118
119 typedef struct
120 {
121         scram_state_enum state;
122
123         const char *username;           /* username from startup packet */
124
125         Port       *port;
126         bool            channel_binding_in_use;
127
128         int                     iterations;
129         char       *salt;                       /* base64-encoded */
130         uint8           StoredKey[SCRAM_KEY_LEN];
131         uint8           ServerKey[SCRAM_KEY_LEN];
132
133         /* Fields of the first message from client */
134         char            cbind_flag;
135         char       *client_first_message_bare;
136         char       *client_username;
137         char       *client_nonce;
138
139         /* Fields from the last message from client */
140         char       *client_final_message_without_proof;
141         char       *client_final_nonce;
142         char            ClientProof[SCRAM_KEY_LEN];
143
144         /* Fields generated in the server */
145         char       *server_first_message;
146         char       *server_nonce;
147
148         /*
149          * If something goes wrong during the authentication, or we are performing
150          * a "mock" authentication (see comments at top of file), the 'doomed'
151          * flag is set.  A reason for the failure, for the server log, is put in
152          * 'logdetail'.
153          */
154         bool            doomed;
155         char       *logdetail;
156 } scram_state;
157
158 static void read_client_first_message(scram_state *state, const char *input);
159 static void read_client_final_message(scram_state *state, const char *input);
160 static char *build_server_first_message(scram_state *state);
161 static char *build_server_final_message(scram_state *state);
162 static bool verify_client_proof(scram_state *state);
163 static bool verify_final_nonce(scram_state *state);
164 static void mock_scram_secret(const char *username, int *iterations,
165                                                                 char **salt, uint8 *stored_key, uint8 *server_key);
166 static bool is_scram_printable(char *p);
167 static char *sanitize_char(char c);
168 static char *sanitize_str(const char *s);
169 static char *scram_mock_salt(const char *username);
170
171 /*
172  * pg_be_scram_get_mechanisms
173  *
174  * Get a list of SASL mechanisms that this module supports.
175  *
176  * For the convenience of building the FE/BE packet that lists the
177  * mechanisms, the names are appended to the given StringInfo buffer,
178  * separated by '\0' bytes.
179  */
180 void
181 pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
182 {
183         /*
184          * Advertise the mechanisms in decreasing order of importance.  So the
185          * channel-binding variants go first, if they are supported.  Channel
186          * binding is only supported with SSL, and only if the SSL implementation
187          * has a function to get the certificate's hash.
188          */
189 #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
190         if (port->ssl_in_use)
191         {
192                 appendStringInfoString(buf, SCRAM_SHA_256_PLUS_NAME);
193                 appendStringInfoChar(buf, '\0');
194         }
195 #endif
196         appendStringInfoString(buf, SCRAM_SHA_256_NAME);
197         appendStringInfoChar(buf, '\0');
198 }
199
200 /*
201  * pg_be_scram_init
202  *
203  * Initialize a new SCRAM authentication exchange status tracker.  This
204  * needs to be called before doing any exchange.  It will be filled later
205  * after the beginning of the exchange with authentication information.
206  *
207  * 'selected_mech' identifies the SASL mechanism that the client selected.
208  * It should be one of the mechanisms that we support, as returned by
209  * pg_be_scram_get_mechanisms().
210  *
211  * 'shadow_pass' is the role's stored secret, from pg_authid.rolpassword.
212  * The username was provided by the client in the startup message, and is
213  * available in port->user_name.  If 'shadow_pass' is NULL, we still perform
214  * an authentication exchange, but it will fail, as if an incorrect password
215  * was given.
216  */
217 void *
218 pg_be_scram_init(Port *port,
219                                  const char *selected_mech,
220                                  const char *shadow_pass)
221 {
222         scram_state *state;
223         bool            got_secret;
224
225         state = (scram_state *) palloc0(sizeof(scram_state));
226         state->port = port;
227         state->state = SCRAM_AUTH_INIT;
228
229         /*
230          * Parse the selected mechanism.
231          *
232          * Note that if we don't support channel binding, either because the SSL
233          * implementation doesn't support it or we're not using SSL at all, we
234          * would not have advertised the PLUS variant in the first place.  If the
235          * client nevertheless tries to select it, it's a protocol violation like
236          * selecting any other SASL mechanism we don't support.
237          */
238 #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
239         if (strcmp(selected_mech, SCRAM_SHA_256_PLUS_NAME) == 0 && port->ssl_in_use)
240                 state->channel_binding_in_use = true;
241         else
242 #endif
243         if (strcmp(selected_mech, SCRAM_SHA_256_NAME) == 0)
244                 state->channel_binding_in_use = false;
245         else
246                 ereport(ERROR,
247                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
248                                  errmsg("client selected an invalid SASL authentication mechanism")));
249
250         /*
251          * Parse the stored secret.
252          */
253         if (shadow_pass)
254         {
255                 int                     password_type = get_password_type(shadow_pass);
256
257                 if (password_type == PASSWORD_TYPE_SCRAM_SHA_256)
258                 {
259                         if (parse_scram_secret(shadow_pass, &state->iterations, &state->salt,
260                                                                          state->StoredKey, state->ServerKey))
261                                 got_secret = true;
262                         else
263                         {
264                                 /*
265                                  * The password looked like a SCRAM secret, but could not be
266                                  * parsed.
267                                  */
268                                 ereport(LOG,
269                                                 (errmsg("invalid SCRAM secret for user \"%s\"",
270                                                                 state->port->user_name)));
271                                 got_secret = false;
272                         }
273                 }
274                 else
275                 {
276                         /*
277                          * The user doesn't have SCRAM secret. (You cannot do SCRAM
278                          * authentication with an MD5 hash.)
279                          */
280                         state->logdetail = psprintf(_("User \"%s\" does not have a valid SCRAM secret."),
281                                                                                 state->port->user_name);
282                         got_secret = false;
283                 }
284         }
285         else
286         {
287                 /*
288                  * The caller requested us to perform a dummy authentication.  This is
289                  * considered normal, since the caller requested it, so don't set log
290                  * detail.
291                  */
292                 got_secret = false;
293         }
294
295         /*
296          * If the user did not have a valid SCRAM secret, we still go through
297          * the motions with a mock one, and fail as if the client supplied an
298          * incorrect password.  This is to avoid revealing information to an
299          * attacker.
300          */
301         if (!got_secret)
302         {
303                 mock_scram_secret(state->port->user_name, &state->iterations,
304                                                         &state->salt, state->StoredKey, state->ServerKey);
305                 state->doomed = true;
306         }
307
308         return state;
309 }
310
311 /*
312  * Continue a SCRAM authentication exchange.
313  *
314  * 'input' is the SCRAM payload sent by the client.  On the first call,
315  * 'input' contains the "Initial Client Response" that the client sent as
316  * part of the SASLInitialResponse message, or NULL if no Initial Client
317  * Response was given.  (The SASL specification distinguishes between an
318  * empty response and non-existing one.)  On subsequent calls, 'input'
319  * cannot be NULL.  For convenience in this function, the caller must
320  * ensure that there is a null terminator at input[inputlen].
321  *
322  * The next message to send to client is saved in 'output', for a length
323  * of 'outputlen'.  In the case of an error, optionally store a palloc'd
324  * string at *logdetail that will be sent to the postmaster log (but not
325  * the client).
326  */
327 int
328 pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
329                                          char **output, int *outputlen, char **logdetail)
330 {
331         scram_state *state = (scram_state *) opaq;
332         int                     result;
333
334         *output = NULL;
335
336         /*
337          * If the client didn't include an "Initial Client Response" in the
338          * SASLInitialResponse message, send an empty challenge, to which the
339          * client will respond with the same data that usually comes in the
340          * Initial Client Response.
341          */
342         if (input == NULL)
343         {
344                 Assert(state->state == SCRAM_AUTH_INIT);
345
346                 *output = pstrdup("");
347                 *outputlen = 0;
348                 return SASL_EXCHANGE_CONTINUE;
349         }
350
351         /*
352          * Check that the input length agrees with the string length of the input.
353          * We can ignore inputlen after this.
354          */
355         if (inputlen == 0)
356                 ereport(ERROR,
357                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
358                                  errmsg("malformed SCRAM message"),
359                                  errdetail("The message is empty.")));
360         if (inputlen != strlen(input))
361                 ereport(ERROR,
362                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
363                                  errmsg("malformed SCRAM message"),
364                                  errdetail("Message length does not match input length.")));
365
366         switch (state->state)
367         {
368                 case SCRAM_AUTH_INIT:
369
370                         /*
371                          * Initialization phase.  Receive the first message from client
372                          * and be sure that it parsed correctly.  Then send the challenge
373                          * to the client.
374                          */
375                         read_client_first_message(state, input);
376
377                         /* prepare message to send challenge */
378                         *output = build_server_first_message(state);
379
380                         state->state = SCRAM_AUTH_SALT_SENT;
381                         result = SASL_EXCHANGE_CONTINUE;
382                         break;
383
384                 case SCRAM_AUTH_SALT_SENT:
385
386                         /*
387                          * Final phase for the server.  Receive the response to the
388                          * challenge previously sent, verify, and let the client know that
389                          * everything went well (or not).
390                          */
391                         read_client_final_message(state, input);
392
393                         if (!verify_final_nonce(state))
394                                 ereport(ERROR,
395                                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
396                                                  errmsg("invalid SCRAM response"),
397                                                  errdetail("Nonce does not match.")));
398
399                         /*
400                          * Now check the final nonce and the client proof.
401                          *
402                          * If we performed a "mock" authentication that we knew would fail
403                          * from the get go, this is where we fail.
404                          *
405                          * The SCRAM specification includes an error code,
406                          * "invalid-proof", for authentication failure, but it also allows
407                          * erroring out in an application-specific way.  We choose to do
408                          * the latter, so that the error message for invalid password is
409                          * the same for all authentication methods.  The caller will call
410                          * ereport(), when we return SASL_EXCHANGE_FAILURE with no output.
411                          *
412                          * NB: the order of these checks is intentional.  We calculate the
413                          * client proof even in a mock authentication, even though it's
414                          * bound to fail, to thwart timing attacks to determine if a role
415                          * with the given name exists or not.
416                          */
417                         if (!verify_client_proof(state) || state->doomed)
418                         {
419                                 result = SASL_EXCHANGE_FAILURE;
420                                 break;
421                         }
422
423                         /* Build final message for client */
424                         *output = build_server_final_message(state);
425
426                         /* Success! */
427                         result = SASL_EXCHANGE_SUCCESS;
428                         state->state = SCRAM_AUTH_FINISHED;
429                         break;
430
431                 default:
432                         elog(ERROR, "invalid SCRAM exchange state");
433                         result = SASL_EXCHANGE_FAILURE;
434         }
435
436         if (result == SASL_EXCHANGE_FAILURE && state->logdetail && logdetail)
437                 *logdetail = state->logdetail;
438
439         if (*output)
440                 *outputlen = strlen(*output);
441
442         return result;
443 }
444
445 /*
446  * Construct a SCRAM secret, for storing in pg_authid.rolpassword.
447  *
448  * The result is palloc'd, so caller is responsible for freeing it.
449  */
450 char *
451 pg_be_scram_build_secret(const char *password)
452 {
453         char       *prep_password;
454         pg_saslprep_rc rc;
455         char            saltbuf[SCRAM_DEFAULT_SALT_LEN];
456         char       *result;
457
458         /*
459          * Normalize the password with SASLprep.  If that doesn't work, because
460          * the password isn't valid UTF-8 or contains prohibited characters, just
461          * proceed with the original password.  (See comments at top of file.)
462          */
463         rc = pg_saslprep(password, &prep_password);
464         if (rc == SASLPREP_SUCCESS)
465                 password = (const char *) prep_password;
466
467         /* Generate random salt */
468         if (!pg_strong_random(saltbuf, SCRAM_DEFAULT_SALT_LEN))
469                 ereport(ERROR,
470                                 (errcode(ERRCODE_INTERNAL_ERROR),
471                                  errmsg("could not generate random salt")));
472
473         result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
474                                                                   SCRAM_DEFAULT_ITERATIONS, password);
475
476         if (prep_password)
477                 pfree(prep_password);
478
479         return result;
480 }
481
482 /*
483  * Verify a plaintext password against a SCRAM secret.  This is used when
484  * performing plaintext password authentication for a user that has a SCRAM
485  * secret stored in pg_authid.
486  */
487 bool
488 scram_verify_plain_password(const char *username, const char *password,
489                                                         const char *secret)
490 {
491         char       *encoded_salt;
492         char       *salt;
493         int                     saltlen;
494         int                     iterations;
495         uint8           salted_password[SCRAM_KEY_LEN];
496         uint8           stored_key[SCRAM_KEY_LEN];
497         uint8           server_key[SCRAM_KEY_LEN];
498         uint8           computed_key[SCRAM_KEY_LEN];
499         char       *prep_password;
500         pg_saslprep_rc rc;
501
502         if (!parse_scram_secret(secret, &iterations, &encoded_salt,
503                                                           stored_key, server_key))
504         {
505                 /*
506                  * The password looked like a SCRAM secret, but could not be parsed.
507                  */
508                 ereport(LOG,
509                                 (errmsg("invalid SCRAM secret for user \"%s\"", username)));
510                 return false;
511         }
512
513         saltlen = pg_b64_dec_len(strlen(encoded_salt));
514         salt = palloc(saltlen);
515         saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt,
516                                                         saltlen);
517         if (saltlen < 0)
518         {
519                 ereport(LOG,
520                                 (errmsg("invalid SCRAM secret for user \"%s\"", username)));
521                 return false;
522         }
523
524         /* Normalize the password */
525         rc = pg_saslprep(password, &prep_password);
526         if (rc == SASLPREP_SUCCESS)
527                 password = prep_password;
528
529         /* Compute Server Key based on the user-supplied plaintext password */
530         scram_SaltedPassword(password, salt, saltlen, iterations, salted_password);
531         scram_ServerKey(salted_password, computed_key);
532
533         if (prep_password)
534                 pfree(prep_password);
535
536         /*
537          * Compare the secret's Server Key with the one computed from the
538          * user-supplied password.
539          */
540         return memcmp(computed_key, server_key, SCRAM_KEY_LEN) == 0;
541 }
542
543
544 /*
545  * Parse and validate format of given SCRAM secret.
546  *
547  * On success, the iteration count, salt, stored key, and server key are
548  * extracted from the secret, and returned to the caller.  For 'stored_key'
549  * and 'server_key', the caller must pass pre-allocated buffers of size
550  * SCRAM_KEY_LEN.  Salt is returned as a base64-encoded, null-terminated
551  * string.  The buffer for the salt is palloc'd by this function.
552  *
553  * Returns true if the SCRAM secret has been parsed, and false otherwise.
554  */
555 bool
556 parse_scram_secret(const char *secret, int *iterations, char **salt,
557                                          uint8 *stored_key, uint8 *server_key)
558 {
559         char       *v;
560         char       *p;
561         char       *scheme_str;
562         char       *salt_str;
563         char       *iterations_str;
564         char       *storedkey_str;
565         char       *serverkey_str;
566         int                     decoded_len;
567         char       *decoded_salt_buf;
568         char       *decoded_stored_buf;
569         char       *decoded_server_buf;
570
571         /*
572          * The secret is of form:
573          *
574          * SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey>
575          */
576         v = pstrdup(secret);
577         if ((scheme_str = strtok(v, "$")) == NULL)
578                 goto invalid_secret;
579         if ((iterations_str = strtok(NULL, ":")) == NULL)
580                 goto invalid_secret;
581         if ((salt_str = strtok(NULL, "$")) == NULL)
582                 goto invalid_secret;
583         if ((storedkey_str = strtok(NULL, ":")) == NULL)
584                 goto invalid_secret;
585         if ((serverkey_str = strtok(NULL, "")) == NULL)
586                 goto invalid_secret;
587
588         /* Parse the fields */
589         if (strcmp(scheme_str, "SCRAM-SHA-256") != 0)
590                 goto invalid_secret;
591
592         errno = 0;
593         *iterations = strtol(iterations_str, &p, 10);
594         if (*p || errno != 0)
595                 goto invalid_secret;
596
597         /*
598          * Verify that the salt is in Base64-encoded format, by decoding it,
599          * although we return the encoded version to the caller.
600          */
601         decoded_len = pg_b64_dec_len(strlen(salt_str));
602         decoded_salt_buf = palloc(decoded_len);
603         decoded_len = pg_b64_decode(salt_str, strlen(salt_str),
604                                                                 decoded_salt_buf, decoded_len);
605         if (decoded_len < 0)
606                 goto invalid_secret;
607         *salt = pstrdup(salt_str);
608
609         /*
610          * Decode StoredKey and ServerKey.
611          */
612         decoded_len = pg_b64_dec_len(strlen(storedkey_str));
613         decoded_stored_buf = palloc(decoded_len);
614         decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str),
615                                                                 decoded_stored_buf, decoded_len);
616         if (decoded_len != SCRAM_KEY_LEN)
617                 goto invalid_secret;
618         memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN);
619
620         decoded_len = pg_b64_dec_len(strlen(serverkey_str));
621         decoded_server_buf = palloc(decoded_len);
622         decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
623                                                                 decoded_server_buf, decoded_len);
624         if (decoded_len != SCRAM_KEY_LEN)
625                 goto invalid_secret;
626         memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN);
627
628         return true;
629
630 invalid_secret:
631         *salt = NULL;
632         return false;
633 }
634
635 /*
636  * Generate plausible SCRAM secret parameters for mock authentication.
637  *
638  * In a normal authentication, these are extracted from the secret
639  * stored in the server.  This function generates values that look
640  * realistic, for when there is no stored secret.
641  *
642  * Like in parse_scram_secret(), for 'stored_key' and 'server_key', the
643  * caller must pass pre-allocated buffers of size SCRAM_KEY_LEN, and
644  * the buffer for the salt is palloc'd by this function.
645  */
646 static void
647 mock_scram_secret(const char *username, int *iterations, char **salt,
648                                         uint8 *stored_key, uint8 *server_key)
649 {
650         char       *raw_salt;
651         char       *encoded_salt;
652         int                     encoded_len;
653
654         /* Generate deterministic salt */
655         raw_salt = scram_mock_salt(username);
656
657         encoded_len = pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN);
658         /* don't forget the zero-terminator */
659         encoded_salt = (char *) palloc(encoded_len + 1);
660         encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt,
661                                                                 encoded_len);
662
663         /*
664          * Note that we cannot reveal any information to an attacker here so the
665          * error message needs to remain generic.  This should never fail anyway
666          * as the salt generated for mock authentication uses the cluster's nonce
667          * value.
668          */
669         if (encoded_len < 0)
670                 elog(ERROR, "could not encode salt");
671         encoded_salt[encoded_len] = '\0';
672
673         *salt = encoded_salt;
674         *iterations = SCRAM_DEFAULT_ITERATIONS;
675
676         /* StoredKey and ServerKey are not used in a doomed authentication */
677         memset(stored_key, 0, SCRAM_KEY_LEN);
678         memset(server_key, 0, SCRAM_KEY_LEN);
679 }
680
681 /*
682  * Read the value in a given SCRAM exchange message for given attribute.
683  */
684 static char *
685 read_attr_value(char **input, char attr)
686 {
687         char       *begin = *input;
688         char       *end;
689
690         if (*begin != attr)
691                 ereport(ERROR,
692                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
693                                  errmsg("malformed SCRAM message"),
694                                  errdetail("Expected attribute \"%c\" but found \"%s\".",
695                                                    attr, sanitize_char(*begin))));
696         begin++;
697
698         if (*begin != '=')
699                 ereport(ERROR,
700                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
701                                  errmsg("malformed SCRAM message"),
702                                  errdetail("Expected character \"=\" for attribute \"%c\".", attr)));
703         begin++;
704
705         end = begin;
706         while (*end && *end != ',')
707                 end++;
708
709         if (*end)
710         {
711                 *end = '\0';
712                 *input = end + 1;
713         }
714         else
715                 *input = end;
716
717         return begin;
718 }
719
720 static bool
721 is_scram_printable(char *p)
722 {
723         /*------
724          * Printable characters, as defined by SCRAM spec: (RFC 5802)
725          *
726          *      printable               = %x21-2B / %x2D-7E
727          *                                        ;; Printable ASCII except ",".
728          *                                        ;; Note that any "printable" is also
729          *                                        ;; a valid "value".
730          *------
731          */
732         for (; *p; p++)
733         {
734                 if (*p < 0x21 || *p > 0x7E || *p == 0x2C /* comma */ )
735                         return false;
736         }
737         return true;
738 }
739
740 /*
741  * Convert an arbitrary byte to printable form.  For error messages.
742  *
743  * If it's a printable ASCII character, print it as a single character.
744  * otherwise, print it in hex.
745  *
746  * The returned pointer points to a static buffer.
747  */
748 static char *
749 sanitize_char(char c)
750 {
751         static char buf[5];
752
753         if (c >= 0x21 && c <= 0x7E)
754                 snprintf(buf, sizeof(buf), "'%c'", c);
755         else
756                 snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
757         return buf;
758 }
759
760 /*
761  * Convert an arbitrary string to printable form, for error messages.
762  *
763  * Anything that's not a printable ASCII character is replaced with
764  * '?', and the string is truncated at 30 characters.
765  *
766  * The returned pointer points to a static buffer.
767  */
768 static char *
769 sanitize_str(const char *s)
770 {
771         static char buf[30 + 1];
772         int                     i;
773
774         for (i = 0; i < sizeof(buf) - 1; i++)
775         {
776                 char            c = s[i];
777
778                 if (c == '\0')
779                         break;
780
781                 if (c >= 0x21 && c <= 0x7E)
782                         buf[i] = c;
783                 else
784                         buf[i] = '?';
785         }
786         buf[i] = '\0';
787         return buf;
788 }
789
790 /*
791  * Read the next attribute and value in a SCRAM exchange message.
792  *
793  * The attribute character is set in *attr_p, the attribute value is the
794  * return value.
795  */
796 static char *
797 read_any_attr(char **input, char *attr_p)
798 {
799         char       *begin = *input;
800         char       *end;
801         char            attr = *begin;
802
803         if (attr == '\0')
804                 ereport(ERROR,
805                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
806                                  errmsg("malformed SCRAM message"),
807                                  errdetail("Attribute expected, but found end of string.")));
808
809         /*------
810          * attr-val                = ALPHA "=" value
811          *                                       ;; Generic syntax of any attribute sent
812          *                                       ;; by server or client
813          *------
814          */
815         if (!((attr >= 'A' && attr <= 'Z') ||
816                   (attr >= 'a' && attr <= 'z')))
817                 ereport(ERROR,
818                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
819                                  errmsg("malformed SCRAM message"),
820                                  errdetail("Attribute expected, but found invalid character \"%s\".",
821                                                    sanitize_char(attr))));
822         if (attr_p)
823                 *attr_p = attr;
824         begin++;
825
826         if (*begin != '=')
827                 ereport(ERROR,
828                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
829                                  errmsg("malformed SCRAM message"),
830                                  errdetail("Expected character \"=\" for attribute \"%c\".", attr)));
831         begin++;
832
833         end = begin;
834         while (*end && *end != ',')
835                 end++;
836
837         if (*end)
838         {
839                 *end = '\0';
840                 *input = end + 1;
841         }
842         else
843                 *input = end;
844
845         return begin;
846 }
847
848 /*
849  * Read and parse the first message from client in the context of a SCRAM
850  * authentication exchange message.
851  *
852  * At this stage, any errors will be reported directly with ereport(ERROR).
853  */
854 static void
855 read_client_first_message(scram_state *state, const char *input)
856 {
857         char       *p = pstrdup(input);
858         char       *channel_binding_type;
859
860
861         /*------
862          * The syntax for the client-first-message is: (RFC 5802)
863          *
864          * saslname                = 1*(value-safe-char / "=2C" / "=3D")
865          *                                       ;; Conforms to <value>.
866          *
867          * authzid                 = "a=" saslname
868          *                                       ;; Protocol specific.
869          *
870          * cb-name                 = 1*(ALPHA / DIGIT / "." / "-")
871          *                                        ;; See RFC 5056, Section 7.
872          *                                        ;; E.g., "tls-server-end-point" or
873          *                                        ;; "tls-unique".
874          *
875          * gs2-cbind-flag  = ("p=" cb-name) / "n" / "y"
876          *                                       ;; "n" -> client doesn't support channel binding.
877          *                                       ;; "y" -> client does support channel binding
878          *                                       ;;                but thinks the server does not.
879          *                                       ;; "p" -> client requires channel binding.
880          *                                       ;; The selected channel binding follows "p=".
881          *
882          * gs2-header      = gs2-cbind-flag "," [ authzid ] ","
883          *                                       ;; GS2 header for SCRAM
884          *                                       ;; (the actual GS2 header includes an optional
885          *                                       ;; flag to indicate that the GSS mechanism is not
886          *                                       ;; "standard", but since SCRAM is "standard", we
887          *                                       ;; don't include that flag).
888          *
889          * username                = "n=" saslname
890          *                                       ;; Usernames are prepared using SASLprep.
891          *
892          * reserved-mext  = "m=" 1*(value-char)
893          *                                       ;; Reserved for signaling mandatory extensions.
894          *                                       ;; The exact syntax will be defined in
895          *                                       ;; the future.
896          *
897          * nonce                   = "r=" c-nonce [s-nonce]
898          *                                       ;; Second part provided by server.
899          *
900          * c-nonce                 = printable
901          *
902          * client-first-message-bare =
903          *                                       [reserved-mext ","]
904          *                                       username "," nonce ["," extensions]
905          *
906          * client-first-message =
907          *                                       gs2-header client-first-message-bare
908          *
909          * For example:
910          * n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL
911          *
912          * The "n,," in the beginning means that the client doesn't support
913          * channel binding, and no authzid is given.  "n=user" is the username.
914          * However, in PostgreSQL the username is sent in the startup packet, and
915          * the username in the SCRAM exchange is ignored.  libpq always sends it
916          * as an empty string.  The last part, "r=fyko+d2lbbFgONRv9qkxdawL" is
917          * the client nonce.
918          *------
919          */
920
921         /*
922          * Read gs2-cbind-flag.  (For details see also RFC 5802 Section 6 "Channel
923          * Binding".)
924          */
925         state->cbind_flag = *p;
926         switch (*p)
927         {
928                 case 'n':
929
930                         /*
931                          * The client does not support channel binding or has simply
932                          * decided to not use it.  In that case just let it go.
933                          */
934                         if (state->channel_binding_in_use)
935                                 ereport(ERROR,
936                                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
937                                                  errmsg("malformed SCRAM message"),
938                                                  errdetail("The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data.")));
939
940                         p++;
941                         if (*p != ',')
942                                 ereport(ERROR,
943                                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
944                                                  errmsg("malformed SCRAM message"),
945                                                  errdetail("Comma expected, but found character \"%s\".",
946                                                                    sanitize_char(*p))));
947                         p++;
948                         break;
949                 case 'y':
950
951                         /*
952                          * The client supports channel binding and thinks that the server
953                          * does not.  In this case, the server must fail authentication if
954                          * it supports channel binding.
955                          */
956                         if (state->channel_binding_in_use)
957                                 ereport(ERROR,
958                                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
959                                                  errmsg("malformed SCRAM message"),
960                                                  errdetail("The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data.")));
961
962 #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
963                         if (state->port->ssl_in_use)
964                                 ereport(ERROR,
965                                                 (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
966                                                  errmsg("SCRAM channel binding negotiation error"),
967                                                  errdetail("The client supports SCRAM channel binding but thinks the server does not.  "
968                                                                    "However, this server does support channel binding.")));
969 #endif
970                         p++;
971                         if (*p != ',')
972                                 ereport(ERROR,
973                                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
974                                                  errmsg("malformed SCRAM message"),
975                                                  errdetail("Comma expected, but found character \"%s\".",
976                                                                    sanitize_char(*p))));
977                         p++;
978                         break;
979                 case 'p':
980
981                         /*
982                          * The client requires channel binding.  Channel binding type
983                          * follows, e.g., "p=tls-server-end-point".
984                          */
985                         if (!state->channel_binding_in_use)
986                                 ereport(ERROR,
987                                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
988                                                  errmsg("malformed SCRAM message"),
989                                                  errdetail("The client selected SCRAM-SHA-256 without channel binding, but the SCRAM message includes channel binding data.")));
990
991                         channel_binding_type = read_attr_value(&p, 'p');
992
993                         /*
994                          * The only channel binding type we support is
995                          * tls-server-end-point.
996                          */
997                         if (strcmp(channel_binding_type, "tls-server-end-point") != 0)
998                                 ereport(ERROR,
999                                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
1000                                                  (errmsg("unsupported SCRAM channel-binding type \"%s\"",
1001                                                                  sanitize_str(channel_binding_type)))));
1002                         break;
1003                 default:
1004                         ereport(ERROR,
1005                                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
1006                                          errmsg("malformed SCRAM message"),
1007                                          errdetail("Unexpected channel-binding flag \"%s\".",
1008                                                            sanitize_char(*p))));
1009         }
1010
1011         /*
1012          * Forbid optional authzid (authorization identity).  We don't support it.
1013          */
1014         if (*p == 'a')
1015                 ereport(ERROR,
1016                                 (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
1017                                  errmsg("client uses authorization identity, but it is not supported")));
1018         if (*p != ',')
1019                 ereport(ERROR,
1020                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
1021                                  errmsg("malformed SCRAM message"),
1022                                  errdetail("Unexpected attribute \"%s\" in client-first-message.",
1023                                                    sanitize_char(*p))));
1024         p++;
1025
1026         state->client_first_message_bare = pstrdup(p);
1027
1028         /*
1029          * Any mandatory extensions would go here.  We don't support any.
1030          *
1031          * RFC 5802 specifies error code "e=extensions-not-supported" for this,
1032          * but it can only be sent in the server-final message.  We prefer to fail
1033          * immediately (which the RFC also allows).
1034          */
1035         if (*p == 'm')
1036                 ereport(ERROR,
1037                                 (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
1038                                  errmsg("client requires an unsupported SCRAM extension")));
1039
1040         /*
1041          * Read username.  Note: this is ignored.  We use the username from the
1042          * startup message instead, still it is kept around if provided as it
1043          * proves to be useful for debugging purposes.
1044          */
1045         state->client_username = read_attr_value(&p, 'n');
1046
1047         /* read nonce and check that it is made of only printable characters */
1048         state->client_nonce = read_attr_value(&p, 'r');
1049         if (!is_scram_printable(state->client_nonce))
1050                 ereport(ERROR,
1051                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
1052                                  errmsg("non-printable characters in SCRAM nonce")));
1053
1054         /*
1055          * There can be any number of optional extensions after this.  We don't
1056          * support any extensions, so ignore them.
1057          */
1058         while (*p != '\0')
1059                 read_any_attr(&p, NULL);
1060
1061         /* success! */
1062 }
1063
1064 /*
1065  * Verify the final nonce contained in the last message received from
1066  * client in an exchange.
1067  */
1068 static bool
1069 verify_final_nonce(scram_state *state)
1070 {
1071         int                     client_nonce_len = strlen(state->client_nonce);
1072         int                     server_nonce_len = strlen(state->server_nonce);
1073         int                     final_nonce_len = strlen(state->client_final_nonce);
1074
1075         if (final_nonce_len != client_nonce_len + server_nonce_len)
1076                 return false;
1077         if (memcmp(state->client_final_nonce, state->client_nonce, client_nonce_len) != 0)
1078                 return false;
1079         if (memcmp(state->client_final_nonce + client_nonce_len, state->server_nonce, server_nonce_len) != 0)
1080                 return false;
1081
1082         return true;
1083 }
1084
1085 /*
1086  * Verify the client proof contained in the last message received from
1087  * client in an exchange.
1088  */
1089 static bool
1090 verify_client_proof(scram_state *state)
1091 {
1092         uint8           ClientSignature[SCRAM_KEY_LEN];
1093         uint8           ClientKey[SCRAM_KEY_LEN];
1094         uint8           client_StoredKey[SCRAM_KEY_LEN];
1095         scram_HMAC_ctx ctx;
1096         int                     i;
1097
1098         /* calculate ClientSignature */
1099         scram_HMAC_init(&ctx, state->StoredKey, SCRAM_KEY_LEN);
1100         scram_HMAC_update(&ctx,
1101                                           state->client_first_message_bare,
1102                                           strlen(state->client_first_message_bare));
1103         scram_HMAC_update(&ctx, ",", 1);
1104         scram_HMAC_update(&ctx,
1105                                           state->server_first_message,
1106                                           strlen(state->server_first_message));
1107         scram_HMAC_update(&ctx, ",", 1);
1108         scram_HMAC_update(&ctx,
1109                                           state->client_final_message_without_proof,
1110                                           strlen(state->client_final_message_without_proof));
1111         scram_HMAC_final(ClientSignature, &ctx);
1112
1113         /* Extract the ClientKey that the client calculated from the proof */
1114         for (i = 0; i < SCRAM_KEY_LEN; i++)
1115                 ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
1116
1117         /* Hash it one more time, and compare with StoredKey */
1118         scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey);
1119
1120         if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0)
1121                 return false;
1122
1123         return true;
1124 }
1125
1126 /*
1127  * Build the first server-side message sent to the client in a SCRAM
1128  * communication exchange.
1129  */
1130 static char *
1131 build_server_first_message(scram_state *state)
1132 {
1133         /*------
1134          * The syntax for the server-first-message is: (RFC 5802)
1135          *
1136          * server-first-message =
1137          *                                       [reserved-mext ","] nonce "," salt ","
1138          *                                       iteration-count ["," extensions]
1139          *
1140          * nonce                   = "r=" c-nonce [s-nonce]
1141          *                                       ;; Second part provided by server.
1142          *
1143          * c-nonce                 = printable
1144          *
1145          * s-nonce                 = printable
1146          *
1147          * salt                    = "s=" base64
1148          *
1149          * iteration-count = "i=" posit-number
1150          *                                       ;; A positive number.
1151          *
1152          * Example:
1153          *
1154          * r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096
1155          *------
1156          */
1157
1158         /*
1159          * Per the spec, the nonce may consist of any printable ASCII characters.
1160          * For convenience, however, we don't use the whole range available,
1161          * rather, we generate some random bytes, and base64 encode them.
1162          */
1163         char            raw_nonce[SCRAM_RAW_NONCE_LEN];
1164         int                     encoded_len;
1165
1166         if (!pg_strong_random(raw_nonce, SCRAM_RAW_NONCE_LEN))
1167                 ereport(ERROR,
1168                                 (errcode(ERRCODE_INTERNAL_ERROR),
1169                                  errmsg("could not generate random nonce")));
1170
1171         encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN);
1172         /* don't forget the zero-terminator */
1173         state->server_nonce = palloc(encoded_len + 1);
1174         encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN,
1175                                                                 state->server_nonce, encoded_len);
1176         if (encoded_len < 0)
1177                 ereport(ERROR,
1178                                 (errcode(ERRCODE_INTERNAL_ERROR),
1179                                  errmsg("could not encode random nonce")));
1180         state->server_nonce[encoded_len] = '\0';
1181
1182         state->server_first_message =
1183                 psprintf("r=%s%s,s=%s,i=%u",
1184                                  state->client_nonce, state->server_nonce,
1185                                  state->salt, state->iterations);
1186
1187         return pstrdup(state->server_first_message);
1188 }
1189
1190
1191 /*
1192  * Read and parse the final message received from client.
1193  */
1194 static void
1195 read_client_final_message(scram_state *state, const char *input)
1196 {
1197         char            attr;
1198         char       *channel_binding;
1199         char       *value;
1200         char       *begin,
1201                            *proof;
1202         char       *p;
1203         char       *client_proof;
1204         int                     client_proof_len;
1205
1206         begin = p = pstrdup(input);
1207
1208         /*------
1209          * The syntax for the server-first-message is: (RFC 5802)
1210          *
1211          * gs2-header      = gs2-cbind-flag "," [ authzid ] ","
1212          *                                       ;; GS2 header for SCRAM
1213          *                                       ;; (the actual GS2 header includes an optional
1214          *                                       ;; flag to indicate that the GSS mechanism is not
1215          *                                       ;; "standard", but since SCRAM is "standard", we
1216          *                                       ;; don't include that flag).
1217          *
1218          * cbind-input   = gs2-header [ cbind-data ]
1219          *                                       ;; cbind-data MUST be present for
1220          *                                       ;; gs2-cbind-flag of "p" and MUST be absent
1221          *                                       ;; for "y" or "n".
1222          *
1223          * channel-binding = "c=" base64
1224          *                                       ;; base64 encoding of cbind-input.
1225          *
1226          * proof                   = "p=" base64
1227          *
1228          * client-final-message-without-proof =
1229          *                                       channel-binding "," nonce [","
1230          *                                       extensions]
1231          *
1232          * client-final-message =
1233          *                                       client-final-message-without-proof "," proof
1234          *------
1235          */
1236
1237         /*
1238          * Read channel binding.  This repeats the channel-binding flags and is
1239          * then followed by the actual binding data depending on the type.
1240          */
1241         channel_binding = read_attr_value(&p, 'c');
1242         if (state->channel_binding_in_use)
1243         {
1244 #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
1245                 const char *cbind_data = NULL;
1246                 size_t          cbind_data_len = 0;
1247                 size_t          cbind_header_len;
1248                 char       *cbind_input;
1249                 size_t          cbind_input_len;
1250                 char       *b64_message;
1251                 int                     b64_message_len;
1252
1253                 Assert(state->cbind_flag == 'p');
1254
1255                 /* Fetch hash data of server's SSL certificate */
1256                 cbind_data = be_tls_get_certificate_hash(state->port,
1257                                                                                                  &cbind_data_len);
1258
1259                 /* should not happen */
1260                 if (cbind_data == NULL || cbind_data_len == 0)
1261                         elog(ERROR, "could not get server certificate hash");
1262
1263                 cbind_header_len = strlen("p=tls-server-end-point,,");  /* p=type,, */
1264                 cbind_input_len = cbind_header_len + cbind_data_len;
1265                 cbind_input = palloc(cbind_input_len);
1266                 snprintf(cbind_input, cbind_input_len, "p=tls-server-end-point,,");
1267                 memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
1268
1269                 b64_message_len = pg_b64_enc_len(cbind_input_len);
1270                 /* don't forget the zero-terminator */
1271                 b64_message = palloc(b64_message_len + 1);
1272                 b64_message_len = pg_b64_encode(cbind_input, cbind_input_len,
1273                                                                                 b64_message, b64_message_len);
1274                 if (b64_message_len < 0)
1275                         elog(ERROR, "could not encode channel binding data");
1276                 b64_message[b64_message_len] = '\0';
1277
1278                 /*
1279                  * Compare the value sent by the client with the value expected by the
1280                  * server.
1281                  */
1282                 if (strcmp(channel_binding, b64_message) != 0)
1283                         ereport(ERROR,
1284                                         (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
1285                                          (errmsg("SCRAM channel binding check failed"))));
1286 #else
1287                 /* shouldn't happen, because we checked this earlier already */
1288                 elog(ERROR, "channel binding not supported by this build");
1289 #endif
1290         }
1291         else
1292         {
1293                 /*
1294                  * If we are not using channel binding, the binding data is expected
1295                  * to always be "biws", which is "n,," base64-encoded, or "eSws",
1296                  * which is "y,,".  We also have to check whether the flag is the same
1297                  * one that the client originally sent.
1298                  */
1299                 if (!(strcmp(channel_binding, "biws") == 0 && state->cbind_flag == 'n') &&
1300                         !(strcmp(channel_binding, "eSws") == 0 && state->cbind_flag == 'y'))
1301                         ereport(ERROR,
1302                                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
1303                                          (errmsg("unexpected SCRAM channel-binding attribute in client-final-message"))));
1304         }
1305
1306         state->client_final_nonce = read_attr_value(&p, 'r');
1307
1308         /* ignore optional extensions, read until we find "p" attribute */
1309         do
1310         {
1311                 proof = p - 1;
1312                 value = read_any_attr(&p, &attr);
1313         } while (attr != 'p');
1314
1315         client_proof_len = pg_b64_dec_len(strlen(value));
1316         client_proof = palloc(client_proof_len);
1317         if (pg_b64_decode(value, strlen(value), client_proof,
1318                                           client_proof_len) != SCRAM_KEY_LEN)
1319                 ereport(ERROR,
1320                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
1321                                  errmsg("malformed SCRAM message"),
1322                                  errdetail("Malformed proof in client-final-message.")));
1323         memcpy(state->ClientProof, client_proof, SCRAM_KEY_LEN);
1324         pfree(client_proof);
1325
1326         if (*p != '\0')
1327                 ereport(ERROR,
1328                                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
1329                                  errmsg("malformed SCRAM message"),
1330                                  errdetail("Garbage found at the end of client-final-message.")));
1331
1332         state->client_final_message_without_proof = palloc(proof - begin + 1);
1333         memcpy(state->client_final_message_without_proof, input, proof - begin);
1334         state->client_final_message_without_proof[proof - begin] = '\0';
1335 }
1336
1337 /*
1338  * Build the final server-side message of an exchange.
1339  */
1340 static char *
1341 build_server_final_message(scram_state *state)
1342 {
1343         uint8           ServerSignature[SCRAM_KEY_LEN];
1344         char       *server_signature_base64;
1345         int                     siglen;
1346         scram_HMAC_ctx ctx;
1347
1348         /* calculate ServerSignature */
1349         scram_HMAC_init(&ctx, state->ServerKey, SCRAM_KEY_LEN);
1350         scram_HMAC_update(&ctx,
1351                                           state->client_first_message_bare,
1352                                           strlen(state->client_first_message_bare));
1353         scram_HMAC_update(&ctx, ",", 1);
1354         scram_HMAC_update(&ctx,
1355                                           state->server_first_message,
1356                                           strlen(state->server_first_message));
1357         scram_HMAC_update(&ctx, ",", 1);
1358         scram_HMAC_update(&ctx,
1359                                           state->client_final_message_without_proof,
1360                                           strlen(state->client_final_message_without_proof));
1361         scram_HMAC_final(ServerSignature, &ctx);
1362
1363         siglen = pg_b64_enc_len(SCRAM_KEY_LEN);
1364         /* don't forget the zero-terminator */
1365         server_signature_base64 = palloc(siglen + 1);
1366         siglen = pg_b64_encode((const char *) ServerSignature,
1367                                                    SCRAM_KEY_LEN, server_signature_base64,
1368                                                    siglen);
1369         if (siglen < 0)
1370                 elog(ERROR, "could not encode server signature");
1371         server_signature_base64[siglen] = '\0';
1372
1373         /*------
1374          * The syntax for the server-final-message is: (RFC 5802)
1375          *
1376          * verifier                = "v=" base64
1377          *                                       ;; base-64 encoded ServerSignature.
1378          *
1379          * server-final-message = (server-error / verifier)
1380          *                                       ["," extensions]
1381          *
1382          *------
1383          */
1384         return psprintf("v=%s", server_signature_base64);
1385 }
1386
1387
1388 /*
1389  * Deterministically generate salt for mock authentication, using a SHA256
1390  * hash based on the username and a cluster-level secret key.  Returns a
1391  * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN.
1392  */
1393 static char *
1394 scram_mock_salt(const char *username)
1395 {
1396         pg_sha256_ctx ctx;
1397         static uint8 sha_digest[PG_SHA256_DIGEST_LENGTH];
1398         char       *mock_auth_nonce = GetMockAuthenticationNonce();
1399
1400         /*
1401          * Generate salt using a SHA256 hash of the username and the cluster's
1402          * mock authentication nonce.  (This works as long as the salt length is
1403          * not larger the SHA256 digest length. If the salt is smaller, the caller
1404          * will just ignore the extra data.)
1405          */
1406         StaticAssertStmt(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN,
1407                                          "salt length greater than SHA256 digest length");
1408
1409         pg_sha256_init(&ctx);
1410         pg_sha256_update(&ctx, (uint8 *) username, strlen(username));
1411         pg_sha256_update(&ctx, (uint8 *) mock_auth_nonce, MOCK_AUTH_NONCE_LEN);
1412         pg_sha256_final(&ctx, sha_digest);
1413
1414         return (char *) sha_digest;
1415 }