return false;
}
- salt = palloc(pg_b64_dec_len(strlen(encoded_salt)));
- saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt);
- if (saltlen == -1)
+ saltlen = pg_b64_dec_len(strlen(encoded_salt));
+ salt = palloc(saltlen);
+ saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt,
+ saltlen);
+ if (saltlen < 0)
{
ereport(LOG,
(errmsg("invalid SCRAM verifier for user \"%s\"", username)));
* Verify that the salt is in Base64-encoded format, by decoding it,
* although we return the encoded version to the caller.
*/
- decoded_salt_buf = palloc(pg_b64_dec_len(strlen(salt_str)));
+ decoded_len = pg_b64_dec_len(strlen(salt_str));
+ decoded_salt_buf = palloc(decoded_len);
decoded_len = pg_b64_decode(salt_str, strlen(salt_str),
- decoded_salt_buf);
+ decoded_salt_buf, decoded_len);
if (decoded_len < 0)
goto invalid_verifier;
*salt = pstrdup(salt_str);
/*
* Decode StoredKey and ServerKey.
*/
- decoded_stored_buf = palloc(pg_b64_dec_len(strlen(storedkey_str)));
+ decoded_len = pg_b64_dec_len(strlen(storedkey_str));
+ decoded_stored_buf = palloc(decoded_len);
decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str),
- decoded_stored_buf);
+ decoded_stored_buf, decoded_len);
if (decoded_len != SCRAM_KEY_LEN)
goto invalid_verifier;
memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN);
- decoded_server_buf = palloc(pg_b64_dec_len(strlen(serverkey_str)));
+ decoded_len = pg_b64_dec_len(strlen(serverkey_str));
+ decoded_server_buf = palloc(decoded_len);
decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
- decoded_server_buf);
+ decoded_server_buf, decoded_len);
if (decoded_len != SCRAM_KEY_LEN)
goto invalid_verifier;
memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN);
/* Generate deterministic salt */
raw_salt = scram_mock_salt(username);
- encoded_salt = (char *) palloc(pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN) + 1);
- encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt);
+ encoded_len = pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN);
+ /* don't forget the zero-terminator */
+ encoded_salt = (char *) palloc(encoded_len + 1);
+ encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt,
+ encoded_len);
+
+ /*
+ * Note that we cannot reveal any information to an attacker here so the
+ * error message needs to remain generic. This should never fail anyway
+ * as the salt generated for mock authentication uses the cluster's nonce
+ * value.
+ */
+ if (encoded_len < 0)
+ elog(ERROR, "could not encode salt");
encoded_salt[encoded_len] = '\0';
*salt = encoded_salt;
(errcode(ERRCODE_INTERNAL_ERROR),
errmsg("could not generate random nonce")));
- state->server_nonce = palloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1);
- encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN, state->server_nonce);
+ encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN);
+ /* don't forget the zero-terminator */
+ state->server_nonce = palloc(encoded_len + 1);
+ encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN,
+ state->server_nonce, encoded_len);
+ if (encoded_len < 0)
+ ereport(ERROR,
+ (errcode(ERRCODE_INTERNAL_ERROR),
+ errmsg("could not encode random nonce")));
state->server_nonce[encoded_len] = '\0';
state->server_first_message =
*proof;
char *p;
char *client_proof;
+ int client_proof_len;
begin = p = pstrdup(input);
snprintf(cbind_input, cbind_input_len, "p=tls-server-end-point,,");
memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
- b64_message = palloc(pg_b64_enc_len(cbind_input_len) + 1);
+ b64_message_len = pg_b64_enc_len(cbind_input_len);
+ /* don't forget the zero-terminator */
+ b64_message = palloc(b64_message_len + 1);
b64_message_len = pg_b64_encode(cbind_input, cbind_input_len,
- b64_message);
+ b64_message, b64_message_len);
+ if (b64_message_len < 0)
+ elog(ERROR, "could not encode channel binding data");
b64_message[b64_message_len] = '\0';
/*
value = read_any_attr(&p, &attr);
} while (attr != 'p');
- client_proof = palloc(pg_b64_dec_len(strlen(value)));
- if (pg_b64_decode(value, strlen(value), client_proof) != SCRAM_KEY_LEN)
+ client_proof_len = pg_b64_dec_len(strlen(value));
+ client_proof = palloc(client_proof_len);
+ if (pg_b64_decode(value, strlen(value), client_proof,
+ client_proof_len) != SCRAM_KEY_LEN)
ereport(ERROR,
(errcode(ERRCODE_PROTOCOL_VIOLATION),
errmsg("malformed SCRAM message"),
strlen(state->client_final_message_without_proof));
scram_HMAC_final(ServerSignature, &ctx);
- server_signature_base64 = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1);
+ siglen = pg_b64_enc_len(SCRAM_KEY_LEN);
+ /* don't forget the zero-terminator */
+ server_signature_base64 = palloc(siglen + 1);
siglen = pg_b64_encode((const char *) ServerSignature,
- SCRAM_KEY_LEN, server_signature_base64);
+ SCRAM_KEY_LEN, server_signature_base64,
+ siglen);
+ if (siglen < 0)
+ elog(ERROR, "could not encode server signature");
server_signature_base64[siglen] = '\0';
/*------
* pg_b64_encode
*
* Encode into base64 the given string. Returns the length of the encoded
- * string.
+ * string, and -1 in the event of an error with the result buffer zeroed
+ * for safety.
*/
int
-pg_b64_encode(const char *src, int len, char *dst)
+pg_b64_encode(const char *src, int len, char *dst, int dstlen)
{
char *p;
const char *s,
/* write it out */
if (pos < 0)
{
+ /*
+ * Leave if there is an overflow in the area allocated for the
+ * encoded string.
+ */
+ if ((p - dst + 4) > dstlen)
+ goto error;
+
*p++ = _base64[(buf >> 18) & 0x3f];
*p++ = _base64[(buf >> 12) & 0x3f];
*p++ = _base64[(buf >> 6) & 0x3f];
}
if (pos != 2)
{
+ /*
+ * Leave if there is an overflow in the area allocated for the encoded
+ * string.
+ */
+ if ((p - dst + 4) > dstlen)
+ goto error;
+
*p++ = _base64[(buf >> 18) & 0x3f];
*p++ = _base64[(buf >> 12) & 0x3f];
*p++ = (pos == 0) ? _base64[(buf >> 6) & 0x3f] : '=';
*p++ = '=';
}
+ Assert((p - dst) <= dstlen);
return p - dst;
+
+error:
+ memset(dst, 0, dstlen);
+ return -1;
}
/*
* pg_b64_decode
*
* Decode the given base64 string. Returns the length of the decoded
- * string on success, and -1 in the event of an error.
+ * string on success, and -1 in the event of an error with the result
+ * buffer zeroed for safety.
*/
int
-pg_b64_decode(const char *src, int len, char *dst)
+pg_b64_decode(const char *src, int len, char *dst, int dstlen)
{
const char *srcend = src + len,
*s = src;
/* Leave if a whitespace is found */
if (c == ' ' || c == '\t' || c == '\n' || c == '\r')
- return -1;
+ goto error;
if (c == '=')
{
* Unexpected "=" character found while decoding base64
* sequence.
*/
- return -1;
+ goto error;
}
}
b = 0;
if (b < 0)
{
/* invalid symbol found */
- return -1;
+ goto error;
}
}
/* add it to buffer */
pos++;
if (pos == 4)
{
+ /*
+ * Leave if there is an overflow in the area allocated for the
+ * decoded string.
+ */
+ if ((p - dst + 1) > dstlen)
+ goto error;
*p++ = (buf >> 16) & 255;
+
if (end == 0 || end > 1)
+ {
+ /* overflow check */
+ if ((p - dst + 1) > dstlen)
+ goto error;
*p++ = (buf >> 8) & 255;
+ }
if (end == 0 || end > 2)
+ {
+ /* overflow check */
+ if ((p - dst + 1) > dstlen)
+ goto error;
*p++ = buf & 255;
+ }
buf = 0;
pos = 0;
}
* base64 end sequence is invalid. Input data is missing padding, is
* truncated or is otherwise corrupted.
*/
- return -1;
+ goto error;
}
+ Assert((p - dst) <= dstlen);
return p - dst;
+
+error:
+ memset(dst, 0, dstlen);
+ return -1;
}
/*
char *result;
char *p;
int maxlen;
+ int encoded_salt_len;
+ int encoded_stored_len;
+ int encoded_server_len;
+ int encoded_result;
if (iterations <= 0)
iterations = SCRAM_DEFAULT_ITERATIONS;
* SCRAM-SHA-256$<iteration count>:<salt>$<StoredKey>:<ServerKey>
*----------
*/
+ encoded_salt_len = pg_b64_enc_len(saltlen);
+ encoded_stored_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+ encoded_server_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+
maxlen = strlen("SCRAM-SHA-256") + 1
+ 10 + 1 /* iteration count */
- + pg_b64_enc_len(saltlen) + 1 /* Base64-encoded salt */
- + pg_b64_enc_len(SCRAM_KEY_LEN) + 1 /* Base64-encoded StoredKey */
- + pg_b64_enc_len(SCRAM_KEY_LEN) + 1; /* Base64-encoded ServerKey */
+ + encoded_salt_len + 1 /* Base64-encoded salt */
+ + encoded_stored_len + 1 /* Base64-encoded StoredKey */
+ + encoded_server_len + 1; /* Base64-encoded ServerKey */
#ifdef FRONTEND
result = malloc(maxlen);
p = result + sprintf(result, "SCRAM-SHA-256$%d:", iterations);
- p += pg_b64_encode(salt, saltlen, p);
+ /* salt */
+ encoded_result = pg_b64_encode(salt, saltlen, p, encoded_salt_len);
+ if (encoded_result < 0)
+ {
+#ifdef FRONTEND
+ free(result);
+ return NULL;
+#else
+ elog(ERROR, "could not encode salt");
+#endif
+ }
+ p += encoded_result;
*(p++) = '$';
- p += pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p);
+
+ /* stored key */
+ encoded_result = pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p,
+ encoded_stored_len);
+ if (encoded_result < 0)
+ {
+#ifdef FRONTEND
+ free(result);
+ return NULL;
+#else
+ elog(ERROR, "could not encode stored key");
+#endif
+ }
+
+ p += encoded_result;
*(p++) = ':';
- p += pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p);
+
+ /* server key */
+ encoded_result = pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p,
+ encoded_server_len);
+ if (encoded_result < 0)
+ {
+#ifdef FRONTEND
+ free(result);
+ return NULL;
+#else
+ elog(ERROR, "could not encode server key");
+#endif
+ }
+
+ p += encoded_result;
*(p++) = '\0';
Assert(p - result <= maxlen);
#define BASE64_H
/* base 64 */
-extern int pg_b64_encode(const char *src, int len, char *dst);
-extern int pg_b64_decode(const char *src, int len, char *dst);
+extern int pg_b64_encode(const char *src, int len, char *dst, int dstlen);
+extern int pg_b64_decode(const char *src, int len, char *dst, int dstlen);
extern int pg_b64_enc_len(int srclen);
extern int pg_b64_dec_len(int srclen);
return NULL;
}
- state->client_nonce = malloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1);
+ encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN);
+ /* don't forget the zero-terminator */
+ state->client_nonce = malloc(encoded_len + 1);
if (state->client_nonce == NULL)
{
printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("out of memory\n"));
return NULL;
}
- encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN, state->client_nonce);
+ encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN,
+ state->client_nonce, encoded_len);
+ if (encoded_len < 0)
+ {
+ printfPQExpBuffer(&conn->errorMessage,
+ libpq_gettext("could not encode nonce\n"));
+ return NULL;
+ }
state->client_nonce[encoded_len] = '\0';
/*
PGconn *conn = state->conn;
uint8 client_proof[SCRAM_KEY_LEN];
char *result;
+ int encoded_len;
initPQExpBuffer(&buf);
size_t cbind_header_len;
char *cbind_input;
size_t cbind_input_len;
+ int encoded_cbind_len;
/* Fetch hash data of server's SSL certificate */
cbind_data =
memcpy(cbind_input, "p=tls-server-end-point,,", cbind_header_len);
memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
- if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(cbind_input_len)))
+ encoded_cbind_len = pg_b64_enc_len(cbind_input_len);
+ if (!enlargePQExpBuffer(&buf, encoded_cbind_len))
{
free(cbind_data);
free(cbind_input);
goto oom_error;
}
- buf.len += pg_b64_encode(cbind_input, cbind_input_len, buf.data + buf.len);
+ encoded_cbind_len = pg_b64_encode(cbind_input, cbind_input_len,
+ buf.data + buf.len,
+ encoded_cbind_len);
+ if (encoded_cbind_len < 0)
+ {
+ free(cbind_data);
+ free(cbind_input);
+ termPQExpBuffer(&buf);
+ printfPQExpBuffer(&conn->errorMessage,
+ "could not encode cbind data for channel binding\n");
+ return NULL;
+ }
+ buf.len += encoded_cbind_len;
buf.data[buf.len] = '\0';
free(cbind_data);
client_proof);
appendPQExpBufferStr(&buf, ",p=");
- if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(SCRAM_KEY_LEN)))
+ encoded_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+ if (!enlargePQExpBuffer(&buf, encoded_len))
goto oom_error;
- buf.len += pg_b64_encode((char *) client_proof,
- SCRAM_KEY_LEN,
- buf.data + buf.len);
+ encoded_len = pg_b64_encode((char *) client_proof,
+ SCRAM_KEY_LEN,
+ buf.data + buf.len,
+ encoded_len);
+ if (encoded_len < 0)
+ {
+ termPQExpBuffer(&buf);
+ printfPQExpBuffer(&conn->errorMessage,
+ libpq_gettext("could not encode client proof\n"));
+ return NULL;
+ }
+ buf.len += encoded_len;
buf.data[buf.len] = '\0';
result = strdup(buf.data);
char *endptr;
char *encoded_salt;
char *nonce;
+ int decoded_salt_len;
state->server_first_message = strdup(input);
if (state->server_first_message == NULL)
/* read_attr_value() has generated an error string */
return false;
}
- state->salt = malloc(pg_b64_dec_len(strlen(encoded_salt)));
+ decoded_salt_len = pg_b64_dec_len(strlen(encoded_salt));
+ state->salt = malloc(decoded_salt_len);
if (state->salt == NULL)
{
printfPQExpBuffer(&conn->errorMessage,
}
state->saltlen = pg_b64_decode(encoded_salt,
strlen(encoded_salt),
- state->salt);
+ state->salt,
+ decoded_salt_len);
if (state->saltlen < 0)
{
printfPQExpBuffer(&conn->errorMessage,
server_signature_len = pg_b64_decode(encoded_server_signature,
strlen(encoded_server_signature),
- decoded_server_signature);
+ decoded_server_signature,
+ server_signature_len);
if (server_signature_len != SCRAM_KEY_LEN)
{
free(decoded_server_signature);