1
0
mirror of https://github.com/Mbed-TLS/mbedtls.git synced 2025-07-30 22:43:08 +03:00

mbedtls_base64_decode: insist on correct padding

Correct base64 input (excluding ignored characters such as spaces) consists
of exactly 4*k, 4*k-1 or 4*k-2 digits, followed by 0, 1 or 2 equal signs
respectively.

Previously, any number of trailing equal signs up to 2 was accepted, but if
there fewer than 4*k digits-or-equals, the last partial block was counted in
`*olen` in buffer-too-small mode, but was not output despite returning 0.

Now `mbedtls_base64_decode()` insists on correct padding. This is
backward-compatible since the only plausible useful inputs that used to be
accepted were inputs with 4*k-1 or 4*k-2 digits and no trailing equal signs,
and those led to invalid (truncated) output. Furthermore the function now
always reports the exact output size in buffer-too-small mode.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
This commit is contained in:
Gilles Peskine
2025-06-04 11:22:25 +02:00
parent 84999d1a7b
commit 2b3d6a8f28
4 changed files with 67 additions and 78 deletions

View File

@ -14,6 +14,7 @@
#include "mbedtls/base64.h"
#include "base64_internal.h"
#include "constant_time_internal.h"
#include "mbedtls/error.h"
#include <stdint.h>
@ -183,55 +184,57 @@ int mbedtls_base64_decode(unsigned char *dst, size_t dlen, size_t *olen,
n++;
}
/* In valid base64, the number of digits is always of the form
* 4n, 4n+2 or 4n+3. */
/* In valid base64, the number of digits (n-equals) is always of the form
* 4*k, 4*k+2 or *4k+3. Also, the number n of digits plus the number of
* equal signs at the end is always a multiple of 4. */
if ((n - equals) % 4 == 1) {
return MBEDTLS_ERR_BASE64_INVALID_CHARACTER;
}
if (n == 0) {
*olen = 0;
return 0;
if (n % 4 != 0) {
return MBEDTLS_ERR_BASE64_INVALID_CHARACTER;
}
/* The following expression is to calculate the following formula without
* risk of integer overflow in n:
* n = ( ( n * 6 ) + 7 ) >> 3;
*/
n = (6 * (n >> 3)) + ((6 * (n & 0x7) + 7) >> 3);
n -= equals;
/* We've determined that the input is valid, and that it contains
* n digits-plus-trailing-equal-signs, which means (n - equals) digits.
* Now set *olen to the exact length of the output. */
/* Each block of 4 digits in the input map to 3 bytes of output.
* The last block can have one or two equal signs, in which case
* there are that many fewer output bytes. */
*olen = (n / 4) * 3 - equals;
if (dst == NULL || dlen < n) {
*olen = n;
if ((*olen != 0 && dst == NULL) || dlen < *olen) {
return MBEDTLS_ERR_BASE64_BUFFER_TOO_SMALL;
}
equals = 0;
for (x = 0, p = dst; i > 0; i--, src++) {
if (*src == '\r' || *src == '\n' || *src == ' ') {
continue;
}
x = x << 6;
if (*src == '=') {
++equals;
} else {
x |= mbedtls_ct_base64_dec_value(*src);
/* We already know from the first loop that equal signs are
* only at the end. */
break;
}
x = x << 6;
x |= mbedtls_ct_base64_dec_value(*src);
if (++accumulated_digits == 4) {
accumulated_digits = 0;
*p++ = MBEDTLS_BYTE_2(x);
if (equals <= 1) {
*p++ = MBEDTLS_BYTE_1(x);
}
if (equals <= 0) {
*p++ = MBEDTLS_BYTE_0(x);
}
*p++ = MBEDTLS_BYTE_1(x);
*p++ = MBEDTLS_BYTE_0(x);
}
}
if (accumulated_digits == 3) {
*p++ = MBEDTLS_BYTE_2(x << 6);
*p++ = MBEDTLS_BYTE_1(x << 6);
} else if (accumulated_digits == 2) {
*p++ = MBEDTLS_BYTE_2(x << 12);
}
*olen = (size_t) (p - dst);
if (*olen != (size_t) (p - dst)) {
return MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
}
return 0;
}