1
0
mirror of https://github.com/postgres/postgres.git synced 2025-09-02 04:21:28 +03:00

Introduce safer encoding and decoding routines for base64.c

This is a follow-up refactoring after 09ec55b and b674211, which has
proved that the encoding and decoding routines used by SCRAM have a
poor interface when it comes to check after buffer overflows.  This adds
an extra argument in the shape of the length of the result buffer for
each routine, which is used for overflow checks when encoding or
decoding an input string.  The original idea comes from Tom Lane.

As a result of that, the encoding routine can now fail, so all its
callers are adjusted to generate proper error messages in case of
problems.

On failure, the result buffer gets zeroed.

Author: Michael Paquier
Reviewed-by: Daniel Gustafsson
Discussion: https://postgr.es/m/20190623132535.GB1628@paquier.xyz
This commit is contained in:
Michael Paquier
2019-07-04 16:08:09 +09:00
parent d5ab9a891c
commit cfc40d384a
5 changed files with 210 additions and 46 deletions

View File

@@ -42,10 +42,11 @@ static const int8 b64lookup[128] = {
* pg_b64_encode
*
* Encode into base64 the given string. Returns the length of the encoded
* string.
* string, and -1 in the event of an error with the result buffer zeroed
* for safety.
*/
int
pg_b64_encode(const char *src, int len, char *dst)
pg_b64_encode(const char *src, int len, char *dst, int dstlen)
{
char *p;
const char *s,
@@ -65,6 +66,13 @@ pg_b64_encode(const char *src, int len, char *dst)
/* write it out */
if (pos < 0)
{
/*
* Leave if there is an overflow in the area allocated for the
* encoded string.
*/
if ((p - dst + 4) > dstlen)
goto error;
*p++ = _base64[(buf >> 18) & 0x3f];
*p++ = _base64[(buf >> 12) & 0x3f];
*p++ = _base64[(buf >> 6) & 0x3f];
@@ -76,23 +84,36 @@ pg_b64_encode(const char *src, int len, char *dst)
}
if (pos != 2)
{
/*
* Leave if there is an overflow in the area allocated for the encoded
* string.
*/
if ((p - dst + 4) > dstlen)
goto error;
*p++ = _base64[(buf >> 18) & 0x3f];
*p++ = _base64[(buf >> 12) & 0x3f];
*p++ = (pos == 0) ? _base64[(buf >> 6) & 0x3f] : '=';
*p++ = '=';
}
Assert((p - dst) <= dstlen);
return p - dst;
error:
memset(dst, 0, dstlen);
return -1;
}
/*
* pg_b64_decode
*
* Decode the given base64 string. Returns the length of the decoded
* string on success, and -1 in the event of an error.
* string on success, and -1 in the event of an error with the result
* buffer zeroed for safety.
*/
int
pg_b64_decode(const char *src, int len, char *dst)
pg_b64_decode(const char *src, int len, char *dst, int dstlen)
{
const char *srcend = src + len,
*s = src;
@@ -109,7 +130,7 @@ pg_b64_decode(const char *src, int len, char *dst)
/* Leave if a whitespace is found */
if (c == ' ' || c == '\t' || c == '\n' || c == '\r')
return -1;
goto error;
if (c == '=')
{
@@ -126,7 +147,7 @@ pg_b64_decode(const char *src, int len, char *dst)
* Unexpected "=" character found while decoding base64
* sequence.
*/
return -1;
goto error;
}
}
b = 0;
@@ -139,7 +160,7 @@ pg_b64_decode(const char *src, int len, char *dst)
if (b < 0)
{
/* invalid symbol found */
return -1;
goto error;
}
}
/* add it to buffer */
@@ -147,11 +168,28 @@ pg_b64_decode(const char *src, int len, char *dst)
pos++;
if (pos == 4)
{
/*
* Leave if there is an overflow in the area allocated for the
* decoded string.
*/
if ((p - dst + 1) > dstlen)
goto error;
*p++ = (buf >> 16) & 255;
if (end == 0 || end > 1)
{
/* overflow check */
if ((p - dst + 1) > dstlen)
goto error;
*p++ = (buf >> 8) & 255;
}
if (end == 0 || end > 2)
{
/* overflow check */
if ((p - dst + 1) > dstlen)
goto error;
*p++ = buf & 255;
}
buf = 0;
pos = 0;
}
@@ -163,10 +201,15 @@ pg_b64_decode(const char *src, int len, char *dst)
* base64 end sequence is invalid. Input data is missing padding, is
* truncated or is otherwise corrupted.
*/
return -1;
goto error;
}
Assert((p - dst) <= dstlen);
return p - dst;
error:
memset(dst, 0, dstlen);
return -1;
}
/*

View File

@@ -198,6 +198,10 @@ scram_build_verifier(const char *salt, int saltlen, int iterations,
char *result;
char *p;
int maxlen;
int encoded_salt_len;
int encoded_stored_len;
int encoded_server_len;
int encoded_result;
if (iterations <= 0)
iterations = SCRAM_DEFAULT_ITERATIONS;
@@ -215,11 +219,15 @@ scram_build_verifier(const char *salt, int saltlen, int iterations,
* SCRAM-SHA-256$<iteration count>:<salt>$<StoredKey>:<ServerKey>
*----------
*/
encoded_salt_len = pg_b64_enc_len(saltlen);
encoded_stored_len = pg_b64_enc_len(SCRAM_KEY_LEN);
encoded_server_len = pg_b64_enc_len(SCRAM_KEY_LEN);
maxlen = strlen("SCRAM-SHA-256") + 1
+ 10 + 1 /* iteration count */
+ pg_b64_enc_len(saltlen) + 1 /* Base64-encoded salt */
+ pg_b64_enc_len(SCRAM_KEY_LEN) + 1 /* Base64-encoded StoredKey */
+ pg_b64_enc_len(SCRAM_KEY_LEN) + 1; /* Base64-encoded ServerKey */
+ encoded_salt_len + 1 /* Base64-encoded salt */
+ encoded_stored_len + 1 /* Base64-encoded StoredKey */
+ encoded_server_len + 1; /* Base64-encoded ServerKey */
#ifdef FRONTEND
result = malloc(maxlen);
@@ -231,11 +239,50 @@ scram_build_verifier(const char *salt, int saltlen, int iterations,
p = result + sprintf(result, "SCRAM-SHA-256$%d:", iterations);
p += pg_b64_encode(salt, saltlen, p);
/* salt */
encoded_result = pg_b64_encode(salt, saltlen, p, encoded_salt_len);
if (encoded_result < 0)
{
#ifdef FRONTEND
free(result);
return NULL;
#else
elog(ERROR, "could not encode salt");
#endif
}
p += encoded_result;
*(p++) = '$';
p += pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p);
/* stored key */
encoded_result = pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p,
encoded_stored_len);
if (encoded_result < 0)
{
#ifdef FRONTEND
free(result);
return NULL;
#else
elog(ERROR, "could not encode stored key");
#endif
}
p += encoded_result;
*(p++) = ':';
p += pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p);
/* server key */
encoded_result = pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p,
encoded_server_len);
if (encoded_result < 0)
{
#ifdef FRONTEND
free(result);
return NULL;
#else
elog(ERROR, "could not encode server key");
#endif
}
p += encoded_result;
*(p++) = '\0';
Assert(p - result <= maxlen);