JWK .equals and .hashCode (#823)

* Adjusted JWK .equals implementations to only account for kty value and material fields (two JWKs are equal if their type and key material are equal, regardless of other public parameters and/or custom name/value pairs).

* Adjusted JWK .hashCode implementation to pre-cache its value based on JwkThumpbrint fields since JWKs are immutable
This commit is contained in:
lhazlewood 2023-09-12 20:38:01 -07:00 committed by GitHub
parent f60d560297
commit b55f26175c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 471 additions and 37 deletions

View File

@ -176,6 +176,16 @@ public final class Bytes {
return output; return output;
} }
/**
* Clears the array by filling it with all zeros. Does nothing with a null or empty argument.
*
* @param bytes the (possibly null or empty) byte array to clear
*/
public static void clear(byte[] bytes) {
if (isEmpty(bytes)) return;
java.util.Arrays.fill(bytes, (byte) 0);
}
public static boolean isEmpty(byte[] bytes) { public static boolean isEmpty(byte[] bytes) {
return length(bytes) == 0; return length(bytes) == 0;
} }

View File

@ -17,10 +17,12 @@ package io.jsonwebtoken.impl.lang;
import io.jsonwebtoken.lang.Arrays; import io.jsonwebtoken.lang.Arrays;
import io.jsonwebtoken.lang.Assert; import io.jsonwebtoken.lang.Assert;
import io.jsonwebtoken.lang.Objects;
import io.jsonwebtoken.lang.Registry; import io.jsonwebtoken.lang.Registry;
import java.math.BigInteger; import java.math.BigInteger;
import java.net.URI; import java.net.URI;
import java.security.MessageDigest;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.Collection; import java.util.Collection;
import java.util.Date; import java.util.Date;
@ -97,4 +99,46 @@ public final class Fields {
newFields.put(id, field); // add new one newFields.put(id, field); // add new one
return registry(newFields.values()); return registry(newFields.values());
} }
private static byte[] bytes(BigInteger i) {
return i != null ? i.toByteArray() : null;
}
public static boolean bytesEquals(BigInteger a, BigInteger b) {
//noinspection NumberEquality
if (a == b) return true;
if (a == null || b == null) return false;
byte[] aBytes = bytes(a);
byte[] bBytes = bytes(b);
try {
return MessageDigest.isEqual(aBytes, bBytes);
} finally {
Bytes.clear(aBytes);
Bytes.clear(bBytes);
}
}
private static <T> boolean equals(T a, T b, Field<T> field) {
if (a == b) return true;
if (a == null || b == null) return false;
if (field.isSecret()) {
// byte[] and BigInteger are the only types of secret Fields in the JJWT codebase
// (i.e. Field.isSecret() == true). If a Field is ever marked as secret, and it's not one of these two
// data types, we need to know about it. So we use the 'assertSecret' helper above to ensure we do:
if (a instanceof byte[]) {
return b instanceof byte[] && MessageDigest.isEqual((byte[]) a, (byte[]) b);
} else if (a instanceof BigInteger) {
return b instanceof BigInteger && bytesEquals((BigInteger) a, (BigInteger) b);
}
}
// default to a standard null-safe comparison:
return Objects.nullSafeEquals(a, b);
}
public static <T> boolean equals(FieldReadable a, Object o, Field<T> field) {
if (a == o) return true;
if (a == null || !(o instanceof FieldReadable)) return false;
FieldReadable b = (FieldReadable) o;
return equals(a.get(field), b.get(field), field);
}
} }

View File

@ -33,6 +33,9 @@ import io.jsonwebtoken.security.KeyOperation;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.Key; import java.security.Key;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
@ -48,10 +51,11 @@ public abstract class AbstractJwk<K extends Key> implements Jwk<K>, FieldReadabl
.set().setId("key_ops").setName("Key Operations").build(); .set().setId("key_ops").setName("Key Operations").build();
static final Field<String> KTY = Fields.string("kty", "Key Type"); static final Field<String> KTY = Fields.string("kty", "Key Type");
static final Set<Field<?>> FIELDS = Collections.setOf(ALG, KID, KEY_OPS, KTY); static final Set<Field<?>> FIELDS = Collections.setOf(ALG, KID, KEY_OPS, KTY);
public static final String IMMUTABLE_MSG = "JWKs are immutable and may not be modified."; public static final String IMMUTABLE_MSG = "JWKs are immutable and may not be modified.";
protected final JwkContext<K> context; protected final JwkContext<K> context;
private final List<Field<?>> THUMBPRINT_FIELDS; private final List<Field<?>> THUMBPRINT_FIELDS;
private final int hashCode;
/** /**
* @param ctx the backing JwkContext containing the JWK field values. * @param ctx the backing JwkContext containing the JWK field values.
@ -71,6 +75,40 @@ public abstract class AbstractJwk<K extends Key> implements Jwk<K>, FieldReadabl
String kid = thumbprint.toString(); String kid = thumbprint.toString();
ctx.setId(kid); ctx.setId(kid);
} }
this.hashCode = computeHashCode();
}
/**
* Compute and return the JWK hashCode. As JWKs are immutable, this value will be cached as a final constant
* upon JWK instantiation. This uses the JWK's thumbprint fields during computation, but differs from JwkThumbprint
* calculation in two ways:
* <ol>
* <li>JwkThumbprints use a MessageDigest calculation, which is unnecessary overhead for a hashcode</li>
* <li>The hashCode calculation uses each field's idiomatic (Java) object value instead of the
* JwkThumbprint-required canonical (String) value.</li>
* </ol>
*
* @return the JWK hashcode
*/
private int computeHashCode() {
List<Object> list = new ArrayList<>(this.THUMBPRINT_FIELDS.size() + 1 /* possible discriminator */);
// So we don't leak information about the private key value, we need a discriminator to ensure that
// public and private key hashCodes are not identical (in case both JWKs need to be in the same hash set).
// So we add a discriminator String to the list of values that are used during hashCode calculation
Key key = Assert.notNull(toKey(), "JWK toKey() value cannot be null.");
if (key instanceof PublicKey) {
list.add("Public");
} else if (key instanceof PrivateKey) {
list.add("Private");
}
for (Field<?> field : this.THUMBPRINT_FIELDS) {
// Unlike thumbprint calculation, we get the idiomatic (Java) value, not canonical (String) value
// (We could have used either actually, but the idiomatic value hashCode calculation is probably
// faster).
Object val = Assert.notNull(get(field), "computeHashCode: Field idiomatic value cannot be null.");
list.add(val);
}
return Objects.nullSafeHashCode(list.toArray());
} }
private String getRequiredThumbprintValue(Field<?> field) { private String getRequiredThumbprintValue(Field<?> field) {
@ -230,13 +268,20 @@ public abstract class AbstractJwk<K extends Key> implements Jwk<K>, FieldReadabl
} }
@Override @Override
public int hashCode() { public final int hashCode() {
return this.context.hashCode(); return this.hashCode;
} }
@SuppressWarnings("EqualsWhichDoesntCheckParameterClass")
@Override @Override
public boolean equals(Object obj) { public final boolean equals(Object obj) {
return this.context.equals(obj); if (obj == this) return true;
if (obj instanceof Jwk<?>) {
Jwk<?> other = (Jwk<?>) obj;
// this.getType() guaranteed non-null in constructor:
return getType().equals(other.getType()) && equals(other);
}
return false;
} }
protected abstract boolean equals(Jwk<?> jwk);
} }

View File

@ -170,6 +170,12 @@ abstract class AbstractJwkBuilder<K extends Key, J extends Jwk<K>, T extends Jwk
implements SecretJwkBuilder { implements SecretJwkBuilder {
public DefaultSecretJwkBuilder(JwkContext<SecretKey> ctx) { public DefaultSecretJwkBuilder(JwkContext<SecretKey> ctx) {
super(ctx); super(ctx);
// assign a standard algorithm if possible:
Key key = Assert.notNull(ctx.getKey(), "SecretKey cannot be null.");
DefaultMacAlgorithm mac = DefaultMacAlgorithm.findByKey(key);
if (mac != null) {
algorithm(mac.getId());
}
} }
} }
} }

View File

@ -17,6 +17,7 @@ package io.jsonwebtoken.impl.security;
import io.jsonwebtoken.impl.lang.Field; import io.jsonwebtoken.impl.lang.Field;
import io.jsonwebtoken.lang.Assert; import io.jsonwebtoken.lang.Assert;
import io.jsonwebtoken.security.Jwk;
import io.jsonwebtoken.security.KeyPair; import io.jsonwebtoken.security.KeyPair;
import io.jsonwebtoken.security.PrivateJwk; import io.jsonwebtoken.security.PrivateJwk;
import io.jsonwebtoken.security.PublicJwk; import io.jsonwebtoken.security.PublicJwk;
@ -47,4 +48,11 @@ abstract class AbstractPrivateJwk<K extends PrivateKey, L extends PublicKey, M e
public KeyPair<L, K> toKeyPair() { public KeyPair<L, K> toKeyPair() {
return this.keyPair; return this.keyPair;
} }
@Override
protected final boolean equals(Jwk<?> jwk) {
return jwk instanceof PrivateJwk && equals((PrivateJwk<?, ?, ?>) jwk);
}
protected abstract boolean equals(PrivateJwk<?, ?, ?> jwk);
} }

View File

@ -16,6 +16,7 @@
package io.jsonwebtoken.impl.security; package io.jsonwebtoken.impl.security;
import io.jsonwebtoken.impl.lang.Field; import io.jsonwebtoken.impl.lang.Field;
import io.jsonwebtoken.security.Jwk;
import io.jsonwebtoken.security.PublicJwk; import io.jsonwebtoken.security.PublicJwk;
import java.security.PublicKey; import java.security.PublicKey;
@ -25,4 +26,11 @@ abstract class AbstractPublicJwk<K extends PublicKey> extends AbstractAsymmetric
AbstractPublicJwk(JwkContext<K> ctx, List<Field<?>> thumbprintFields) { AbstractPublicJwk(JwkContext<K> ctx, List<Field<?>> thumbprintFields) {
super(ctx, thumbprintFields); super(ctx, thumbprintFields);
} }
@Override
protected final boolean equals(Jwk<?> jwk) {
return jwk instanceof PublicJwk && equals((PublicJwk<?>) jwk);
}
protected abstract boolean equals(PublicJwk<?> jwk);
} }

View File

@ -20,12 +20,15 @@ import io.jsonwebtoken.impl.lang.Fields;
import io.jsonwebtoken.lang.Collections; import io.jsonwebtoken.lang.Collections;
import io.jsonwebtoken.security.EcPrivateJwk; import io.jsonwebtoken.security.EcPrivateJwk;
import io.jsonwebtoken.security.EcPublicJwk; import io.jsonwebtoken.security.EcPublicJwk;
import io.jsonwebtoken.security.PrivateJwk;
import java.math.BigInteger; import java.math.BigInteger;
import java.security.interfaces.ECPrivateKey; import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey; import java.security.interfaces.ECPublicKey;
import java.util.Set; import java.util.Set;
import static io.jsonwebtoken.impl.security.DefaultEcPublicJwk.equalsPublic;
class DefaultEcPrivateJwk extends AbstractPrivateJwk<ECPrivateKey, ECPublicKey, EcPublicJwk> implements EcPrivateJwk { class DefaultEcPrivateJwk extends AbstractPrivateJwk<ECPrivateKey, ECPublicKey, EcPublicJwk> implements EcPrivateJwk {
static final Field<BigInteger> D = Fields.secretBigInt("d", "ECC Private Key"); static final Field<BigInteger> D = Fields.secretBigInt("d", "ECC Private Key");
@ -38,4 +41,9 @@ class DefaultEcPrivateJwk extends AbstractPrivateJwk<ECPrivateKey, ECPublicKey,
DefaultEcPublicJwk.THUMBPRINT_FIELDS, DefaultEcPublicJwk.THUMBPRINT_FIELDS,
pubJwk); pubJwk);
} }
@Override
protected boolean equals(PrivateJwk<?, ?, ?> jwk) {
return jwk instanceof EcPrivateJwk && equalsPublic(this, jwk) && Fields.equals(this, jwk, D);
}
} }

View File

@ -16,9 +16,11 @@
package io.jsonwebtoken.impl.security; package io.jsonwebtoken.impl.security;
import io.jsonwebtoken.impl.lang.Field; import io.jsonwebtoken.impl.lang.Field;
import io.jsonwebtoken.impl.lang.FieldReadable;
import io.jsonwebtoken.impl.lang.Fields; import io.jsonwebtoken.impl.lang.Fields;
import io.jsonwebtoken.lang.Collections; import io.jsonwebtoken.lang.Collections;
import io.jsonwebtoken.security.EcPublicJwk; import io.jsonwebtoken.security.EcPublicJwk;
import io.jsonwebtoken.security.PublicJwk;
import java.math.BigInteger; import java.math.BigInteger;
import java.security.interfaces.ECPublicKey; import java.security.interfaces.ECPublicKey;
@ -39,4 +41,15 @@ class DefaultEcPublicJwk extends AbstractPublicJwk<ECPublicKey> implements EcPub
DefaultEcPublicJwk(JwkContext<ECPublicKey> ctx) { DefaultEcPublicJwk(JwkContext<ECPublicKey> ctx) {
super(ctx, THUMBPRINT_FIELDS); super(ctx, THUMBPRINT_FIELDS);
} }
static boolean equalsPublic(FieldReadable self, Object candidate) {
return Fields.equals(self, candidate, CRV) &&
Fields.equals(self, candidate, X) &&
Fields.equals(self, candidate, Y);
}
@Override
protected boolean equals(PublicJwk<?> jwk) {
return jwk instanceof EcPublicJwk && equalsPublic(this, jwk);
}
} }

View File

@ -53,7 +53,7 @@ final class DefaultMacAlgorithm extends AbstractSecureDigestAlgorithm<SecretKey,
static final DefaultMacAlgorithm HS384 = new DefaultMacAlgorithm(384); static final DefaultMacAlgorithm HS384 = new DefaultMacAlgorithm(384);
static final DefaultMacAlgorithm HS512 = new DefaultMacAlgorithm(512); static final DefaultMacAlgorithm HS512 = new DefaultMacAlgorithm(512);
private static final Map<String, MacAlgorithm> JCA_NAME_MAP; private static final Map<String, DefaultMacAlgorithm> JCA_NAME_MAP;
static { static {
JCA_NAME_MAP = new LinkedHashMap<>(6); JCA_NAME_MAP = new LinkedHashMap<>(6);
@ -96,7 +96,7 @@ final class DefaultMacAlgorithm extends AbstractSecureDigestAlgorithm<SecretKey,
return JCA_NAME_MAP.containsKey(key); return JCA_NAME_MAP.containsKey(key);
} }
static MacAlgorithm findByKey(Key key) { static DefaultMacAlgorithm findByKey(Key key) {
String alg = KeysBridge.findAlgorithm(key); String alg = KeysBridge.findAlgorithm(key);
if (!Strings.hasText(alg)) { if (!Strings.hasText(alg)) {
@ -104,7 +104,7 @@ final class DefaultMacAlgorithm extends AbstractSecureDigestAlgorithm<SecretKey,
} }
String upper = alg.toUpperCase(Locale.ENGLISH); String upper = alg.toUpperCase(Locale.ENGLISH);
MacAlgorithm mac = JCA_NAME_MAP.get(upper); DefaultMacAlgorithm mac = JCA_NAME_MAP.get(upper);
if (mac == null) { if (mac == null) {
return null; return null;
} }

View File

@ -20,12 +20,16 @@ import io.jsonwebtoken.impl.lang.Fields;
import io.jsonwebtoken.lang.Collections; import io.jsonwebtoken.lang.Collections;
import io.jsonwebtoken.security.OctetPrivateJwk; import io.jsonwebtoken.security.OctetPrivateJwk;
import io.jsonwebtoken.security.OctetPublicJwk; import io.jsonwebtoken.security.OctetPublicJwk;
import io.jsonwebtoken.security.PrivateJwk;
import java.security.PrivateKey; import java.security.PrivateKey;
import java.security.PublicKey; import java.security.PublicKey;
import java.util.Set; import java.util.Set;
public class DefaultOctetPrivateJwk<T extends PrivateKey, P extends PublicKey> extends AbstractPrivateJwk<T, P, OctetPublicJwk<P>> implements OctetPrivateJwk<T, P> { import static io.jsonwebtoken.impl.security.DefaultOctetPublicJwk.equalsPublic;
public class DefaultOctetPrivateJwk<T extends PrivateKey, P extends PublicKey>
extends AbstractPrivateJwk<T, P, OctetPublicJwk<P>> implements OctetPrivateJwk<T, P> {
static final Field<byte[]> D = Fields.bytes("d", "The private key").setSecret(true).build(); static final Field<byte[]> D = Fields.bytes("d", "The private key").setSecret(true).build();
@ -37,4 +41,9 @@ public class DefaultOctetPrivateJwk<T extends PrivateKey, P extends PublicKey> e
// https://www.rfc-editor.org/rfc/rfc7638#section-3.2.1 // https://www.rfc-editor.org/rfc/rfc7638#section-3.2.1
DefaultOctetPublicJwk.THUMBPRINT_FIELDS, pubJwk); DefaultOctetPublicJwk.THUMBPRINT_FIELDS, pubJwk);
} }
@Override
protected boolean equals(PrivateJwk<?, ?, ?> jwk) {
return jwk instanceof OctetPrivateJwk && equalsPublic(this, jwk) && Fields.equals(this, jwk, D);
}
} }

View File

@ -16,9 +16,11 @@
package io.jsonwebtoken.impl.security; package io.jsonwebtoken.impl.security;
import io.jsonwebtoken.impl.lang.Field; import io.jsonwebtoken.impl.lang.Field;
import io.jsonwebtoken.impl.lang.FieldReadable;
import io.jsonwebtoken.impl.lang.Fields; import io.jsonwebtoken.impl.lang.Fields;
import io.jsonwebtoken.lang.Collections; import io.jsonwebtoken.lang.Collections;
import io.jsonwebtoken.security.OctetPublicJwk; import io.jsonwebtoken.security.OctetPublicJwk;
import io.jsonwebtoken.security.PublicJwk;
import java.security.PublicKey; import java.security.PublicKey;
import java.util.List; import java.util.List;
@ -37,4 +39,13 @@ public class DefaultOctetPublicJwk<T extends PublicKey> extends AbstractPublicJw
DefaultOctetPublicJwk(JwkContext<T> ctx) { DefaultOctetPublicJwk(JwkContext<T> ctx) {
super(ctx, THUMBPRINT_FIELDS); super(ctx, THUMBPRINT_FIELDS);
} }
static boolean equalsPublic(FieldReadable self, Object candidate) {
return Fields.equals(self, candidate, CRV) && Fields.equals(self, candidate, X);
}
@Override
protected boolean equals(PublicJwk<?> jwk) {
return jwk instanceof OctetPublicJwk && equalsPublic(this, jwk);
}
} }

View File

@ -16,8 +16,10 @@
package io.jsonwebtoken.impl.security; package io.jsonwebtoken.impl.security;
import io.jsonwebtoken.impl.lang.Field; import io.jsonwebtoken.impl.lang.Field;
import io.jsonwebtoken.impl.lang.FieldReadable;
import io.jsonwebtoken.impl.lang.Fields; import io.jsonwebtoken.impl.lang.Fields;
import io.jsonwebtoken.lang.Collections; import io.jsonwebtoken.lang.Collections;
import io.jsonwebtoken.security.PrivateJwk;
import io.jsonwebtoken.security.RsaPrivateJwk; import io.jsonwebtoken.security.RsaPrivateJwk;
import io.jsonwebtoken.security.RsaPublicJwk; import io.jsonwebtoken.security.RsaPublicJwk;
@ -28,6 +30,8 @@ import java.security.spec.RSAOtherPrimeInfo;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import static io.jsonwebtoken.impl.security.DefaultRsaPublicJwk.equalsPublic;
class DefaultRsaPrivateJwk extends AbstractPrivateJwk<RSAPrivateKey, RSAPublicKey, RsaPublicJwk> implements RsaPrivateJwk { class DefaultRsaPrivateJwk extends AbstractPrivateJwk<RSAPrivateKey, RSAPublicKey, RsaPublicJwk> implements RsaPrivateJwk {
static final Field<BigInteger> PRIVATE_EXPONENT = Fields.secretBigInt("d", "Private Exponent"); static final Field<BigInteger> PRIVATE_EXPONENT = Fields.secretBigInt("d", "Private Exponent");
@ -54,4 +58,39 @@ class DefaultRsaPrivateJwk extends AbstractPrivateJwk<RSAPrivateKey, RSAPublicKe
DefaultRsaPublicJwk.THUMBPRINT_FIELDS, DefaultRsaPublicJwk.THUMBPRINT_FIELDS,
pubJwk); pubJwk);
} }
private static boolean equals(RSAOtherPrimeInfo a, RSAOtherPrimeInfo b) {
if (a == b) return true;
if (a == null || b == null) return false;
return Fields.bytesEquals(a.getPrime(), b.getPrime()) &&
Fields.bytesEquals(a.getExponent(), b.getExponent()) &&
Fields.bytesEquals(a.getCrtCoefficient(), b.getCrtCoefficient());
}
private static boolean equalsOtherPrimes(FieldReadable a, FieldReadable b) {
List<RSAOtherPrimeInfo> aOthers = a.get(OTHER_PRIMES_INFO);
List<RSAOtherPrimeInfo> bOthers = b.get(OTHER_PRIMES_INFO);
int aSize = Collections.size(aOthers);
int bSize = Collections.size(bOthers);
if (aSize != bSize) return false;
if (aSize == 0) return true;
RSAOtherPrimeInfo[] aInfos = aOthers.toArray(new RSAOtherPrimeInfo[0]);
RSAOtherPrimeInfo[] bInfos = bOthers.toArray(new RSAOtherPrimeInfo[0]);
for (int i = 0; i < aSize; i++) {
if (!equals(aInfos[i], bInfos[i])) return false;
}
return true;
}
@Override
protected boolean equals(PrivateJwk<?, ?, ?> jwk) {
return jwk instanceof RsaPrivateJwk && equalsPublic(this, jwk) &&
Fields.equals(this, jwk, PRIVATE_EXPONENT) &&
Fields.equals(this, jwk, FIRST_PRIME) &&
Fields.equals(this, jwk, SECOND_PRIME) &&
Fields.equals(this, jwk, FIRST_CRT_EXPONENT) &&
Fields.equals(this, jwk, SECOND_CRT_EXPONENT) &&
Fields.equals(this, jwk, FIRST_CRT_COEFFICIENT) &&
equalsOtherPrimes(this, (FieldReadable) jwk);
}
} }

View File

@ -16,8 +16,10 @@
package io.jsonwebtoken.impl.security; package io.jsonwebtoken.impl.security;
import io.jsonwebtoken.impl.lang.Field; import io.jsonwebtoken.impl.lang.Field;
import io.jsonwebtoken.impl.lang.FieldReadable;
import io.jsonwebtoken.impl.lang.Fields; import io.jsonwebtoken.impl.lang.Fields;
import io.jsonwebtoken.lang.Collections; import io.jsonwebtoken.lang.Collections;
import io.jsonwebtoken.security.PublicJwk;
import io.jsonwebtoken.security.RsaPublicJwk; import io.jsonwebtoken.security.RsaPublicJwk;
import java.math.BigInteger; import java.math.BigInteger;
@ -38,4 +40,13 @@ class DefaultRsaPublicJwk extends AbstractPublicJwk<RSAPublicKey> implements Rsa
DefaultRsaPublicJwk(JwkContext<RSAPublicKey> ctx) { DefaultRsaPublicJwk(JwkContext<RSAPublicKey> ctx) {
super(ctx, THUMBPRINT_FIELDS); super(ctx, THUMBPRINT_FIELDS);
} }
static boolean equalsPublic(FieldReadable self, Object candidate) {
return Fields.equals(self, candidate, MODULUS) && Fields.equals(self, candidate, PUBLIC_EXPONENT);
}
@Override
protected boolean equals(PublicJwk<?> jwk) {
return jwk instanceof RsaPublicJwk && equalsPublic(this, jwk);
}
} }

View File

@ -18,6 +18,7 @@ package io.jsonwebtoken.impl.security;
import io.jsonwebtoken.impl.lang.Field; import io.jsonwebtoken.impl.lang.Field;
import io.jsonwebtoken.impl.lang.Fields; import io.jsonwebtoken.impl.lang.Fields;
import io.jsonwebtoken.lang.Collections; import io.jsonwebtoken.lang.Collections;
import io.jsonwebtoken.security.Jwk;
import io.jsonwebtoken.security.SecretJwk; import io.jsonwebtoken.security.SecretJwk;
import javax.crypto.SecretKey; import javax.crypto.SecretKey;
@ -36,4 +37,9 @@ class DefaultSecretJwk extends AbstractJwk<SecretKey> implements SecretJwk {
DefaultSecretJwk(JwkContext<SecretKey> ctx) { DefaultSecretJwk(JwkContext<SecretKey> ctx) {
super(ctx, THUMBPRINT_FIELDS); super(ctx, THUMBPRINT_FIELDS);
} }
@Override
protected boolean equals(Jwk<?> jwk) {
return jwk instanceof SecretJwk && Fields.equals(this, jwk, K);
}
} }

View File

@ -144,7 +144,7 @@ class AbstractProtectedHeaderTest {
h([jwk: jwk]) h([jwk: jwk])
fail() fail()
} catch (IllegalArgumentException expected) { } catch (IllegalArgumentException expected) {
String msg = "Invalid JWT header 'jwk' (JSON Web Key) value: {kty=oct, k=<redacted>}. " + String msg = "Invalid JWT header 'jwk' (JSON Web Key) value: {alg=HS256, kty=oct, k=<redacted>}. " +
"Value must be a Public JWK, not a Secret JWK." "Value must be a Public JWK, not a Secret JWK."
assertEquals msg, expected.getMessage() assertEquals msg, expected.getMessage()
} }

View File

@ -63,7 +63,7 @@ class DefaultJweHeaderTest {
h([epk: values]) h([epk: values])
fail() fail()
} catch (IllegalArgumentException expected) { } catch (IllegalArgumentException expected) {
String msg = "Invalid JWE header 'epk' (Ephemeral Public Key) value: {kty=oct, k=<redacted>}. " + String msg = "Invalid JWE header 'epk' (Ephemeral Public Key) value: {alg=HS256, kty=oct, k=<redacted>}. " +
"Value must be a Public JWK, not a Secret JWK." "Value must be a Public JWK, not a Secret JWK."
assertEquals msg, expected.getMessage() assertEquals msg, expected.getMessage()
} }

View File

@ -290,4 +290,39 @@ class BytesTest {
Bytes.length(-1) Bytes.length(-1)
} }
@Test
void testClearNull() {
Bytes.clear(null) // no exception
}
@Test
void testClearEmpty() {
Bytes.clear(Bytes.EMPTY) // no exception
}
@Test
void testClear() {
int len = 16
byte[] bytes = Bytes.random(len)
boolean allZero = true
for(int i = 0; i < len; i++) {
if (bytes[i] != (byte)0) {
allZero = false
break
}
}
assertFalse allZero // guarantee that we start with random bytes
Bytes.clear(bytes)
allZero = true
for(int i = 0; i < len; i++) {
if (bytes[i] != (byte)0) {
allZero = false
break
}
}
assertTrue allZero // asserts zeroed out entirely
}
} }

View File

@ -221,4 +221,70 @@ class FieldsTest {
def field = Fields.builder(String.class).setId('foo').setName("FooName").build() def field = Fields.builder(String.class).setId('foo').setName("FooName").build()
assertFalse field.equals(new Object()) assertFalse field.equals(new Object())
} }
@Test
void testBigIntegerBytesNull() {
assertNull Fields.bytes(null)
}
@Test
void testBytesEqualsWhenBothAreNull() {
assertTrue Fields.bytesEquals(null, null)
}
@Test
void testBytesEqualsIdentity() {
assertTrue Fields.bytesEquals(BigInteger.ONE, BigInteger.ONE)
}
@Test
void testBytesEqualsWhenAIsNull() {
assertFalse Fields.bytesEquals(null, BigInteger.ONE)
}
@Test
void testBytesEqualsWhenBIsNull() {
assertFalse Fields.bytesEquals(BigInteger.ONE, null)
}
@Test
void testFieldValueEqualsWhenAIsNull() {
BigInteger a = null
BigInteger b = BigInteger.ONE
Field<BigInteger> field = Fields.bigInt('foo', 'bar').build()
assertFalse Fields.equals(a, b, field)
}
@Test
void testFieldValueEqualsWhenBIsNull() {
BigInteger a = BigInteger.ONE
BigInteger b = null
Field<BigInteger> field = Fields.bigInt('foo', 'bar').build()
assertFalse Fields.equals(a, b, field)
}
@Test
void testFieldValueEqualsSecretString() {
String a = 'hello'
String b = new String('hello'.toCharArray()) // new instance not in the string table (Groovy side effect)
Field<String> field = Fields.builder(String.class).setId('foo').setName('bar').setSecret(true).build()
assertTrue Fields.equals(a, b, field)
}
@Test
void testEqualsIdentity() {
FieldReadable r = new TestFieldReadable()
assertTrue Fields.equals(r, r, Fields.string('foo', 'bar'))
}
@Test
void testEqualsWhenAIsNull() {
assertFalse Fields.equals(null, "hello", Fields.string('foo', 'bar'))
}
@Test
void testEqualsWhenAIsFieldReadableButBIsNot() {
FieldReadable r = new TestFieldReadable()
assertFalse Fields.equals(r, "hello", Fields.string('foo', 'bar'))
}
} }

View File

@ -17,8 +17,7 @@ package io.jsonwebtoken.impl.lang
import org.junit.Test import org.junit.Test
import static org.junit.Assert.assertFalse import static org.junit.Assert.*
import static org.junit.Assert.assertTrue
class RedactedSupplierTest { class RedactedSupplierTest {
@ -43,4 +42,16 @@ class RedactedSupplierTest {
assertFalse new RedactedSupplier<>(42).equals(new RedactedSupplier(30)) assertFalse new RedactedSupplier<>(42).equals(new RedactedSupplier(30))
} }
@Test
void testEqualsIdentity() {
def supplier = new RedactedSupplier('hello')
assertEquals supplier, supplier
}
@Test
void testHashCode() {
int hashCode = 42.hashCode()
assertEquals hashCode, new RedactedSupplier(42).hashCode()
}
} }

View File

@ -0,0 +1,26 @@
/*
* Copyright © 2023 jsonwebtoken.io
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.jsonwebtoken.impl.lang
class TestFieldReadable implements FieldReadable {
def value = null
@Override
Object get(Field field) {
return value
}
}

View File

@ -16,7 +16,9 @@
package io.jsonwebtoken.impl.security package io.jsonwebtoken.impl.security
import io.jsonwebtoken.lang.Collections import io.jsonwebtoken.lang.Collections
import io.jsonwebtoken.security.Jwk
import io.jsonwebtoken.security.Jwks import io.jsonwebtoken.security.Jwks
import io.jsonwebtoken.security.SecretJwk
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
@ -44,7 +46,12 @@ class AbstractJwkTest {
} }
static AbstractJwk<SecretKey> newJwk(JwkContext<SecretKey> ctx) { static AbstractJwk<SecretKey> newJwk(JwkContext<SecretKey> ctx) {
return new AbstractJwk(ctx, Collections.of(AbstractJwk.KTY)) {} return new AbstractJwk(ctx, Collections.of(AbstractJwk.KTY)) {
@Override
protected boolean equals(Jwk jwk) {
return this.@context.equals(jwk.@context)
}
}
} }
@Before @Before
@ -144,24 +151,22 @@ class AbstractJwkTest {
@Test @Test
void testPrivateJwkHashCode() { void testPrivateJwkHashCode() {
assertEquals jwk.hashCode(), jwk.@context.hashCode()
def secretJwk1 = Jwks.builder().key(TestKeys.HS256).add('hello', 'world').build() def secretJwk1 = Jwks.builder().key(TestKeys.HS256).add('hello', 'world').build()
def secretJwk2 = Jwks.builder().key(TestKeys.HS256).add('hello', 'world').build() def secretJwk2 = Jwks.builder().key(TestKeys.HS256).add('hello', 'world').build()
assertEquals secretJwk1.hashCode(), secretJwk1.@context.hashCode()
assertEquals secretJwk2.hashCode(), secretJwk2.@context.hashCode()
assertEquals secretJwk1.hashCode(), secretJwk2.hashCode() assertEquals secretJwk1.hashCode(), secretJwk2.hashCode()
def ecPrivJwk1 = Jwks.builder().key(TestKeys.ES256.pair.private).add('hello', 'ecworld').build() def ecPrivJwk1 = Jwks.builder().key(TestKeys.ES256.pair.private).add('hello', 'ecworld').build()
def ecPrivJwk2 = Jwks.builder().key(TestKeys.ES256.pair.private).add('hello', 'ecworld').build() def ecPrivJwk2 = Jwks.builder().key(TestKeys.ES256.pair.private).add('hello', 'ecworld').build()
assertEquals ecPrivJwk1.hashCode(), ecPrivJwk2.hashCode() assertEquals ecPrivJwk1.hashCode(), ecPrivJwk2.hashCode()
assertEquals ecPrivJwk1.hashCode(), ecPrivJwk1.@context.hashCode()
assertEquals ecPrivJwk2.hashCode(), ecPrivJwk2.@context.hashCode()
def rsaPrivJwk1 = Jwks.builder().key(TestKeys.RS256.pair.private).add('hello', 'rsaworld').build() def rsaPrivJwk1 = Jwks.builder().key(TestKeys.RS256.pair.private).add('hello', 'rsaworld').build()
def rsaPrivJwk2 = Jwks.builder().key(TestKeys.RS256.pair.private).add('hello', 'rsaworld').build() def rsaPrivJwk2 = Jwks.builder().key(TestKeys.RS256.pair.private).add('hello', 'rsaworld').build()
assertEquals rsaPrivJwk1.hashCode(), rsaPrivJwk2.hashCode() assertEquals rsaPrivJwk1.hashCode(), rsaPrivJwk2.hashCode()
assertEquals rsaPrivJwk1.hashCode(), rsaPrivJwk1.@context.hashCode() }
assertEquals rsaPrivJwk2.hashCode(), rsaPrivJwk2.@context.hashCode()
@Test
void testEqualsWithNonJwk() {
SecretJwk jwk = Jwks.builder().key(TestKeys.HS256).build()
assertFalse jwk.equals(42)
} }
} }

View File

@ -0,0 +1,73 @@
/*
* Copyright © 2023 jsonwebtoken.io
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.jsonwebtoken.impl.security
import io.jsonwebtoken.impl.lang.FieldReadable
import io.jsonwebtoken.impl.lang.TestFieldReadable
import org.junit.Test
import java.security.spec.RSAOtherPrimeInfo
import static org.junit.Assert.assertFalse
import static org.junit.Assert.assertTrue
class DefaultRsaPrivateJwkTest {
@Test
void testEqualsOtherPrimesDifferentSizes() {
def info1 = new RSAOtherPrimeInfo(BigInteger.ONE, BigInteger.ONE, BigInteger.ONE)
def info2 = new RSAOtherPrimeInfo(BigInteger.TEN, BigInteger.TEN, BigInteger.TEN)
FieldReadable a = new TestFieldReadable(value: [info1, info2])
FieldReadable b = new TestFieldReadable(value: [info1]) // different sizes
assertFalse DefaultRsaPrivateJwk.equalsOtherPrimes(a, b)
}
@Test
void testEqualsOtherPrimes() {
def info1 = new RSAOtherPrimeInfo(BigInteger.ONE, BigInteger.ONE, BigInteger.ONE)
def info2 = new RSAOtherPrimeInfo(BigInteger.ONE, BigInteger.ONE, BigInteger.ONE)
FieldReadable a = new TestFieldReadable(value: [info1])
FieldReadable b = new TestFieldReadable(value: [info2])
assertTrue DefaultRsaPrivateJwk.equalsOtherPrimes(a, b)
}
@Test
void testEqualsOtherPrimesIdentity() {
def info1 = new RSAOtherPrimeInfo(BigInteger.ONE, BigInteger.ONE, BigInteger.ONE)
FieldReadable a = new TestFieldReadable(value: [info1])
FieldReadable b = new TestFieldReadable(value: [info1])
assertTrue DefaultRsaPrivateJwk.equalsOtherPrimes(a, b)
}
@Test
void testEqualsOtherPrimesNullElement() {
def info1 = new RSAOtherPrimeInfo(BigInteger.ONE, BigInteger.ONE, BigInteger.ONE)
// sizes are the same, but one element is null:
FieldReadable a = new TestFieldReadable(value: [null])
FieldReadable b = new TestFieldReadable(value: [info1])
assertFalse DefaultRsaPrivateJwk.equalsOtherPrimes(a, b)
}
@Test
void testEqualsOtherPrimesInfoNotEqual() {
def info1 = new RSAOtherPrimeInfo(BigInteger.ONE, BigInteger.ONE, BigInteger.ONE)
def info2 = new RSAOtherPrimeInfo(BigInteger.ONE, BigInteger.ONE, BigInteger.TEN) // different
FieldReadable a = new TestFieldReadable(value: [info1])
FieldReadable b = new TestFieldReadable(value: [info2])
assertFalse DefaultRsaPrivateJwk.equalsOtherPrimes(a, b)
}
}

View File

@ -55,14 +55,14 @@ class JwksTest {
//test non-null value: //test non-null value:
//noinspection GroovyAssignabilityCheck //noinspection GroovyAssignabilityCheck
def builder = Jwks.builder().key(key) def builder = Jwks.builder().key(key).delete('alg') // delete alg put there by SecretKeyBuilder
builder."$name"(val) builder."$name"(val)
def jwk = builder.build() def jwk = builder.build()
assertEquals val, jwk."get${cap}"() assertEquals val, jwk."get${cap}"()
assertEquals expectedFieldValue, jwk."${id}" assertEquals expectedFieldValue, jwk."${id}"
//test null value: //test null value:
builder = Jwks.builder().key(key) builder = Jwks.builder().key(key).delete('alg')
try { try {
builder."$name"(null) builder."$name"(null)
fail("IAE should have been thrown") fail("IAE should have been thrown")
@ -74,7 +74,7 @@ class JwksTest {
assertFalse jwk.containsKey(id) assertFalse jwk.containsKey(id)
//test empty string value //test empty string value
builder = Jwks.builder().key(key) builder = Jwks.builder().key(key).delete('alg')
if (val instanceof String) { if (val instanceof String) {
try { try {
builder."$name"(' ' as String) builder."$name"(' ' as String)

View File

@ -33,7 +33,7 @@ class SecretJwkFactoryTest {
@Test @Test
// if a jwk does not have an 'alg' or 'use' field, we default to an AES key // if a jwk does not have an 'alg' or 'use' field, we default to an AES key
void testNoAlgNoSigJcaName() { void testNoAlgNoSigJcaName() {
SecretJwk jwk = Jwks.builder().key(TestKeys.HS256).build() SecretJwk jwk = Jwks.builder().key(TestKeys.HS256).delete('alg').build()
SecretJwk result = Jwks.builder().add(jwk).build() as SecretJwk SecretJwk result = Jwks.builder().add(jwk).build() as SecretJwk
assertEquals 'AES', result.toKey().getAlgorithm() assertEquals 'AES', result.toKey().getAlgorithm()
} }
@ -41,13 +41,13 @@ class SecretJwkFactoryTest {
@Test @Test
void testJwkHS256AlgSetsKeyJcaNameCorrectly() { void testJwkHS256AlgSetsKeyJcaNameCorrectly() {
SecretJwk jwk = Jwks.builder().key(TestKeys.HS256).build() SecretJwk jwk = Jwks.builder().key(TestKeys.HS256).build()
SecretJwk result = Jwks.builder().add(jwk).add('alg', 'HS256').build() as SecretJwk SecretJwk result = Jwks.builder().add(jwk).build() as SecretJwk
assertEquals 'HmacSHA256', result.toKey().getAlgorithm() assertEquals 'HmacSHA256', result.toKey().getAlgorithm()
} }
@Test @Test
void testSignOpSetsKeyHmacSHA256() { void testSignOpSetsKeyHmacSHA256() {
SecretJwk jwk = Jwks.builder().key(TestKeys.HS256).build() SecretJwk jwk = Jwks.builder().key(TestKeys.HS256).delete('alg').build()
SecretJwk result = Jwks.builder().add(jwk).operations([Jwks.OP.SIGN]).build() as SecretJwk SecretJwk result = Jwks.builder().add(jwk).operations([Jwks.OP.SIGN]).build() as SecretJwk
assertNull result.getAlgorithm() assertNull result.getAlgorithm()
assertNull result.get('use') assertNull result.get('use')
@ -57,13 +57,13 @@ class SecretJwkFactoryTest {
@Test @Test
void testJwkHS384AlgSetsKeyJcaNameCorrectly() { void testJwkHS384AlgSetsKeyJcaNameCorrectly() {
SecretJwk jwk = Jwks.builder().key(TestKeys.HS384).build() SecretJwk jwk = Jwks.builder().key(TestKeys.HS384).build()
SecretJwk result = Jwks.builder().add(jwk).add('alg', 'HS384').build() as SecretJwk SecretJwk result = Jwks.builder().add(jwk).build() as SecretJwk
assertEquals 'HmacSHA384', result.toKey().getAlgorithm() assertEquals 'HmacSHA384', result.toKey().getAlgorithm()
} }
@Test @Test
void testSignOpSetsKeyHmacSHA384() { void testSignOpSetsKeyHmacSHA384() {
SecretJwk jwk = Jwks.builder().key(TestKeys.HS384).build() SecretJwk jwk = Jwks.builder().key(TestKeys.HS384).delete('alg').build()
SecretJwk result = Jwks.builder().add(jwk).operations([Jwks.OP.SIGN]).build() as SecretJwk SecretJwk result = Jwks.builder().add(jwk).operations([Jwks.OP.SIGN]).build() as SecretJwk
assertNull result.getAlgorithm() assertNull result.getAlgorithm()
assertNull result.get('use') assertNull result.get('use')
@ -73,13 +73,13 @@ class SecretJwkFactoryTest {
@Test @Test
void testJwkHS512AlgSetsKeyJcaNameCorrectly() { void testJwkHS512AlgSetsKeyJcaNameCorrectly() {
SecretJwk jwk = Jwks.builder().key(TestKeys.HS512).build() SecretJwk jwk = Jwks.builder().key(TestKeys.HS512).build()
SecretJwk result = Jwks.builder().add(jwk).add('alg', 'HS512').build() as SecretJwk SecretJwk result = Jwks.builder().add(jwk).build() as SecretJwk
assertEquals 'HmacSHA512', result.toKey().getAlgorithm() assertEquals 'HmacSHA512', result.toKey().getAlgorithm()
} }
@Test @Test
void testSignOpSetsKeyHmacSHA512() { void testSignOpSetsKeyHmacSHA512() {
SecretJwk jwk = Jwks.builder().key(TestKeys.HS512).build() SecretJwk jwk = Jwks.builder().key(TestKeys.HS512).delete('alg').build()
SecretJwk result = Jwks.builder().add(jwk).operations([Jwks.OP.SIGN]).build() as SecretJwk SecretJwk result = Jwks.builder().add(jwk).operations([Jwks.OP.SIGN]).build() as SecretJwk
assertNull result.getAlgorithm() assertNull result.getAlgorithm()
assertNull result.get('use') assertNull result.get('use')
@ -89,7 +89,7 @@ class SecretJwkFactoryTest {
@Test @Test
// no 'alg' jwk property, but 'use' is 'sig', so forces jcaName to be HmacSHA256 // no 'alg' jwk property, but 'use' is 'sig', so forces jcaName to be HmacSHA256
void testNoAlgAndSigUseForHS256() { void testNoAlgAndSigUseForHS256() {
SecretJwk jwk = Jwks.builder().key(TestKeys.HS256).build() SecretJwk jwk = Jwks.builder().key(TestKeys.HS256).delete('alg').build()
assertFalse jwk.containsKey('alg') assertFalse jwk.containsKey('alg')
assertFalse jwk.containsKey('use') assertFalse jwk.containsKey('use')
SecretJwk result = Jwks.builder().add(jwk).add('use', 'sig').build() as SecretJwk SecretJwk result = Jwks.builder().add(jwk).add('use', 'sig').build() as SecretJwk
@ -99,7 +99,7 @@ class SecretJwkFactoryTest {
@Test @Test
// no 'alg' jwk property, but 'use' is 'sig', so forces jcaName to be HmacSHA384 // no 'alg' jwk property, but 'use' is 'sig', so forces jcaName to be HmacSHA384
void testNoAlgAndSigUseForHS384() { void testNoAlgAndSigUseForHS384() {
SecretJwk jwk = Jwks.builder().key(TestKeys.HS384).build() SecretJwk jwk = Jwks.builder().key(TestKeys.HS384).delete('alg').build()
assertFalse jwk.containsKey('alg') assertFalse jwk.containsKey('alg')
assertFalse jwk.containsKey('use') assertFalse jwk.containsKey('use')
SecretJwk result = Jwks.builder().add(jwk).add('use', 'sig').build() as SecretJwk SecretJwk result = Jwks.builder().add(jwk).add('use', 'sig').build() as SecretJwk
@ -109,7 +109,7 @@ class SecretJwkFactoryTest {
@Test @Test
// no 'alg' jwk property, but 'use' is 'sig', so forces jcaName to be HmacSHA512 // no 'alg' jwk property, but 'use' is 'sig', so forces jcaName to be HmacSHA512
void testNoAlgAndSigUseForHS512() { void testNoAlgAndSigUseForHS512() {
SecretJwk jwk = Jwks.builder().key(TestKeys.HS512).build() SecretJwk jwk = Jwks.builder().key(TestKeys.HS512).delete('alg').build()
assertFalse jwk.containsKey('alg') assertFalse jwk.containsKey('alg')
assertFalse jwk.containsKey('use') assertFalse jwk.containsKey('use')
SecretJwk result = Jwks.builder().add(jwk).add('use', 'sig').build() as SecretJwk SecretJwk result = Jwks.builder().add(jwk).add('use', 'sig').build() as SecretJwk
@ -119,7 +119,7 @@ class SecretJwkFactoryTest {
@Test @Test
// no 'alg' jwk property, but 'use' is something other than 'sig', so jcaName should default to AES // no 'alg' jwk property, but 'use' is something other than 'sig', so jcaName should default to AES
void testNoAlgAndNonSigUse() { void testNoAlgAndNonSigUse() {
SecretJwk jwk = Jwks.builder().key(TestKeys.HS256).build() SecretJwk jwk = Jwks.builder().key(TestKeys.HS256).delete('alg').build()
assertFalse jwk.containsKey('alg') assertFalse jwk.containsKey('alg')
assertFalse jwk.containsKey('use') assertFalse jwk.containsKey('use')
SecretJwk result = Jwks.builder().add(jwk).add('use', 'foo').build() as SecretJwk SecretJwk result = Jwks.builder().add(jwk).add('use', 'foo').build() as SecretJwk