From b411b19b926983e772c0218ebfdd379d79621043 Mon Sep 17 00:00:00 2001 From: lhazlewood <121180+lhazlewood@users.noreply.github.com> Date: Tue, 3 Oct 2023 12:27:12 -0700 Subject: [PATCH] key byte array cleanup as necessary (#846) --- .../jsonwebtoken/impl/security/ConcatKDF.java | 82 ++++++++++++------- .../impl/security/EcdhKeyAlgorithm.java | 6 +- .../impl/security/HmacAesAeadAlgorithm.java | 33 ++++++-- 3 files changed, 86 insertions(+), 35 deletions(-) diff --git a/impl/src/main/java/io/jsonwebtoken/impl/security/ConcatKDF.java b/impl/src/main/java/io/jsonwebtoken/impl/security/ConcatKDF.java index 5bd2f8ef..37439845 100644 --- a/impl/src/main/java/io/jsonwebtoken/impl/security/ConcatKDF.java +++ b/impl/src/main/java/io/jsonwebtoken/impl/security/ConcatKDF.java @@ -15,6 +15,7 @@ */ package io.jsonwebtoken.impl.security; +import io.jsonwebtoken.impl.lang.Bytes; import io.jsonwebtoken.impl.lang.CheckedFunction; import io.jsonwebtoken.lang.Assert; import io.jsonwebtoken.security.SecurityException; @@ -114,43 +115,68 @@ final class ConcatKDF extends CryptoAlgorithm { long inputBitLength = bitLength(counter) + bitLength(Z) + bitLength(OtherInfo); Assert.state(inputBitLength <= MAX_HASH_INPUT_BIT_LENGTH, "Hash input is too large."); - byte[] derivedKeyBytes = jca().withMessageDigest(new CheckedFunction() { - @Override - public byte[] apply(MessageDigest md) throws Exception { + final ClearableByteArrayOutputStream stream = new ClearableByteArrayOutputStream((int) derivedKeyByteLength); + byte[] derivedKeyBytes = EMPTY; - final ByteArrayOutputStream stream = new ByteArrayOutputStream((int) derivedKeyByteLength); + try { + derivedKeyBytes = jca().withMessageDigest(new CheckedFunction() { + @Override + public byte[] apply(MessageDigest md) throws Exception { - // Section 5.8.1.1, Process step #5. We depart from Java idioms here by starting iteration index at 1 - // (instead of 0) and continue to <= reps (instead of < reps) to match the NIST publication algorithm - // notation convention (so variables like Ki and kLast below match the NIST definitions). - for (long i = 1; i <= reps; i++) { + // Section 5.8.1.1, Process step #5. We depart from Java idioms here by starting iteration index at 1 + // (instead of 0) and continue to <= reps (instead of < reps) to match the NIST publication algorithm + // notation convention (so variables like Ki and kLast below match the NIST definitions). + for (long i = 1; i <= reps; i++) { - // Section 5.8.1.1, Process step #5.1: - md.update(counter); - md.update(Z); - md.update(OtherInfo); - byte[] Ki = md.digest(); + // Section 5.8.1.1, Process step #5.1: + md.update(counter); + md.update(Z); + md.update(OtherInfo); + byte[] Ki = md.digest(); - // Section 5.8.1.1, Process step #5.2: - increment(counter); + // Section 5.8.1.1, Process step #5.2: + increment(counter); - // Section 5.8.1.1, Process step #6: - if (i == reps && kLastPartial) { - long leftmostBitLength = derivedKeyBitLength % hashBitLength; - int leftmostByteLength = (int) (leftmostBitLength / Byte.SIZE); - byte[] kLast = new byte[leftmostByteLength]; - System.arraycopy(Ki, 0, kLast, 0, kLast.length); - Ki = kLast; + // Section 5.8.1.1, Process step #6: + if (i == reps && kLastPartial) { + long leftmostBitLength = derivedKeyBitLength % hashBitLength; + int leftmostByteLength = (int) (leftmostBitLength / Byte.SIZE); + byte[] kLast = new byte[leftmostByteLength]; + System.arraycopy(Ki, 0, kLast, 0, kLast.length); + Ki = kLast; + } + + stream.write(Ki); } - stream.write(Ki); + // Section 5.8.1.1, Process step #7: + return stream.toByteArray(); } + }); + return new SecretKeySpec(derivedKeyBytes, AesAlgorithm.KEY_ALG_NAME); + } finally { + // key cleanup + Bytes.clear(derivedKeyBytes); // SecretKeySpec clones this, so we can clear it out safely + Bytes.clear(counter); + stream.reset(); + // we don't clear out 'Z', since that is the responsibility of the caller + } + } - // Section 5.8.1.1, Process step #7: - return stream.toByteArray(); - } - }); + /** + * Calling ByteArrayOutputStream.toByteArray returns a copy of the bytes, so this class allows us to completely + * zero-out the buffer upon reset (whereas BAOS just resets the position marker, leaving the bytes in tact) + */ + private static class ClearableByteArrayOutputStream extends ByteArrayOutputStream { - return new SecretKeySpec(derivedKeyBytes, AesAlgorithm.KEY_ALG_NAME); + public ClearableByteArrayOutputStream(int size) { + super(size); + } + + @Override + public synchronized void reset() { + super.reset(); + Bytes.clear(buf); // zero out internal buffer + } } } diff --git a/impl/src/main/java/io/jsonwebtoken/impl/security/EcdhKeyAlgorithm.java b/impl/src/main/java/io/jsonwebtoken/impl/security/EcdhKeyAlgorithm.java index 8e771131..97aec5cf 100644 --- a/impl/src/main/java/io/jsonwebtoken/impl/security/EcdhKeyAlgorithm.java +++ b/impl/src/main/java/io/jsonwebtoken/impl/security/EcdhKeyAlgorithm.java @@ -139,7 +139,11 @@ class EcdhKeyAlgorithm extends CryptoAlgorithm implements KeyAlgorithm