[Concept,01/24] aes: Fix key size handling for AES-192 and AES-256
Commit Message
From: Simon Glass <sjg@chromium.org>
At present the aes_get_rounds() and aes_get_keycols() functions compare
the key_len parameter (in bits) directly against AES*_KEY_LENGTH
constants (in bytes), causing incorrect round and column counts for
non-128-bit keys.
Additionally, aes_expand_key() uses key_len as a byte count in memcpy(),
copying far more data than intended and causing buffer overflows.
Specifically, for AES-256 (256-bit key) it comparies 256 (bits) against
32 (bytes), failing the comparison. This causes AES-256 to use AES-128
parameters (10 rounds instead of 14) and the memcpy() to copy 256 bytes
instead of 32.
Fix by converting key_len from bits to bytes before comparisons and in
memcpy. With this we get:
- AES-128 (128 bits / 16 bytes): 10 rounds, 4 key columns
- AES-192 (192 bits / 24 bytes): 12 rounds, 6 key columns
- AES-256 (256 bits / 32 bytes): 14 rounds, 8 key columns
Co-developed-by: Claude <noreply@anthropic.com>
Signed-off-by: Simon Glass <sjg@chromium.org>
Fixes: 8302d1708ae ("aes: add support of aes192 and aes256")
---
lib/aes.c | 15 +++++++++------
1 file changed, 9 insertions(+), 6 deletions(-)
@@ -513,10 +513,11 @@ static u8 rcon[11] = {
static u32 aes_get_rounds(u32 key_len)
{
u32 rounds = AES128_ROUNDS;
+ u32 key_len_bytes = key_len / 8; /* Convert bits to bytes */
- if (key_len == AES192_KEY_LENGTH)
+ if (key_len_bytes == AES192_KEY_LENGTH)
rounds = AES192_ROUNDS;
- else if (key_len == AES256_KEY_LENGTH)
+ else if (key_len_bytes == AES256_KEY_LENGTH)
rounds = AES256_ROUNDS;
return rounds;
@@ -525,10 +526,11 @@ static u32 aes_get_rounds(u32 key_len)
static u32 aes_get_keycols(u32 key_len)
{
u32 keycols = AES128_KEYCOLS;
+ u32 key_len_bytes = key_len / 8; /* Convert bits to bytes */
- if (key_len == AES192_KEY_LENGTH)
+ if (key_len_bytes == AES192_KEY_LENGTH)
keycols = AES192_KEYCOLS;
- else if (key_len == AES256_KEY_LENGTH)
+ else if (key_len_bytes == AES256_KEY_LENGTH)
keycols = AES256_KEYCOLS;
return keycols;
@@ -538,12 +540,13 @@ static u32 aes_get_keycols(u32 key_len)
void aes_expand_key(u8 *key, u32 key_len, u8 *expkey)
{
u8 tmp0, tmp1, tmp2, tmp3, tmp4;
- u32 idx, aes_rounds, aes_keycols;
+ uint idx, aes_rounds, aes_keycols;
aes_rounds = aes_get_rounds(key_len);
aes_keycols = aes_get_keycols(key_len);
- memcpy(expkey, key, key_len);
+ /* key_len is in bits; convert to bytes */
+ memcpy(expkey, key, key_len / 8);
for (idx = aes_keycols; idx < AES_STATECOLS * (aes_rounds + 1); idx++) {
tmp0 = expkey[4*idx - 4];