]> granicus.if.org Git - pgbouncer/blob - src/scram.c
Add SASL and SCRAM support
[pgbouncer] / src / scram.c
1 /*
2  * PgBouncer - Lightweight connection pooler for PostgreSQL.
3  *
4  * Copyright (c) 2007-2009  Marko Kreen, Skype Technologies OÜ
5  *
6  * Permission to use, copy, modify, and/or distribute this software for any
7  * purpose with or without fee is hereby granted, provided that the above
8  * copyright notice and this permission notice appear in all copies.
9  *
10  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17  */
18
19 /*
20  * SCRAM support
21  */
22
23 #include "bouncer.h"
24 #include "scram.h"
25 #include "common/base64.h"
26 #include "common/saslprep.h"
27 #include "common/scram-common.h"
28
29
30 static bool calculate_client_proof(ScramState *scram_state,
31                                    const char *passwd,
32                                    const char *salt,
33                                    int saltlen,
34                                    int iterations,
35                                    const char *client_final_message_without_proof,
36                                    uint8_t *result);
37
38
39 /*
40  * free SCRAM state info after auth is done
41  */
42 void free_scram_state(ScramState *scram_state)
43 {
44         free(scram_state->client_nonce);
45         free(scram_state->client_first_message_bare);
46         free(scram_state->client_final_message_without_proof);
47         free(scram_state->server_nonce);
48         free(scram_state->server_first_message);
49         free(scram_state->SaltedPassword);
50         free(scram_state->salt);
51         memset(scram_state, 0, sizeof(*scram_state));
52 }
53
54 static bool is_scram_printable(char *p)
55 {
56         /*------
57          * Printable characters, as defined by SCRAM spec: (RFC 5802)
58          *
59          *  printable       = %x21-2B / %x2D-7E
60          *                    ;; Printable ASCII except ",".
61          *                    ;; Note that any "printable" is also
62          *                    ;; a valid "value".
63          *------
64          */
65         for (; *p; p++)
66                 if (*p < 0x21 || *p > 0x7E || *p == 0x2C /* comma */ )
67                         return false;
68
69         return true;
70 }
71
72 static char *sanitize_char(char c)
73 {
74         static char buf[5];
75
76         if (c >= 0x21 && c <= 0x7E)
77                 snprintf(buf, sizeof(buf), "'%c'", c);
78         else
79                 snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
80         return buf;
81 }
82
83 /*
84  * Read value for an attribute part of a SCRAM message.
85  */
86 static char *read_attr_value(PgSocket *sk, char **input, char attr)
87 {
88         char *begin = *input;
89         char *end;
90
91         if (*begin != attr)
92         {
93                 slog_error(sk, "malformed SCRAM message (attribute \"%c\" expected)",
94                            attr);
95                 return NULL;
96         }
97         begin++;
98
99         if (*begin != '=')
100         {
101                 slog_error(sk, "malformed SCRAM message (expected \"=\" after attribute \"%c\")",
102                            attr);
103                 return NULL;
104         }
105         begin++;
106
107         end = begin;
108         while (*end && *end != ',')
109                 end++;
110
111         if (*end)
112         {
113                 *end = '\0';
114                 *input = end + 1;
115         }
116         else
117                 *input = end;
118
119         return begin;
120 }
121
122 /*
123  * Read the next attribute and value in a SCRAM exchange message.
124  *
125  * Returns NULL if there is no attribute.
126  */
127 static char *
128 read_any_attr(PgSocket *sk, char **input, char *attr_p)
129 {
130         char *begin = *input;
131         char *end;
132         char attr = *begin;
133
134         if (!((attr >= 'A' && attr <= 'Z') ||
135               (attr >= 'a' && attr <= 'z')))
136         {
137                 slog_error(sk, "malformed SCRAM message (attribute expected, but found invalid character \"%s\")",
138                            sanitize_char(attr));
139                 return NULL;
140         }
141         if (attr_p)
142                 *attr_p = attr;
143         begin++;
144
145         if (*begin != '=')
146         {
147                 slog_error(sk, "malformed SCRAM message (expected character \"=\" after attribute \"%c\")",
148                            attr);
149                 return NULL;
150         }
151         begin++;
152
153         end = begin;
154         while (*end && *end != ',')
155                 end++;
156
157         if (*end)
158         {
159                 *end = '\0';
160                 *input = end + 1;
161         }
162         else
163                 *input = end;
164
165         return begin;
166 }
167
168 /*
169  * Parse and validate format of given SCRAM verifier.
170  *
171  * Returns true if the SCRAM verifier has been parsed, and false otherwise.
172  */
173 static bool parse_scram_verifier(const char *verifier, int *iterations, char **salt,
174                                  uint8_t *stored_key, uint8_t *server_key)
175 {
176         char       *v;
177         char       *p;
178         char       *scheme_str;
179         char       *salt_str;
180         char       *iterations_str;
181         char       *storedkey_str;
182         char       *serverkey_str;
183         int                     decoded_len;
184         char       *decoded_salt_buf;
185         char       *decoded_stored_buf = NULL;
186         char       *decoded_server_buf = NULL;
187
188         /*
189          * The verifier is of form:
190          *
191          * SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey>
192          */
193         v = strdup(verifier);
194         if (!v)
195                 goto invalid_verifier;
196         if ((scheme_str = strtok(v, "$")) == NULL)
197                 goto invalid_verifier;
198         if ((iterations_str = strtok(NULL, ":")) == NULL)
199                 goto invalid_verifier;
200         if ((salt_str = strtok(NULL, "$")) == NULL)
201                 goto invalid_verifier;
202         if ((storedkey_str = strtok(NULL, ":")) == NULL)
203                 goto invalid_verifier;
204         if ((serverkey_str = strtok(NULL, "")) == NULL)
205                 goto invalid_verifier;
206
207         /* Parse the fields */
208         if (strcmp(scheme_str, "SCRAM-SHA-256") != 0)
209                 goto invalid_verifier;
210
211         errno = 0;
212         *iterations = strtol(iterations_str, &p, 10);
213         if (*p || errno != 0)
214                 goto invalid_verifier;
215
216         /*
217          * Verify that the salt is in Base64-encoded format, by decoding it,
218          * although we return the encoded version to the caller.
219          */
220         decoded_salt_buf = malloc(pg_b64_dec_len(strlen(salt_str)));
221         if (!decoded_salt_buf)
222                 goto invalid_verifier;
223         decoded_len = pg_b64_decode(salt_str, strlen(salt_str), decoded_salt_buf);
224         free(decoded_salt_buf);
225         if (decoded_len < 0)
226                 goto invalid_verifier;
227         *salt = strdup(salt_str);
228
229         /*
230          * Decode StoredKey and ServerKey.
231          */
232         decoded_stored_buf = malloc(pg_b64_dec_len(strlen(storedkey_str)));
233         if (!decoded_stored_buf)
234                 goto invalid_verifier;
235         decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str), decoded_stored_buf);
236         if (decoded_len != SCRAM_KEY_LEN)
237                 goto invalid_verifier;
238         memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN);
239
240         decoded_server_buf = malloc(pg_b64_dec_len(strlen(serverkey_str)));
241         decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
242                                     decoded_server_buf);
243         if (decoded_len != SCRAM_KEY_LEN)
244                 goto invalid_verifier;
245         memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN);
246
247         free(decoded_stored_buf);
248         free(decoded_server_buf);
249         free(v);
250         return true;
251
252 invalid_verifier:
253         free(decoded_stored_buf);
254         free(decoded_server_buf);
255         free(v);
256         *salt = NULL;
257         return false;
258 }
259
260 #define MD5_PASSWD_CHARSET "0123456789abcdef"
261
262 /*
263  * What kind of a password verifier is 'shadow_pass'?
264  */
265 PasswordType
266 get_password_type(const char *shadow_pass)
267 {
268         char *encoded_salt = NULL;
269         int iterations;
270         uint8_t stored_key[SCRAM_KEY_LEN];
271         uint8_t server_key[SCRAM_KEY_LEN];
272
273         if (strncmp(shadow_pass, "md5", 3) == 0 &&
274             strlen(shadow_pass) == MD5_PASSWD_LEN &&
275             strspn(shadow_pass + 3, MD5_PASSWD_CHARSET) == MD5_PASSWD_LEN - 3)
276                 return PASSWORD_TYPE_MD5;
277         if (parse_scram_verifier(shadow_pass, &iterations, &encoded_salt,
278                                  stored_key, server_key)) {
279                 free(encoded_salt);
280                 return PASSWORD_TYPE_SCRAM_SHA_256;
281         }
282         free(encoded_salt);
283         return PASSWORD_TYPE_PLAINTEXT;
284 }
285
286 /*
287  * Functions for communicating as a client with the server
288  */
289
290 char *build_client_first_message(ScramState *scram_state)
291 {
292         uint8_t raw_nonce[SCRAM_RAW_NONCE_LEN + 1];
293         int encoded_len;
294         size_t len;
295         char *result = NULL;
296
297         get_random_bytes(raw_nonce, SCRAM_RAW_NONCE_LEN);
298
299         scram_state->client_nonce = malloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1);
300         if (scram_state->client_nonce == NULL)
301                 goto failed;
302         encoded_len = pg_b64_encode((char *) raw_nonce, SCRAM_RAW_NONCE_LEN, scram_state->client_nonce);
303         scram_state->client_nonce[encoded_len] = '\0';
304
305         len = 8 + strlen(scram_state->client_nonce) + 1;
306         result = malloc(len);
307         if (result == NULL)
308                 goto failed;
309         snprintf(result, len, "n,,n=,r=%s", scram_state->client_nonce);
310
311         scram_state->client_first_message_bare = strdup(result + 3);
312         if (scram_state->client_first_message_bare == NULL)
313                 goto failed;
314
315         return result;
316
317 failed:
318         free(result);
319         free(scram_state->client_nonce);
320         free(scram_state->client_first_message_bare);
321         return NULL;
322 }
323
324 char *build_client_final_message(ScramState *scram_state,
325                                  const char *passwd,
326                                  const char *server_nonce,
327                                  const char *salt,
328                                  int saltlen,
329                                  int iterations)
330 {
331         char buf[512];
332         size_t len;
333         uint8_t client_proof[SCRAM_KEY_LEN];
334
335         len = snprintf(buf, sizeof(buf), "c=biws,r=%s", server_nonce);
336
337         scram_state->client_final_message_without_proof = strdup(buf);
338         if (scram_state->client_final_message_without_proof == NULL)
339                 goto failed;
340
341         if (!calculate_client_proof(scram_state, passwd,
342                                     salt, saltlen, iterations, buf,
343                                     client_proof))
344                 goto failed;
345
346         len = strlcat(buf, ",p=", sizeof(buf));
347         len += pg_b64_encode((char *) client_proof,
348                              SCRAM_KEY_LEN,
349                              buf + len);
350         buf[len] = '\0';
351
352         return strdup(buf);
353 failed:
354         return NULL;
355 }
356
357 bool read_server_first_message(PgSocket *server, char *input,
358                                char **server_nonce_p, char **salt_p, int *saltlen_p, int *iterations_p)
359 {
360         char *server_nonce;
361         char *encoded_salt;
362         char *salt = NULL;
363         int saltlen;
364         char *iterations_str;
365         char *endptr;
366         int iterations;
367
368         server->scram_state.server_first_message = strdup(input);
369         if (server->scram_state.server_first_message == NULL)
370                 goto failed;
371
372         server_nonce = read_attr_value(server, &input, 'r');
373         if (server_nonce == NULL)
374                 goto failed;
375
376         if (strlen(server_nonce) < strlen(server->scram_state.client_nonce) ||
377             memcmp(server_nonce, server->scram_state.client_nonce, strlen(server->scram_state.client_nonce)) != 0)
378         {
379                 slog_error(server, "invalid SCRAM response (nonce mismatch)");
380                 goto failed;
381         }
382
383         encoded_salt = read_attr_value(server, &input, 's');
384         if (encoded_salt == NULL)
385                 goto failed;
386         salt = malloc(pg_b64_dec_len(strlen(encoded_salt)));
387         if (salt == NULL)
388                 goto failed;
389         saltlen = pg_b64_decode(encoded_salt,
390                                 strlen(encoded_salt),
391                                 salt);
392         if (saltlen < 0)
393         {
394                 slog_error(server, "malformed SCRAM message (invalid salt)");
395                 goto failed;
396         }
397
398         iterations_str = read_attr_value(server, &input, 'i');
399         if (iterations_str == NULL)
400                 goto failed;
401
402         iterations = strtol(iterations_str, &endptr, 10);
403         if (*endptr != '\0' || iterations < 1)
404         {
405                 slog_error(server, "malformed SCRAM message (invalid iteration count)");
406                 goto failed;
407         }
408
409         if (*input != '\0')
410         {
411                 slog_error(server, "malformed SCRAM message (garbage at end of server-first-message)");
412                 goto failed;
413         }
414
415         *server_nonce_p = server_nonce;
416         *salt_p = salt;
417         *saltlen_p = saltlen;
418         *iterations_p = iterations;
419         return true;
420 failed:
421         free(salt);
422         return false;
423 }
424
425 bool read_server_final_message(PgSocket *server, char *input, char *ServerSignature)
426 {
427         char *encoded_server_signature;
428         char *decoded_server_signature = NULL;
429         int server_signature_len;
430
431         if (*input == 'e')
432         {
433                 char *errmsg = read_attr_value(server, &input, 'e');
434                 slog_error(server, "error received from server in SCRAM exchange: %s",
435                            errmsg);
436                 goto failed;
437         }
438
439         encoded_server_signature = read_attr_value(server, &input, 'v');
440         if (encoded_server_signature == NULL)
441                 goto failed;
442
443         if (*input != '\0')
444                 slog_error(server, "malformed SCRAM message (garbage at end of server-final-message)");
445
446         server_signature_len = pg_b64_dec_len(strlen(encoded_server_signature));
447         decoded_server_signature = malloc(server_signature_len);
448         if (!decoded_server_signature)
449                 goto failed;
450
451         server_signature_len = pg_b64_decode(encoded_server_signature,
452                                              strlen(encoded_server_signature),
453                                              decoded_server_signature);
454         if (server_signature_len != SCRAM_KEY_LEN)
455         {
456                 slog_error(server, "malformed SCRAM message (malformed server signature)");
457                 goto failed;
458         }
459         memcpy(ServerSignature, decoded_server_signature, SCRAM_KEY_LEN);
460
461         free(decoded_server_signature);
462         return true;
463 failed:
464         free(decoded_server_signature);
465         return false;
466 }
467
468 static bool calculate_client_proof(ScramState *scram_state,
469                                    const char *passwd,
470                                    const char *salt,
471                                    int saltlen,
472                                    int iterations,
473                                    const char *client_final_message_without_proof,
474                                    uint8_t *result)
475 {
476         pg_saslprep_rc rc;
477         char *prep_password = NULL;
478         uint8_t StoredKey[SCRAM_KEY_LEN];
479         uint8_t ClientKey[SCRAM_KEY_LEN];
480         uint8_t ClientSignature[SCRAM_KEY_LEN];
481         scram_HMAC_ctx ctx;
482
483         rc = pg_saslprep(passwd, &prep_password);
484         if (rc == SASLPREP_OOM)
485                 false;
486         if (rc != SASLPREP_SUCCESS)
487         {
488                 prep_password = strdup(passwd);
489                 if (!prep_password)
490                         return false;
491         }
492
493         scram_state->SaltedPassword = malloc(SCRAM_KEY_LEN);
494         if (scram_state->SaltedPassword == NULL)
495                 goto failed;
496         scram_SaltedPassword(prep_password,
497                              salt,
498                              saltlen,
499                              iterations,
500                              scram_state->SaltedPassword);
501
502         scram_ClientKey(scram_state->SaltedPassword, ClientKey);
503         scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey);
504
505         scram_HMAC_init(&ctx, StoredKey, SCRAM_KEY_LEN);
506         scram_HMAC_update(&ctx,
507                           scram_state->client_first_message_bare,
508                           strlen(scram_state->client_first_message_bare));
509         scram_HMAC_update(&ctx, ",", 1);
510         scram_HMAC_update(&ctx,
511                           scram_state->server_first_message,
512                           strlen(scram_state->server_first_message));
513         scram_HMAC_update(&ctx, ",", 1);
514         scram_HMAC_update(&ctx,
515                           client_final_message_without_proof,
516                           strlen(client_final_message_without_proof));
517         scram_HMAC_final(ClientSignature, &ctx);
518
519         for (int i = 0; i < SCRAM_KEY_LEN; i++)
520                 result[i] = ClientKey[i] ^ ClientSignature[i];
521
522         free(prep_password);
523         return true;
524 failed:
525         free(prep_password);
526         return false;
527 }
528
529 bool verify_server_signature(ScramState *scram_state, const char *ServerSignature)
530 {
531         uint8_t expected_ServerSignature[SCRAM_KEY_LEN];
532         uint8_t ServerKey[SCRAM_KEY_LEN];
533         scram_HMAC_ctx ctx;
534
535         scram_ServerKey(scram_state->SaltedPassword, ServerKey);
536
537         scram_HMAC_init(&ctx, ServerKey, SCRAM_KEY_LEN);
538         scram_HMAC_update(&ctx,
539                           scram_state->client_first_message_bare,
540                           strlen(scram_state->client_first_message_bare));
541         scram_HMAC_update(&ctx, ",", 1);
542         scram_HMAC_update(&ctx,
543                           scram_state->server_first_message,
544                           strlen(scram_state->server_first_message));
545         scram_HMAC_update(&ctx, ",", 1);
546         scram_HMAC_update(&ctx,
547                           scram_state->client_final_message_without_proof,
548                           strlen(scram_state->client_final_message_without_proof));
549         scram_HMAC_final(expected_ServerSignature, &ctx);
550
551         if (memcmp(expected_ServerSignature, ServerSignature, SCRAM_KEY_LEN) != 0)
552                 return false;
553
554         return true;
555 }
556
557
558 /*
559  * Functions for communicating as a server to the client
560  */
561
562 bool read_client_first_message(PgSocket *client, char *input,
563                                char **client_first_message_bare_p,
564                                char **client_nonce_p)
565 {
566         char *client_first_message_bare = NULL;
567         char *client_nonce = NULL;
568
569         switch (*input) {
570         case 'n':
571                 /* Client does not support channel binding */
572                 input++;
573                 break;
574         case 'y':
575                 /* Client supports channel binding, but we're not doing it today */
576                 input++;
577                 break;
578         case 'p':
579                 /* Client requires channel binding.  We don't support it. */
580                 slog_error(client, "client requires SCRAM channel binding, but it is not supported");
581                 goto failed;
582         default:
583                 slog_error(client, "malformed SCRAM message (unexpected channel-binding flag \"%s\")",
584                            sanitize_char(*input));
585                 goto failed;
586         }
587
588         if (*input != ',') {
589                 slog_error(client, "malformed SCRAM message (comma expected, but found character \"%s\")",
590                            sanitize_char(*input));
591                 goto failed;
592         }
593         input++;
594
595         if (*input == 'a') {
596                 slog_error(client, "client uses authorization identity, but it is not supported");
597                 goto failed;
598         }
599         if (*input != ',') {
600                 slog_error(client, "malformed SCRAM message (unexpected attribute \"%s\" in client-first-message)",
601                            sanitize_char(*input));
602                 goto failed;
603         }
604         input++;
605
606         client_first_message_bare = strdup(input);
607         if (client_first_message_bare == NULL)
608                 goto failed;
609
610         if (*input == 'm') {
611                 slog_error(client, "client requires an unsupported SCRAM extension");
612                 goto failed;
613         }
614
615         /* read and ignore user name */
616         read_attr_value(client, &input, 'n');
617
618         client_nonce = read_attr_value(client, &input, 'r');
619         if (client_nonce == NULL)
620                 goto failed;
621         if (!is_scram_printable(client_nonce)) {
622                 slog_error(client, "non-printable characters in SCRAM nonce");
623                 goto failed;
624         }
625         client_nonce = strdup(client_nonce);
626         if (client_nonce == NULL)
627                 goto failed;
628
629         /*
630          * There can be any number of optional extensions after this.  We don't
631          * support any extensions, so ignore them.
632          */
633         while (*input != '\0')
634                 read_any_attr(client, &input, NULL);
635
636         *client_first_message_bare_p = client_first_message_bare;
637         *client_nonce_p = client_nonce;
638         return true;
639 failed:
640         free(client_first_message_bare);
641         return false;
642 }
643
644 bool read_client_final_message(PgSocket *client, const uint8_t *raw_input, char *input,
645                                const char **client_final_nonce_p,
646                                char **proof_p)
647 {
648         const char *input_start = input;
649         char attr;
650         char *channel_binding;
651         char *client_final_nonce;
652         char *proof_start;
653         char *encoded_proof;
654         char *proof = NULL;
655         int prooflen;
656
657         /*
658          * Read channel-binding.  We don't support channel binding, so
659          * it's expected to always be "biws", which is "n,,",
660          * base64-encoded.
661          */
662         channel_binding = read_attr_value(client, &input, 'c');
663         if (channel_binding == NULL)
664                 goto failed;
665         if (strcmp(channel_binding, "biws") != 0) {
666                 slog_error(client, "unexpected SCRAM channel-binding attribute in client-final-message");
667                 goto failed;
668         }
669
670         client_final_nonce = read_attr_value(client, &input, 'r');
671
672         /* ignore optional extensions */
673         do
674         {
675                 proof_start = input - 1;
676                 encoded_proof = read_any_attr(client, &input, &attr);
677         } while (attr != 'p');
678
679         if (!encoded_proof) {
680                 slog_error(client, "could not read proof");
681                 goto failed;
682         }
683
684         proof = malloc(pg_b64_dec_len(strlen(encoded_proof)));
685         if (proof == NULL) {
686                 slog_error(client, "could not decode proof");
687                 goto failed;
688         }
689         prooflen = pg_b64_decode(encoded_proof,
690                                  strlen(encoded_proof),
691                                  proof);
692         (void) prooflen;
693
694         if (*input != '\0') {
695                 slog_error(client, "malformed SCRAM message (garbage at the end of client-final-message)");
696                 goto failed;
697         }
698
699         client->scram_state.client_final_message_without_proof = malloc(proof_start - input_start + 1);
700         if (!client->scram_state.client_final_message_without_proof)
701                 goto failed;
702         memcpy(client->scram_state.client_final_message_without_proof, raw_input, proof_start - input_start);
703         client->scram_state.client_final_message_without_proof[proof_start - input_start] = '\0';
704
705         *client_final_nonce_p = client_final_nonce;
706         *proof_p = proof;
707         return true;
708 failed:
709         free(proof);
710         return false;
711 }
712
713 /*
714  * For doing SCRAM with a password stored in plain text, build a SCRAM
715  * secret on the fly.
716  */
717 static bool build_adhoc_scram_secret(const char *plain_password, ScramState *scram_state)
718 {
719         const char *password;
720         char *prep_password;
721         pg_saslprep_rc rc;
722         char saltbuf[SCRAM_DEFAULT_SALT_LEN];
723         int encoded_len;
724         uint8_t salted_password[SCRAM_KEY_LEN];
725
726         rc = pg_saslprep(plain_password, &prep_password);
727         if (rc == SASLPREP_OOM)
728                 goto failed;
729         else if (rc == SASLPREP_SUCCESS)
730                 password = prep_password;
731         else
732                 password = plain_password;
733
734         get_random_bytes((uint8_t *) saltbuf, sizeof(saltbuf));
735
736         scram_state->iterations = SCRAM_DEFAULT_ITERATIONS;
737
738         scram_state->salt = malloc(pg_b64_enc_len(sizeof(saltbuf)) + 1);
739         if (!scram_state->salt)
740                 goto failed;
741         encoded_len = pg_b64_encode(saltbuf, sizeof(saltbuf), scram_state->salt);
742         scram_state->salt[encoded_len] = '\0';
743
744         /* Calculate StoredKey and ServerKey */
745         scram_SaltedPassword(password, saltbuf, sizeof(saltbuf),
746                              scram_state->iterations,
747                              salted_password);
748         scram_ClientKey(salted_password, scram_state->StoredKey);
749         scram_H(scram_state->StoredKey, SCRAM_KEY_LEN, scram_state->StoredKey);
750         scram_ServerKey(salted_password, scram_state->ServerKey);
751
752         if (prep_password)
753                 free(prep_password);
754         return true;
755 failed:
756         if (prep_password)
757                 free(prep_password);
758         return false;
759 }
760
761 char *build_server_first_message(ScramState *scram_state, const char *stored_secret)
762 {
763         uint8_t raw_nonce[SCRAM_RAW_NONCE_LEN + 1];
764         int encoded_len;
765         size_t len;
766         char *result;
767
768         switch (get_password_type(stored_secret)) {
769         case PASSWORD_TYPE_SCRAM_SHA_256:
770                 if (!parse_scram_verifier(stored_secret,
771                                           &scram_state->iterations,
772                                           &scram_state->salt,
773                                           scram_state->StoredKey,
774                                           scram_state->ServerKey))
775                         goto failed;
776                 break;
777         case PASSWORD_TYPE_PLAINTEXT:
778                 if (!build_adhoc_scram_secret(stored_secret, scram_state))
779                         goto failed;
780                 break;
781         default:
782                 /* shouldn't get here */
783                 goto failed;
784         }
785
786         get_random_bytes(raw_nonce, SCRAM_RAW_NONCE_LEN);
787         scram_state->server_nonce = malloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1);
788         if (scram_state->server_nonce == NULL)
789                 goto failed;
790         encoded_len = pg_b64_encode((char *) raw_nonce, SCRAM_RAW_NONCE_LEN, scram_state->server_nonce);
791         scram_state->server_nonce[encoded_len] = '\0';
792
793         len = (2
794                + strlen(scram_state->client_nonce)
795                + strlen(scram_state->server_nonce)
796                + 3
797                + strlen(scram_state->salt)
798                + 3 + 10 + 1);
799         result = malloc(len);
800         if (!result)
801                 goto failed;
802         snprintf(result, len,
803                  "r=%s%s,s=%s,i=%u",
804                  scram_state->client_nonce,
805                  scram_state->server_nonce,
806                  scram_state->salt,
807                  scram_state->iterations);
808
809         scram_state->server_first_message = result;
810
811         return result;
812 failed:
813         free(scram_state->server_nonce);
814         free(scram_state->server_first_message);
815         return NULL;
816 }
817
818 static char *
819 compute_server_signature(ScramState *state)
820 {
821         uint8_t         ServerSignature[SCRAM_KEY_LEN];
822         char       *server_signature_base64;
823         int                     siglen;
824         scram_HMAC_ctx ctx;
825
826         /* calculate ServerSignature */
827         scram_HMAC_init(&ctx, state->ServerKey, SCRAM_KEY_LEN);
828         scram_HMAC_update(&ctx,
829                           state->client_first_message_bare,
830                           strlen(state->client_first_message_bare));
831         scram_HMAC_update(&ctx, ",", 1);
832         scram_HMAC_update(&ctx,
833                           state->server_first_message,
834                           strlen(state->server_first_message));
835         scram_HMAC_update(&ctx, ",", 1);
836         scram_HMAC_update(&ctx,
837                           state->client_final_message_without_proof,
838                           strlen(state->client_final_message_without_proof));
839         scram_HMAC_final(ServerSignature, &ctx);
840
841         server_signature_base64 = malloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1);
842         if (!server_signature_base64)
843                 return NULL;
844         siglen = pg_b64_encode((const char *) ServerSignature,
845                                                    SCRAM_KEY_LEN, server_signature_base64);
846         server_signature_base64[siglen] = '\0';
847
848         return server_signature_base64;
849 }
850
851 char *build_server_final_message(ScramState *scram_state)
852 {
853         char *server_signature = NULL;
854         size_t len;
855         char *result;
856
857         server_signature = compute_server_signature(scram_state);
858         if (!server_signature)
859                 goto failed;
860
861         len = 2 + strlen(server_signature) + 1;
862         result = malloc(len);
863         if (!result)
864                 goto failed;
865         snprintf(result, len, "v=%s", server_signature);
866
867         free(server_signature);
868         return result;
869 failed:
870         free(server_signature);
871         return NULL;
872 }
873
874 bool verify_final_nonce(const ScramState *scram_state, const char *client_final_nonce)
875 {
876         size_t client_nonce_len = strlen(scram_state->client_nonce);
877         size_t server_nonce_len = strlen(scram_state->server_nonce);
878         size_t final_nonce_len = strlen(client_final_nonce);
879
880         if (final_nonce_len != client_nonce_len + server_nonce_len)
881                 return false;
882         if (memcmp(client_final_nonce, scram_state->client_nonce, client_nonce_len) != 0)
883                 return false;
884         if (memcmp(client_final_nonce + client_nonce_len, scram_state->server_nonce, server_nonce_len) != 0)
885                 return false;
886
887         return true;
888 }
889
890 bool verify_client_proof(ScramState *state, const char *ClientProof)
891 {
892     uint8_t ClientSignature[SCRAM_KEY_LEN];
893     uint8_t ClientKey[SCRAM_KEY_LEN];
894     uint8_t client_StoredKey[SCRAM_KEY_LEN];
895     scram_HMAC_ctx ctx;
896     int i;
897
898     /* calculate ClientSignature */
899     scram_HMAC_init(&ctx, state->StoredKey, SCRAM_KEY_LEN);
900     scram_HMAC_update(&ctx,
901                       state->client_first_message_bare,
902                       strlen(state->client_first_message_bare));
903     scram_HMAC_update(&ctx, ",", 1);
904     scram_HMAC_update(&ctx,
905                       state->server_first_message,
906                       strlen(state->server_first_message));
907     scram_HMAC_update(&ctx, ",", 1);
908     scram_HMAC_update(&ctx,
909                       state->client_final_message_without_proof,
910                       strlen(state->client_final_message_without_proof));
911     scram_HMAC_final(ClientSignature, &ctx);
912
913     /* Extract the ClientKey that the client calculated from the proof */
914     for (i = 0; i < SCRAM_KEY_LEN; i++)
915             ClientKey[i] = ClientProof[i] ^ ClientSignature[i];
916
917     /* Hash it one more time, and compare with StoredKey */
918     scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey);
919
920     if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0)
921             return false;
922
923     return true;
924 }
925
926 /*
927  * Verify a plaintext password against a SCRAM verifier.  This is used when
928  * performing plaintext password authentication for a user that has a SCRAM
929  * verifier stored in pg_authid.
930  */
931 bool
932 scram_verify_plain_password(PgSocket *client,
933                             const char *username, const char *password,
934                             const char *verifier)
935 {
936         char *encoded_salt = NULL;
937         char *salt = NULL;
938         int saltlen;
939         int iterations;
940         uint8_t salted_password[SCRAM_KEY_LEN];
941         uint8_t stored_key[SCRAM_KEY_LEN];
942         uint8_t server_key[SCRAM_KEY_LEN];
943         uint8_t computed_key[SCRAM_KEY_LEN];
944         char *prep_password = NULL;
945         pg_saslprep_rc rc;
946         bool result = false;
947
948         if (!parse_scram_verifier(verifier, &iterations, &encoded_salt,
949                                   stored_key, server_key))
950         {
951                 /* The password looked like a SCRAM verifier, but could not be parsed. */
952                 slog_warning(client, "invalid SCRAM verifier for user \"%s\"", username);
953                 goto failed;
954         }
955
956         salt = malloc(pg_b64_dec_len(strlen(encoded_salt)));
957         if (!salt)
958                 goto failed;
959         saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt);
960         if (saltlen == -1)
961         {
962                 slog_warning(client, "invalid SCRAM verifier for user \"%s\"", username);
963                 goto failed;
964         }
965
966         /* Normalize the password */
967         rc = pg_saslprep(password, &prep_password);
968         if (rc == SASLPREP_SUCCESS)
969                 password = prep_password;
970
971         /* Compute Server Key based on the user-supplied plaintext password */
972         scram_SaltedPassword(password, salt, saltlen, iterations, salted_password);
973         scram_ServerKey(salted_password, computed_key);
974
975         /*
976          * Compare the verifier's Server Key with the one computed from the
977          * user-supplied password.
978          */
979         result = memcmp(computed_key, server_key, SCRAM_KEY_LEN) == 0;
980
981 failed:
982         free(encoded_salt);
983         free(salt);
984         free(prep_password);
985         return result;
986 }