Consolidate BouncyCastle lookup/fallback logic to JcaTemplate (#798)

* Consolidating BouncyCastle lookup/fallback behavior to JcaTemplate to avoid complexity in Algorithm implementations

* Added JcaTemplate generateX509Certificate helper method to enable BC-fallback behavior if necessary

* Further reduced code dependencies on Providers class.  Now only used by JcaTemplate, JcaTemplateTest and ProvidersTest

* Removed Condition and Conditions concepts - no longer needed now that Providers.java no longer requires conditional loading
This commit is contained in:
lhazlewood 2023-08-24 11:56:23 -07:00 committed by GitHub
parent ed98f3d706
commit eca568ec16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 738 additions and 1065 deletions

View File

@ -1,23 +0,0 @@
/*
* Copyright © 2021 jsonwebtoken.io
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.jsonwebtoken.impl.lang;
/**
* @since JJWT_RELEASE_VERSION
*/
public interface Condition {
boolean test();
}

View File

@ -1,90 +0,0 @@
/*
* Copyright © 2021 jsonwebtoken.io
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.jsonwebtoken.impl.lang;
import io.jsonwebtoken.lang.Assert;
/**
* @since JJWT_RELEASE_VERSION
*/
public final class Conditions {
private Conditions() {
}
public static final Condition TRUE = of(true);
public static Condition of(boolean val) {
return new BooleanCondition(val);
}
public static Condition not(Condition c) {
return new NotCondition(c);
}
public static Condition exists(CheckedSupplier<?> s) {
return new ExistsCondition(s);
}
public static Condition notExists(CheckedSupplier<?> s) {
return not(exists(s));
}
private static final class NotCondition implements Condition {
private final Condition c;
private NotCondition(Condition c) {
this.c = Assert.notNull(c, "Condition cannot be null.");
}
@Override
public boolean test() {
return !c.test();
}
}
private static final class BooleanCondition implements Condition {
private final boolean value;
public BooleanCondition(boolean value) {
this.value = value;
}
@Override
public boolean test() {
return value;
}
}
private static final class ExistsCondition implements Condition {
private final CheckedSupplier<?> supplier;
ExistsCondition(CheckedSupplier<?> supplier) {
this.supplier = Assert.notNull(supplier, "CheckedSupplier cannot be null.");
}
@Override
public boolean test() {
Object value = null;
try {
value = supplier.get();
} catch (Throwable ignored) {
}
return value != null;
}
}
}

View File

@ -24,19 +24,12 @@ package io.jsonwebtoken.impl.lang;
*/ */
public final class ConstantFunction<T, R> implements Function<T, R> { public final class ConstantFunction<T, R> implements Function<T, R> {
private static final Function<?, ?> NULL = new ConstantFunction<>(null);
private final R value; private final R value;
public ConstantFunction(R value) { public ConstantFunction(R value) {
this.value = value; this.value = value;
} }
@SuppressWarnings("unchecked")
public static <T, R> Function<T, R> forNull() {
return (Function<T, R>) NULL;
}
@Override @Override
public R apply(T t) { public R apply(T t) {
return this.value; return this.value;

View File

@ -22,10 +22,6 @@ public final class Functions {
private Functions() { private Functions() {
} }
public static <T, R> Function<T, R> forNull() {
return ConstantFunction.forNull();
}
public static <T> Function<T, T> identity() { public static <T> Function<T, T> identity() {
return new Function<T, T>() { return new Function<T, T>() {
@Override @Override

View File

@ -1,72 +0,0 @@
/*
* Copyright © 2023 jsonwebtoken.io
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.jsonwebtoken.impl.lang;
import io.jsonwebtoken.lang.Arrays;
import io.jsonwebtoken.lang.Assert;
import io.jsonwebtoken.lang.Classes;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.List;
public class OptionalCtorInvoker<T> extends ReflectionFunction<Object, T> {
private final Constructor<T> CTOR;
public OptionalCtorInvoker(String fqcn, Object... ctorArgTypesOrFqcns) {
Assert.hasText(fqcn, "fqcn cannot be null.");
Constructor<T> ctor = null;
try {
Class<T> clazz = Classes.forName(fqcn);
Class<?>[] ctorArgTypes = null;
if (Arrays.length(ctorArgTypesOrFqcns) > 0) {
ctorArgTypes = new Class<?>[ctorArgTypesOrFqcns.length];
List<Class<?>> l = new ArrayList<>(ctorArgTypesOrFqcns.length);
for (Object ctorArgTypeOrFqcn : ctorArgTypesOrFqcns) {
Class<?> ctorArgClass;
if (ctorArgTypeOrFqcn instanceof Class<?>) {
ctorArgClass = (Class<?>) ctorArgTypeOrFqcn;
} else {
String typeFqcn = Assert.isInstanceOf(String.class, ctorArgTypeOrFqcn, "ctorArgTypesOrFcqns array must contain Class or String instances.");
ctorArgClass = Classes.forName(typeFqcn);
}
l.add(ctorArgClass);
}
ctorArgTypes = l.toArray(ctorArgTypes);
}
ctor = Classes.getConstructor(clazz, ctorArgTypes);
} catch (Exception ignored) {
}
this.CTOR = ctor;
}
@Override
protected boolean supports(Object input) {
return CTOR != null;
}
@Override
protected T invoke(Object input) {
Object[] args = null;
if (input instanceof Object[]) {
args = (Object[]) input;
} else if (input != null) {
args = new Object[]{input};
}
return Classes.instantiate(CTOR, args);
}
}

View File

@ -16,8 +16,6 @@
package io.jsonwebtoken.impl.security; package io.jsonwebtoken.impl.security;
import io.jsonwebtoken.impl.lang.Bytes; import io.jsonwebtoken.impl.lang.Bytes;
import io.jsonwebtoken.impl.lang.CheckedSupplier;
import io.jsonwebtoken.impl.lang.Conditions;
import io.jsonwebtoken.lang.Arrays; import io.jsonwebtoken.lang.Arrays;
import io.jsonwebtoken.lang.Assert; import io.jsonwebtoken.lang.Assert;
import io.jsonwebtoken.security.AssociatedDataSupplier; import io.jsonwebtoken.security.AssociatedDataSupplier;
@ -28,7 +26,6 @@ import io.jsonwebtoken.security.Request;
import io.jsonwebtoken.security.SecretKeyBuilder; import io.jsonwebtoken.security.SecretKeyBuilder;
import io.jsonwebtoken.security.WeakKeyException; import io.jsonwebtoken.security.WeakKeyException;
import javax.crypto.Cipher;
import javax.crypto.SecretKey; import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec; import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.IvParameterSpec; import javax.crypto.spec.IvParameterSpec;
@ -62,17 +59,6 @@ abstract class AesAlgorithm extends CryptoAlgorithm implements KeyBuilderSupplie
this.ivBitLength = jcaTransformation.equals("AESWrap") ? 0 : (this.gcm ? GCM_IV_SIZE : BLOCK_SIZE); this.ivBitLength = jcaTransformation.equals("AESWrap") ? 0 : (this.gcm ? GCM_IV_SIZE : BLOCK_SIZE);
// https://tools.ietf.org/html/rfc7518#section-5.2.3 through https://tools.ietf.org/html/rfc7518#section-5.3 : // https://tools.ietf.org/html/rfc7518#section-5.2.3 through https://tools.ietf.org/html/rfc7518#section-5.3 :
this.tagBitLength = this.gcm ? BLOCK_SIZE : this.keyBitLength; this.tagBitLength = this.gcm ? BLOCK_SIZE : this.keyBitLength;
// GCM mode only available on JDK 8 and later, so enable BC as a backup provider if necessary for <= JDK 7:
// TODO: remove when dropping JDK 7:
if (this.gcm) {
setProvider(Providers.findBouncyCastle(Conditions.notExists(new CheckedSupplier<Cipher>() {
@Override
public Cipher get() throws Exception {
return Cipher.getInstance(jcaTransformation);
}
})));
}
} }
@Override @Override

View File

@ -35,8 +35,6 @@ abstract class CryptoAlgorithm implements Identifiable {
private final String jcaName; private final String jcaName;
private Provider provider; // default, if any
CryptoAlgorithm(String id, String jcaName) { CryptoAlgorithm(String id, String jcaName) {
Assert.hasText(id, "id cannot be null or empty."); Assert.hasText(id, "id cannot be null or empty.");
this.ID = id; this.ID = id;
@ -53,27 +51,19 @@ abstract class CryptoAlgorithm implements Identifiable {
return this.jcaName; return this.jcaName;
} }
protected void setProvider(Provider provider) { // can be null
this.provider = provider;
}
protected Provider getProvider() {
return this.provider;
}
SecureRandom ensureSecureRandom(Request<?> request) { SecureRandom ensureSecureRandom(Request<?> request) {
SecureRandom random = request != null ? request.getSecureRandom() : null; SecureRandom random = request != null ? request.getSecureRandom() : null;
return random != null ? random : Randoms.secureRandom(); return random != null ? random : Randoms.secureRandom();
} }
protected JcaTemplate jca() { protected JcaTemplate jca() {
return new JcaTemplate(getJcaName(), getProvider()); return new JcaTemplate(getJcaName(), null);
} }
protected JcaTemplate jca(Request<?> request) { protected JcaTemplate jca(Request<?> request) {
Assert.notNull(request, "request cannot be null."); Assert.notNull(request, "request cannot be null.");
String jcaName = Assert.hasText(getJcaName(request), "Request jcaName cannot be null or empty."); String jcaName = Assert.hasText(getJcaName(request), "Request jcaName cannot be null or empty.");
Provider provider = getProvider(request); Provider provider = request.getProvider();
SecureRandom random = ensureSecureRandom(request); SecureRandom random = ensureSecureRandom(request);
return new JcaTemplate(jcaName, provider, random); return new JcaTemplate(jcaName, provider, random);
} }
@ -82,18 +72,10 @@ abstract class CryptoAlgorithm implements Identifiable {
return getJcaName(); return getJcaName();
} }
protected Provider getProvider(Request<?> request) {
Provider provider = request.getProvider();
if (provider == null) {
provider = this.provider; // fallback, if any
}
return provider;
}
protected SecretKey generateKey(KeyRequest<?> request) { protected SecretKey generateKey(KeyRequest<?> request) {
AeadAlgorithm enc = Assert.notNull(request.getEncryptionAlgorithm(), "Request encryptionAlgorithm cannot be null."); AeadAlgorithm enc = Assert.notNull(request.getEncryptionAlgorithm(), "Request encryptionAlgorithm cannot be null.");
SecretKeyBuilder builder = Assert.notNull(enc.key(), "Request encryptionAlgorithm KeyBuilder cannot be null."); SecretKeyBuilder builder = Assert.notNull(enc.key(), "Request encryptionAlgorithm KeyBuilder cannot be null.");
SecretKey key = builder.provider(getProvider(request)).random(request.getSecureRandom()).build(); SecretKey key = builder.provider(request.getProvider()).random(request.getSecureRandom()).build();
return Assert.notNull(key, "Request encryptionAlgorithm SecretKeyBuilder cannot produce null keys."); return Assert.notNull(key, "Request encryptionAlgorithm SecretKeyBuilder cannot produce null keys.");
} }

View File

@ -20,24 +20,15 @@ import io.jsonwebtoken.lang.Strings;
import io.jsonwebtoken.security.Curve; import io.jsonwebtoken.security.Curve;
import io.jsonwebtoken.security.KeyPairBuilder; import io.jsonwebtoken.security.KeyPairBuilder;
import java.security.Provider;
class DefaultCurve implements Curve { class DefaultCurve implements Curve {
private final String ID; private final String ID;
private final String JCA_NAME; private final String JCA_NAME;
private final Provider PROVIDER; // can be null
DefaultCurve(String id, String jcaName) { DefaultCurve(String id, String jcaName) {
this(id, jcaName, null);
}
DefaultCurve(String id, String jcaName, Provider provider) {
this.ID = Assert.notNull(Strings.clean(id), "Curve ID cannot be null or empty."); this.ID = Assert.notNull(Strings.clean(id), "Curve ID cannot be null or empty.");
this.JCA_NAME = Assert.notNull(Strings.clean(jcaName), "Curve jcaName cannot be null or empty."); this.JCA_NAME = Assert.notNull(Strings.clean(jcaName), "Curve jcaName cannot be null or empty.");
this.PROVIDER = provider;
} }
@Override @Override
@ -49,10 +40,6 @@ class DefaultCurve implements Curve {
return this.JCA_NAME; return this.JCA_NAME;
} }
public Provider getProvider() {
return this.PROVIDER;
}
@Override @Override
public int hashCode() { public int hashCode() {
return ID.hashCode(); return ID.hashCode();
@ -76,6 +63,6 @@ class DefaultCurve implements Curve {
} }
public KeyPairBuilder keyPair() { public KeyPairBuilder keyPair() {
return new DefaultKeyPairBuilder(this.JCA_NAME).provider(this.PROVIDER); return new DefaultKeyPairBuilder(this.JCA_NAME);
} }
} }

View File

@ -22,7 +22,6 @@ import io.jsonwebtoken.security.Request;
import io.jsonwebtoken.security.VerifyDigestRequest; import io.jsonwebtoken.security.VerifyDigestRequest;
import java.security.MessageDigest; import java.security.MessageDigest;
import java.security.Provider;
import java.util.Locale; import java.util.Locale;
public final class DefaultHashAlgorithm extends CryptoAlgorithm implements HashAlgorithm { public final class DefaultHashAlgorithm extends CryptoAlgorithm implements HashAlgorithm {
@ -33,11 +32,6 @@ public final class DefaultHashAlgorithm extends CryptoAlgorithm implements HashA
super(id, id.toUpperCase(Locale.ENGLISH)); super(id, id.toUpperCase(Locale.ENGLISH));
} }
DefaultHashAlgorithm(String id, String jcaName, Provider provider) {
super(id, jcaName);
setProvider(provider);
}
@Override @Override
public byte[] digest(final Request<byte[]> request) { public byte[] digest(final Request<byte[]> request) {
Assert.notNull(request, "Request cannot be null."); Assert.notNull(request, "Request cannot be null.");

View File

@ -155,7 +155,6 @@ final class EcSignatureAlgorithm extends AbstractSignatureAlgorithm {
@Override @Override
public KeyPairBuilder keyPair() { public KeyPairBuilder keyPair() {
return new DefaultKeyPairBuilder(ECCurve.KEY_PAIR_GENERATOR_JCA_NAME, this.KEY_PAIR_GEN_PARAMS) return new DefaultKeyPairBuilder(ECCurve.KEY_PAIR_GENERATOR_JCA_NAME, this.KEY_PAIR_GEN_PARAMS)
.provider(getProvider())
.random(Randoms.secureRandom()); .random(Randoms.secureRandom());
} }

View File

@ -69,8 +69,10 @@ class EcdhKeyAlgorithm extends CryptoAlgorithm implements KeyAlgorithm<PublicKey
// where the Digest Method is SHA-256. // where the Digest Method is SHA-256.
private static final String CONCAT_KDF_HASH_ALG_NAME = "SHA-256"; private static final String CONCAT_KDF_HASH_ALG_NAME = "SHA-256";
private static final ConcatKDF CONCAT_KDF = new ConcatKDF(CONCAT_KDF_HASH_ALG_NAME); private static final ConcatKDF CONCAT_KDF = new ConcatKDF(CONCAT_KDF_HASH_ALG_NAME);
public static final String KEK_TYPE_MESSAGE = "Key Encryption Key must be a " + ECKey.class.getName() + " or valid Edwards Curve PublicKey instance."; public static final String KEK_TYPE_MESSAGE = "Key Encryption Key must be a " + ECKey.class.getName() +
public static final String KDK_TYPE_MESSAGE = "Key Decryption Key must be a " + ECKey.class.getName() + " or valid Edwards Curve PrivateKey instance."; " or valid Edwards Curve PublicKey instance.";
public static final String KDK_TYPE_MESSAGE = "Key Decryption Key must be a " + ECKey.class.getName() +
" or valid Edwards Curve PrivateKey instance.";
private final KeyAlgorithm<SecretKey, SecretKey> WRAP_ALG; private final KeyAlgorithm<SecretKey, SecretKey> WRAP_ALG;
@ -92,7 +94,8 @@ class EcdhKeyAlgorithm extends CryptoAlgorithm implements KeyAlgorithm<PublicKey
//visible for testing, for non-Edwards elliptic curves //visible for testing, for non-Edwards elliptic curves
protected KeyPair generateKeyPair(final Request<?> request, final ECParameterSpec spec) { protected KeyPair generateKeyPair(final Request<?> request, final ECParameterSpec spec) {
Assert.notNull(spec, "request key params cannot be null."); Assert.notNull(spec, "request key params cannot be null.");
JcaTemplate template = new JcaTemplate(ECCurve.KEY_PAIR_GENERATOR_JCA_NAME, getProvider(request), ensureSecureRandom(request)); JcaTemplate template = new JcaTemplate(ECCurve.KEY_PAIR_GENERATOR_JCA_NAME, request.getProvider(),
ensureSecureRandom(request));
return template.generateKeyPair(spec); return template.generateKeyPair(spec);
} }
@ -113,7 +116,8 @@ class EcdhKeyAlgorithm extends CryptoAlgorithm implements KeyAlgorithm<PublicKey
} }
protected String getConcatKDFAlgorithmId(AeadAlgorithm enc) { protected String getConcatKDFAlgorithmId(AeadAlgorithm enc) {
return this.WRAP_ALG instanceof DirectKeyAlgorithm ? Assert.hasText(enc.getId(), "AeadAlgorithm id cannot be null or empty.") : getId(); return this.WRAP_ALG instanceof DirectKeyAlgorithm ? Assert.hasText(enc.getId(),
"AeadAlgorithm id cannot be null or empty.") : getId();
} }
private byte[] createOtherInfo(int keydatalen, String AlgorithmID, byte[] PartyUInfo, byte[] PartyVInfo) { private byte[] createOtherInfo(int keydatalen, String AlgorithmID, byte[] PartyUInfo, byte[] PartyVInfo) {
@ -136,12 +140,14 @@ class EcdhKeyAlgorithm extends CryptoAlgorithm implements KeyAlgorithm<PublicKey
} }
private int getKeyBitLength(AeadAlgorithm enc) { private int getKeyBitLength(AeadAlgorithm enc) {
int bitLength = this.WRAP_ALG instanceof KeyLengthSupplier ? ((KeyLengthSupplier) this.WRAP_ALG).getKeyBitLength() : enc.getKeyBitLength(); int bitLength = this.WRAP_ALG instanceof KeyLengthSupplier ?
((KeyLengthSupplier) this.WRAP_ALG).getKeyBitLength() : enc.getKeyBitLength();
return Assert.gt(bitLength, 0, "Algorithm keyBitLength must be > 0"); return Assert.gt(bitLength, 0, "Algorithm keyBitLength must be > 0");
} }
private SecretKey deriveKey(KeyRequest<?> request, PublicKey publicKey, PrivateKey privateKey) { private SecretKey deriveKey(KeyRequest<?> request, PublicKey publicKey, PrivateKey privateKey) {
AeadAlgorithm enc = Assert.notNull(request.getEncryptionAlgorithm(), "Request encryptionAlgorithm cannot be null."); AeadAlgorithm enc = Assert.notNull(request.getEncryptionAlgorithm(),
"Request encryptionAlgorithm cannot be null.");
int requiredCekBitLen = getKeyBitLength(enc); int requiredCekBitLen = getKeyBitLength(enc);
final String AlgorithmID = getConcatKDFAlgorithmId(enc); final String AlgorithmID = getConcatKDFAlgorithmId(enc);
byte[] apu = request.getHeader().getAgreementPartyUInfo(); byte[] apu = request.getHeader().getAgreementPartyUInfo();
@ -169,7 +175,8 @@ class EcdhKeyAlgorithm extends CryptoAlgorithm implements KeyAlgorithm<PublicKey
} }
Assert.stateNotNull(curve, "EdwardsCurve instance cannot be null."); Assert.stateNotNull(curve, "EdwardsCurve instance cannot be null.");
if (curve.isSignatureCurve()) { if (curve.isSignatureCurve()) {
String msg = curve.getId() + " keys may not be used with ECDH-ES key agreement algorithms per " + "https://www.rfc-editor.org/rfc/rfc8037#section-3.1"; String msg = curve.getId() + " keys may not be used with ECDH-ES key agreement algorithms per " +
"https://www.rfc-editor.org/rfc/rfc8037#section-3.1";
throw new UnsupportedKeyException(msg); throw new UnsupportedKeyException(msg);
} }
return curve; return curve;
@ -183,11 +190,12 @@ class EcdhKeyAlgorithm extends CryptoAlgorithm implements KeyAlgorithm<PublicKey
KeyPair pair; // generated (ephemeral) key pair KeyPair pair; // generated (ephemeral) key pair
final SecureRandom random = ensureSecureRandom(request); final SecureRandom random = ensureSecureRandom(request);
DynamicJwkBuilder<?, ?> jwkBuilder = Jwks.builder().random(random); DynamicJwkBuilder<?, ?> jwkBuilder = Jwks.builder().random(random).provider(request.getProvider());
if (publicKey instanceof ECKey) { if (publicKey instanceof ECKey) {
ECKey ecPublicKey = (ECKey) publicKey; ECKey ecPublicKey = (ECKey) publicKey;
ECParameterSpec spec = Assert.notNull(ecPublicKey.getParams(), "Encryption PublicKey params cannot be null."); ECParameterSpec spec = Assert.notNull(ecPublicKey.getParams(),
"Encryption PublicKey params cannot be null.");
// note: we don't need to validate if specified key's point is on a supported curve here // note: we don't need to validate if specified key's point is on a supported curve here
// because that will automatically be asserted when using Jwks.builder().... below // because that will automatically be asserted when using Jwks.builder().... below
pair = generateKeyPair(request, spec); pair = generateKeyPair(request, spec);
@ -197,12 +205,8 @@ class EcdhKeyAlgorithm extends CryptoAlgorithm implements KeyAlgorithm<PublicKey
} else { // it must be an edwards curve key } else { // it must be an edwards curve key
EdwardsCurve curve = assertAgreement(publicKey, KEK_TYPE_MESSAGE); EdwardsCurve curve = assertAgreement(publicKey, KEK_TYPE_MESSAGE);
Provider provider = request.getProvider(); Provider provider = request.getProvider();
Provider curveProvider = curve.getProvider(); // only non-null if not natively supported by the JVM
if (provider == null && curveProvider != null) { // ensure that BC can be used if necessary:
provider = curveProvider;
request = new DefaultKeyRequest<>(request.getPayload(), provider, random, request = new DefaultKeyRequest<>(request.getPayload(), provider, random,
request.getHeader(), request.getEncryptionAlgorithm()); request.getHeader(), request.getEncryptionAlgorithm());
}
pair = generateKeyPair(random, curve, provider); pair = generateKeyPair(random, curve, provider);
jwkBuilder.provider(provider); jwkBuilder.provider(provider);
} }
@ -235,36 +239,39 @@ class EcdhKeyAlgorithm extends CryptoAlgorithm implements KeyAlgorithm<PublicKey
if (privateKey instanceof ECKey) { if (privateKey instanceof ECKey) {
ECKey ecPrivateKey = (ECKey) privateKey; ECKey ecPrivateKey = (ECKey) privateKey;
if (!(epk instanceof EcPublicJwk)) { if (!(epk instanceof EcPublicJwk)) {
String msg = "JWE Header " + DefaultJweHeader.EPK + " value is not a supported Elliptic Curve " + "Public JWK. Value: " + epk; String msg = "JWE Header " + DefaultJweHeader.EPK + " value is not a supported Elliptic Curve " +
"Public JWK. Value: " + epk;
throw new UnsupportedKeyException(msg); throw new UnsupportedKeyException(msg);
} }
EcPublicJwk ecEpk = (EcPublicJwk) epk; EcPublicJwk ecEpk = (EcPublicJwk) epk;
// While the EPK might be on a JWA-supported NIST curve, it must be on the private key's exact curve: // While the EPK might be on a JWA-supported NIST curve, it must be on the private key's exact curve:
if (!EcPublicJwkFactory.contains(ecPrivateKey.getParams().getCurve(), ecEpk.toKey().getW())) { if (!EcPublicJwkFactory.contains(ecPrivateKey.getParams().getCurve(), ecEpk.toKey().getW())) {
String msg = "JWE Header " + DefaultJweHeader.EPK + " value does not represent " + "a point on the expected curve."; String msg = "JWE Header " + DefaultJweHeader.EPK + " value does not represent " +
"a point on the expected curve.";
throw new InvalidKeyException(msg); throw new InvalidKeyException(msg);
} }
} else { // it must be an Edwards Curve key } else { // it must be an Edwards Curve key
EdwardsCurve privateKeyCurve = assertAgreement(privateKey, KDK_TYPE_MESSAGE); EdwardsCurve privateKeyCurve = assertAgreement(privateKey, KDK_TYPE_MESSAGE);
if (!(epk instanceof OctetPublicJwk)) { if (!(epk instanceof OctetPublicJwk)) {
String msg = "JWE Header " + DefaultJweHeader.EPK + " value is not a supported Elliptic Curve " + "Public JWK. Value: " + epk; String msg = "JWE Header " + DefaultJweHeader.EPK + " value is not a supported Elliptic Curve " +
"Public JWK. Value: " + epk;
throw new UnsupportedKeyException(msg); throw new UnsupportedKeyException(msg);
} }
OctetPublicJwk<?> oEpk = (OctetPublicJwk<?>) epk; OctetPublicJwk<?> oEpk = (OctetPublicJwk<?>) epk;
EdwardsCurve epkCurve = EdwardsCurve.forKey(oEpk.toKey()); EdwardsCurve epkCurve = EdwardsCurve.forKey(oEpk.toKey());
if (!privateKeyCurve.equals(epkCurve)) { if (!privateKeyCurve.equals(epkCurve)) {
String msg = "JWE Header " + DefaultJweHeader.EPK + " value does not represent a point " + "on the expected curve. Value: " + oEpk; String msg = "JWE Header " + DefaultJweHeader.EPK + " value does not represent a point " +
"on the expected curve. Value: " + oEpk;
throw new InvalidKeyException(msg); throw new InvalidKeyException(msg);
} }
Provider curveProvider = privateKeyCurve.getProvider(); request = new DefaultDecryptionKeyRequest<>(request.getPayload(), request.getProvider(),
if (request.getProvider() == null && curveProvider != null) { // ensure that BC can be used if necessary: ensureSecureRandom(request), request.getHeader(), request.getEncryptionAlgorithm(), request.getKey());
request = new DefaultDecryptionKeyRequest<>(request.getPayload(), curveProvider, ensureSecureRandom(request), request.getHeader(), request.getEncryptionAlgorithm(), request.getKey());
}
} }
final SecretKey derived = deriveKey(request, epk.toKey(), privateKey); final SecretKey derived = deriveKey(request, epk.toKey(), privateKey);
DecryptionKeyRequest<SecretKey> unwrapReq = new DefaultDecryptionKeyRequest<>(request.getPayload(), request.getProvider(), request.getSecureRandom(), header, request.getEncryptionAlgorithm(), derived); DecryptionKeyRequest<SecretKey> unwrapReq = new DefaultDecryptionKeyRequest<>(request.getPayload(),
request.getProvider(), request.getSecureRandom(), header, request.getEncryptionAlgorithm(), derived);
return WRAP_ALG.getDecryptionKey(unwrapReq); return WRAP_ALG.getDecryptionKey(unwrapReq);
} }

View File

@ -41,8 +41,6 @@ final class EdSignatureAlgorithm extends AbstractSignatureAlgorithm {
private EdSignatureAlgorithm() { private EdSignatureAlgorithm() {
super(ID, ID); super(ID, ID);
this.preferredCurve = EdwardsCurve.Ed448; this.preferredCurve = EdwardsCurve.Ed448;
// EdDSA is not available natively until JDK 15, so try to load BC as a backup provider if possible:
setProvider(this.preferredCurve.getProvider());
Assert.isTrue(this.preferredCurve.isSignatureCurve(), "Must be signature curve, not key agreement curve."); Assert.isTrue(this.preferredCurve.isSignatureCurve(), "Must be signature curve, not key agreement curve.");
} }

View File

@ -16,12 +16,7 @@
package io.jsonwebtoken.impl.security; package io.jsonwebtoken.impl.security;
import io.jsonwebtoken.impl.lang.Bytes; import io.jsonwebtoken.impl.lang.Bytes;
import io.jsonwebtoken.impl.lang.CheckedFunction;
import io.jsonwebtoken.impl.lang.CheckedSupplier;
import io.jsonwebtoken.impl.lang.Conditions;
import io.jsonwebtoken.impl.lang.Function; import io.jsonwebtoken.impl.lang.Function;
import io.jsonwebtoken.impl.lang.Functions;
import io.jsonwebtoken.impl.lang.OptionalCtorInvoker;
import io.jsonwebtoken.lang.Assert; import io.jsonwebtoken.lang.Assert;
import io.jsonwebtoken.lang.Collections; import io.jsonwebtoken.lang.Collections;
import io.jsonwebtoken.lang.Strings; import io.jsonwebtoken.lang.Strings;
@ -32,12 +27,9 @@ import io.jsonwebtoken.security.KeyPairBuilder;
import io.jsonwebtoken.security.UnsupportedKeyException; import io.jsonwebtoken.security.UnsupportedKeyException;
import java.security.Key; import java.security.Key;
import java.security.KeyFactory;
import java.security.KeyPairGenerator;
import java.security.PrivateKey; import java.security.PrivateKey;
import java.security.Provider; import java.security.Provider;
import java.security.PublicKey; import java.security.PublicKey;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.KeySpec; import java.security.spec.KeySpec;
import java.security.spec.PKCS8EncodedKeySpec; import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec; import java.security.spec.X509EncodedKeySpec;
@ -50,21 +42,11 @@ public class EdwardsCurve extends DefaultCurve implements KeyLengthSupplier {
private static final String OID_PREFIX = "1.3.101."; private static final String OID_PREFIX = "1.3.101.";
// DER-encoded edwards keys have this exact sequence identifying the type of key that follows. The trailing // ASN.1-encoded edwards keys have this exact sequence identifying the type of key that follows. The trailing
// byte is the exact edwards curve subsection OID terminal node id. // byte is the exact edwards curve subsection OID terminal node id.
private static final byte[] DER_OID_PREFIX = new byte[]{0x06, 0x03, 0x2B, 0x65}; private static final byte[] ASN1_OID_PREFIX = new byte[]{0x06, 0x03, 0x2B, 0x65};
private static final String NAMED_PARAM_SPEC_FQCN = "java.security.spec.NamedParameterSpec"; // JDK >= 11
private static final String XEC_PRIV_KEY_SPEC_FQCN = "java.security.spec.XECPrivateKeySpec"; // JDK >= 11
private static final String EDEC_PRIV_KEY_SPEC_FQCN = "java.security.spec.EdECPrivateKeySpec"; // JDK >= 15
private static final Function<Key, String> CURVE_NAME_FINDER = new NamedParameterSpecValueFinder(); private static final Function<Key, String> CURVE_NAME_FINDER = new NamedParameterSpecValueFinder();
private static final OptionalCtorInvoker<AlgorithmParameterSpec> NAMED_PARAM_SPEC_CTOR =
new OptionalCtorInvoker<>(NAMED_PARAM_SPEC_FQCN, String.class);
static final OptionalCtorInvoker<KeySpec> XEC_PRIV_KEY_SPEC_CTOR =
new OptionalCtorInvoker<>(XEC_PRIV_KEY_SPEC_FQCN, AlgorithmParameterSpec.class, byte[].class);
static final OptionalCtorInvoker<KeySpec> EDEC_PRIV_KEY_SPEC_CTOR =
new OptionalCtorInvoker<>(EDEC_PRIV_KEY_SPEC_FQCN, NAMED_PARAM_SPEC_FQCN, byte[].class);
public static final EdwardsCurve X25519 = new EdwardsCurve("X25519", 110); // Requires JDK >= 11 or BC public static final EdwardsCurve X25519 = new EdwardsCurve("X25519", 110); // Requires JDK >= 11 or BC
public static final EdwardsCurve X448 = new EdwardsCurve("X448", 111); // Requires JDK >= 11 or BC public static final EdwardsCurve X448 = new EdwardsCurve("X448", 111); // Requires JDK >= 11 or BC
@ -81,24 +63,41 @@ public class EdwardsCurve extends DefaultCurve implements KeyLengthSupplier {
REGISTRY = new LinkedHashMap<>(8); REGISTRY = new LinkedHashMap<>(8);
BY_OID_TERMINAL_NODE = new LinkedHashMap<>(4); BY_OID_TERMINAL_NODE = new LinkedHashMap<>(4);
for (EdwardsCurve curve : VALUES) { for (EdwardsCurve curve : VALUES) {
int subcategoryId = curve.DER_OID[curve.DER_OID.length - 1]; int subcategoryId = curve.ASN1_OID[curve.ASN1_OID.length - 1];
BY_OID_TERMINAL_NODE.put(subcategoryId, curve); BY_OID_TERMINAL_NODE.put(subcategoryId, curve);
REGISTRY.put(curve.getId(), curve); REGISTRY.put(curve.getId(), curve);
REGISTRY.put(curve.OID, curve); // add OID as an alias for alg/id lookups REGISTRY.put(curve.OID, curve); // add OID as an alias for alg/id lookups
} }
} }
private static byte[] privateKeyPkcs8Prefix(int byteLength, byte[] ASN1_OID, boolean ber) {
byte[] keyPrefix = ber ?
new byte[]{0x04, (byte) (byteLength + 2), 0x04, (byte) byteLength} : // correct
new byte[]{0x04, (byte) byteLength}; // https://bugs.openjdk.org/browse/JDK-8213363
return Bytes.concat(
new byte[]{
0x30,
(byte) (5 + ASN1_OID.length + keyPrefix.length + byteLength),
0x02, 0x01, 0x00, // encoding version 1 (integer, 1 byte, value 0)
0x30, 0x05}, // ASN.1 SEQUENCE of 5 bytes to follow (i.e. the OID)
ASN1_OID,
keyPrefix
);
}
private final String OID; private final String OID;
/** /**
* The byte sequence within an DER-encoded key that indicates an Edwards curve encoded key follows. DER (hex) * The byte sequence within an ASN.1-encoded key that indicates an Edwards curve encoded key follows. ASN.1 (hex)
* notation: * notation:
* <pre> * <pre>
* 06 03 ; OBJECT IDENTIFIER (3 bytes long) * 06 03 ; OBJECT IDENTIFIER (3 bytes long)
* | 2B 65 $I ; "1.3.101.$I" for Edwards alg OID, where $I = 6E, 6F, 70, or 71 (decimal 110, 111, 112, or 113) * | 2B 65 $I ; "1.3.101.$I" for Edwards alg OID, where $I = 6E, 6F, 70, or 71 (decimal 110, 111, 112, or 113)
* </pre> * </pre>
*/ */
final byte[] DER_OID; final byte[] ASN1_OID;
private final int keyBitLength; private final int keyBitLength;
@ -107,42 +106,39 @@ public class EdwardsCurve extends DefaultCurve implements KeyLengthSupplier {
private final int encodedKeyByteLength; private final int encodedKeyByteLength;
/** /**
* X.509 (DER) encoding of a public key associated with this curve as a prefix (that is, <em>without</em> the * X.509 (ASN.1) encoding of a public key associated with this curve as a prefix (that is, <em>without</em> the
* actual encoded key material at the end). Appending the public key material directly to the end of this value * actual encoded key material at the end). Appending the public key material directly to the end of this value
* results in a complete X.509 (DER) encoded public key. DER (hex) notation: * results in a complete X.509 (ASN.1) encoded public key. ASN.1 (hex) notation:
* <pre> * <pre>
* 30 $M ; DER SEQUENCE ($M bytes long), where $M = encodedKeyByteLength + 10 * 30 $M ; ASN.1 SEQUENCE ($M bytes long), where $M = encodedKeyByteLength + 10
* 30 05 ; DER SEQUENCE (5 bytes long) * 30 05 ; ASN.1 SEQUENCE (5 bytes long)
* 06 03 ; OBJECT IDENTIFIER (3 bytes long) * 06 03 ; OBJECT IDENTIFIER (3 bytes long)
* 2B 65 $I ; "1.3.101.$I" for Edwards alg OID, where $I = 6E, 6F, 70, or 71 (110, 111, 112, or 113 decimal) * 2B 65 $I ; "1.3.101.$I" for Edwards alg OID, where $I = 6E, 6F, 70, or 71 (110, 111, 112, or 113 decimal)
* 03 $S ; DER BIT STRING ($S bytes long), where $S = encodedKeyByteLength + 1 * 03 $S ; ASN.1 BIT STRING ($S bytes long), where $S = encodedKeyByteLength + 1
* 00 ; DER bit string marker indicating zero unused bits at the end of the bit string * 00 ; ASN.1 bit string marker indicating zero unused bits at the end of the bit string
* XX XX XX ... ; encoded key material (not included in this PREFIX byte array variable) * XX XX XX ... ; encoded key material (not included in this PREFIX byte array variable)
* </pre> * </pre>
*/ */
private final byte[] PUBLIC_KEY_DER_PREFIX; private final byte[] PUBLIC_KEY_ASN1_PREFIX;
/** /**
* PKCS8 (DER) Version 1 encoding of a private key associated with this curve, as a prefix (that is, * PKCS8 (ASN.1) Version 1 encoding of a private key associated with this curve, as a prefix (that is,
* <em>without</em> actual encoded key material at the end). Appending the private key material directly to the * <em>without</em> actual encoded key material at the end). Appending the private key material directly to the
* end of this value results in a complete PKCS8 (DER) V1 encoded private key. DER (hex) notation: * end of this value results in a complete PKCS8 (ASN.1) V1 encoded private key. ASN.1 (hex) notation:
* <pre> * <pre>
* 30 $M ; DER SEQUENCE ($M bytes long), where $M = encodedKeyByteLength + 14 * 30 $M ; ASN.1 SEQUENCE ($M bytes long), where $M = encodedKeyByteLength + 14
* 02 01 ; DER INTEGER (1 byte long) * 02 01 ; ASN.1 INTEGER (1 byte long)
* 00 ; zero (private key encoding version V1) * 00 ; zero (private key encoding version V1)
* 30 05 ; DER SEQUENCE (5 bytes long) * 30 05 ; ASN.1 SEQUENCE (5 bytes long)
* 06 03 ; OBJECT IDENTIFIER (3 bytes long). This is the edwards algorithm ID. * 06 03 ; OBJECT IDENTIFIER (3 bytes long). This is the edwards algorithm ID.
* 2B 65 $I ; "1.3.101.$I" for Edwards alg OID, where $I = 6E, 6F, 70, or 71 (110, 111, 112, or 113 decimal) * 2B 65 $I ; "1.3.101.$I" for Edwards alg OID, where $I = 6E, 6F, 70, or 71 (110, 111, 112, or 113 decimal)
* 04 $B ; DER SEQUENCE ($B bytes long, where $B = encodedKeyByteLength + 2 * 04 $B ; ASN.1 SEQUENCE ($B bytes long, where $B = encodedKeyByteLength + 2
* 04 $K ; DER SEQUENCE ($K bytes long), where $K = encodedKeyByteLength * 04 $K ; ASN.1 SEQUENCE ($K bytes long), where $K = encodedKeyByteLength
* XX XX XX ... ; encoded key material (not included in this PREFIX byte array variable) * XX XX XX ... ; encoded key material (not included in this PREFIX byte array variable)
* </pre> * </pre>
*/ */
private final byte[] PRIVATE_KEY_DER_PREFIX; private final byte[] PRIVATE_KEY_ASN1_PREFIX;
private final byte[] PRIVATE_KEY_JDK11_PREFIX; // https://bugs.openjdk.org/browse/JDK-8213363
private final AlgorithmParameterSpec NAMED_PARAMETER_SPEC; // null on <= JDK 10
private final Function<byte[], KeySpec> PRIVATE_KEY_SPEC_FACTORY;
/** /**
* {@code true} IFF the curve is used for digital signatures, {@code false} if used for key agreement * {@code true} IFF the curve is used for digital signatures, {@code false} if used for key agreement
@ -150,14 +146,7 @@ public class EdwardsCurve extends DefaultCurve implements KeyLengthSupplier {
private final boolean signatureCurve; private final boolean signatureCurve;
EdwardsCurve(final String id, int oidTerminalNode) { EdwardsCurve(final String id, int oidTerminalNode) {
super(id, id, // JWT ID and JCA name happen to be identical super(id, id);
// fall back to BouncyCastle if < JDK 11 (for XDH curves) or < JDK 15 (for EdDSA curves) if necessary:
Providers.findBouncyCastle(Conditions.notExists(new CheckedSupplier<KeyPairGenerator>() {
@Override
public KeyPairGenerator get() throws Exception {
return KeyPairGenerator.getInstance(id);
}
})));
// OIDs (with terminal node IDs) defined here: https://www.rfc-editor.org/rfc/rfc8410#section-3 // OIDs (with terminal node IDs) defined here: https://www.rfc-editor.org/rfc/rfc8410#section-3
// X25519 (oid 1.3.101.110) and X448 (oid 1.3.101.111) have 256 bits // X25519 (oid 1.3.101.110) and X448 (oid 1.3.101.111) have 256 bits
@ -183,39 +172,22 @@ public class EdwardsCurve extends DefaultCurve implements KeyLengthSupplier {
this.OID = OID_PREFIX + oidTerminalNode; this.OID = OID_PREFIX + oidTerminalNode;
this.signatureCurve = (oidTerminalNode == 112 || oidTerminalNode == 113); this.signatureCurve = (oidTerminalNode == 112 || oidTerminalNode == 113);
byte[] suffix = new byte[]{(byte) oidTerminalNode}; byte[] suffix = new byte[]{(byte) oidTerminalNode};
this.DER_OID = Bytes.concat(DER_OID_PREFIX, suffix); this.ASN1_OID = Bytes.concat(ASN1_OID_PREFIX, suffix);
this.encodedKeyByteLength = (this.keyBitLength + 7) / 8; this.encodedKeyByteLength = (this.keyBitLength + 7) / 8;
this.PUBLIC_KEY_DER_PREFIX = Bytes.concat( this.PUBLIC_KEY_ASN1_PREFIX = Bytes.concat(
new byte[]{ new byte[]{
0x30, (byte) (this.encodedKeyByteLength + 10), 0x30, (byte) (this.encodedKeyByteLength + 10),
0x30, 0x05}, // DER SEQUENCE of 5 bytes to follow (i.e. the OID) 0x30, 0x05}, // ASN.1 SEQUENCE of 5 bytes to follow (i.e. the OID)
this.DER_OID, this.ASN1_OID,
new byte[]{ new byte[]{
0x03, 0x03,
(byte) (this.encodedKeyByteLength + 1), (byte) (this.encodedKeyByteLength + 1),
0x00} 0x00}
); );
byte[] keyPrefix = new byte[]{ this.PRIVATE_KEY_ASN1_PREFIX = privateKeyPkcs8Prefix(this.encodedKeyByteLength, this.ASN1_OID, true);
0x04, (byte) (this.encodedKeyByteLength + 2), this.PRIVATE_KEY_JDK11_PREFIX = privateKeyPkcs8Prefix(this.encodedKeyByteLength, this.ASN1_OID, false);
0x04, (byte) this.encodedKeyByteLength};
this.PRIVATE_KEY_DER_PREFIX = Bytes.concat(
new byte[]{
0x30,
(byte) (this.encodedKeyByteLength + 10 + keyPrefix.length),
0x02, 0x01, 0x00, // encoding version 1 (integer, 1 byte, value 0)
0x30, 0x05}, // DER SEQUENCE of 5 bytes to follow (i.e. the OID)
this.DER_OID,
keyPrefix
);
this.NAMED_PARAMETER_SPEC = NAMED_PARAM_SPEC_CTOR.apply(id); // null on <= JDK 10
Function<byte[], KeySpec> paramKeySpecFn = paramKeySpecFactory(NAMED_PARAMETER_SPEC, signatureCurve);
Function<byte[], KeySpec> pkcs8KeySpecFn = new Pkcs8KeySpecFactory(this.PRIVATE_KEY_DER_PREFIX);
// prefer the JDK KeySpec classes first, and fall back to PKCS8 encoding if unavailable:
this.PRIVATE_KEY_SPEC_FACTORY = Functions.firstResult(paramKeySpecFn, pkcs8KeySpecFn);
// The Sun CE KeyPairGenerator implementation that we'll use to derive PublicKeys with is problematic here: // The Sun CE KeyPairGenerator implementation that we'll use to derive PublicKeys with is problematic here:
// //
@ -241,14 +213,6 @@ public class EdwardsCurve extends DefaultCurve implements KeyLengthSupplier {
this.KEY_PAIR_GENERATOR_BIT_LENGTH = this.keyBitLength >= 448 ? 448 : 255; this.KEY_PAIR_GENERATOR_BIT_LENGTH = this.keyBitLength >= 448 ? 448 : 255;
} }
// visible for testing
protected static Function<byte[], KeySpec> paramKeySpecFactory(AlgorithmParameterSpec spec, boolean signatureCurve) {
if (spec == null) {
return Functions.forNull();
}
return new ParameterizedKeySpecFactory(spec, signatureCurve ? EDEC_PRIV_KEY_SPEC_CTOR : XEC_PRIV_KEY_SPEC_CTOR);
}
@Override @Override
public int getKeyBitLength() { public int getKeyBitLength() {
return this.keyBitLength; return this.keyBitLength;
@ -261,39 +225,39 @@ public class EdwardsCurve extends DefaultCurve implements KeyLengthSupplier {
if (t instanceof KeyException) { //propagate if (t instanceof KeyException) { //propagate
throw (KeyException) t; throw (KeyException) t;
} }
String msg = "Invalid " + getId() + " DER encoding: " + t.getMessage(); String msg = "Invalid " + getId() + " ASN.1 encoding: " + t.getMessage();
throw new InvalidKeyException(msg, t); throw new InvalidKeyException(msg, t);
} }
} }
/** /**
* Parses the DER-encoding of the specified key * Parses the ASN.1-encoding of the specified key
* *
* @param key the Edwards curve key * @param key the Edwards curve key
* @return the key value, encoded according to <a href="https://www.rfc-editor.org/rfc/rfc8032">RFC 8032</a> * @return the key value, encoded according to <a href="https://www.rfc-editor.org/rfc/rfc8032">RFC 8032</a>
* @throws RuntimeException if the key's encoded bytes do not reflect a validly DER-encoded edwards key * @throws RuntimeException if the key's encoded bytes do not reflect a validly ASN.1-encoded edwards key
*/ */
protected byte[] doGetKeyMaterial(Key key) { protected byte[] doGetKeyMaterial(Key key) {
byte[] encoded = KeysBridge.getEncoded(key); byte[] encoded = KeysBridge.getEncoded(key);
int i = Bytes.indexOf(encoded, DER_OID); int i = Bytes.indexOf(encoded, ASN1_OID);
Assert.gt(i, -1, "Missing or incorrect algorithm OID."); Assert.gt(i, -1, "Missing or incorrect algorithm OID.");
i = i + DER_OID.length; i = i + ASN1_OID.length;
int keyLen = 0; int keyLen = 0;
if (encoded[i] == 0x05) { // NULL terminator, next should be zero byte indicator if (encoded[i] == 0x05) { // NULL terminator, next should be zero byte indicator
int unusedBytes = encoded[++i]; int unusedBytes = encoded[++i];
Assert.eq(unusedBytes, 0, "OID NULL terminator should indicate zero unused bytes."); Assert.eq(unusedBytes, 0, "OID NULL terminator should indicate zero unused bytes.");
i++; i++;
} }
if (encoded[i] == 0x03) { // DER bit stream, Public Key if (encoded[i] == 0x03) { // ASN.1 bit stream, Public Key
i++; i++;
keyLen = encoded[i++]; keyLen = encoded[i++];
int unusedBytes = encoded[i++]; int unusedBytes = encoded[i++];
Assert.eq(unusedBytes, 0, "BIT STREAM should not indicate unused bytes."); Assert.eq(unusedBytes, 0, "BIT STREAM should not indicate unused bytes.");
keyLen--; keyLen--;
} else if (encoded[i] == 0x04) { // DER octet sequence, Private Key. Key length follows as next byte. } else if (encoded[i] == 0x04) { // ASN.1 octet sequence, Private Key. Key length follows as next byte.
i++; i++;
keyLen = encoded[i++]; keyLen = encoded[i++];
if (encoded[i] == 0x04) { // DER octet sequence, key length follows as next byte. if (encoded[i] == 0x04) { // ASN.1 octet sequence, key length follows as next byte.
i++; // skip sequence marker i++; // skip sequence marker
keyLen = encoded[i++]; // next byte is length keyLen = encoded[i++]; // next byte is length
} }
@ -305,13 +269,6 @@ public class EdwardsCurve extends DefaultCurve implements KeyLengthSupplier {
return result; return result;
} }
protected Provider fallback(Provider provider) {
if (provider == null) {
provider = getProvider();
}
return provider;
}
private void assertLength(byte[] raw, boolean isPublic) { private void assertLength(byte[] raw, boolean isPublic) {
int len = Bytes.length(raw); int len = Bytes.length(raw);
if (len != this.encodedKeyByteLength) { if (len != this.encodedKeyByteLength) {
@ -324,27 +281,23 @@ public class EdwardsCurve extends DefaultCurve implements KeyLengthSupplier {
public PublicKey toPublicKey(byte[] x, Provider provider) { public PublicKey toPublicKey(byte[] x, Provider provider) {
assertLength(x, true); assertLength(x, true);
final byte[] encoded = Bytes.concat(this.PUBLIC_KEY_DER_PREFIX, x); final byte[] encoded = Bytes.concat(this.PUBLIC_KEY_ASN1_PREFIX, x);
final X509EncodedKeySpec spec = new X509EncodedKeySpec(encoded); final X509EncodedKeySpec spec = new X509EncodedKeySpec(encoded);
JcaTemplate template = new JcaTemplate(getJcaName(), fallback(provider)); JcaTemplate template = new JcaTemplate(getJcaName(), provider);
return template.withKeyFactory(new CheckedFunction<KeyFactory, PublicKey>() { return template.generatePublic(spec);
@Override
public PublicKey apply(KeyFactory keyFactory) throws Exception {
return keyFactory.generatePublic(spec);
}
});
} }
public PrivateKey toPrivateKey(byte[] d, Provider provider) { KeySpec privateKeySpec(byte[] d, boolean standard) {
assertLength(d, false); byte[] prefix = standard ? this.PRIVATE_KEY_ASN1_PREFIX : this.PRIVATE_KEY_JDK11_PREFIX;
final KeySpec spec = this.PRIVATE_KEY_SPEC_FACTORY.apply(d); byte[] encoded = Bytes.concat(prefix, d);
JcaTemplate template = new JcaTemplate(getJcaName(), fallback(provider)); return new PKCS8EncodedKeySpec(encoded);
return template.withKeyFactory(new CheckedFunction<KeyFactory, PrivateKey>() {
@Override
public PrivateKey apply(KeyFactory keyFactory) throws Exception {
return keyFactory.generatePrivate(spec);
} }
});
public PrivateKey toPrivateKey(final byte[] d, Provider provider) {
assertLength(d, false);
KeySpec spec = privateKeySpec(d, true);
JcaTemplate template = new JcaTemplate(getJcaName(), provider);
return template.generatePrivate(spec);
} }
/** /**
@ -358,7 +311,7 @@ public class EdwardsCurve extends DefaultCurve implements KeyLengthSupplier {
@Override @Override
public KeyPairBuilder keyPair() { public KeyPairBuilder keyPair() {
return new DefaultKeyPairBuilder(getJcaName(), KEY_PAIR_GENERATOR_BIT_LENGTH).provider(getProvider()); return new DefaultKeyPairBuilder(getJcaName(), KEY_PAIR_GENERATOR_BIT_LENGTH);
} }
public static boolean isEdwards(Key key) { public static boolean isEdwards(Key key) {
@ -397,7 +350,7 @@ public class EdwardsCurve extends DefaultCurve implements KeyLengthSupplier {
curve = findById(alg); curve = findById(alg);
} }
if (curve == null) { // Fall back to key encoding if possible: if (curve == null) { // Fall back to key encoding if possible:
// Try to find the Key DER algorithm OID: // Try to find the Key ASN.1 algorithm OID:
byte[] encoded = KeysBridge.findEncoded(key); byte[] encoded = KeysBridge.findEncoded(key);
if (!Bytes.isEmpty(encoded)) { if (!Bytes.isEmpty(encoded)) {
int oidTerminalNode = findOidTerminalNode(encoded); int oidTerminalNode = findOidTerminalNode(encoded);
@ -411,9 +364,9 @@ public class EdwardsCurve extends DefaultCurve implements KeyLengthSupplier {
} }
private static int findOidTerminalNode(byte[] encoded) { private static int findOidTerminalNode(byte[] encoded) {
int index = Bytes.indexOf(encoded, DER_OID_PREFIX); int index = Bytes.indexOf(encoded, ASN1_OID_PREFIX);
if (index > -1) { if (index > -1) {
index = index + DER_OID_PREFIX.length; index = index + ASN1_OID_PREFIX.length;
if (index < encoded.length) { if (index < encoded.length) {
return encoded[index]; return encoded[index];
} }
@ -438,39 +391,4 @@ public class EdwardsCurve extends DefaultCurve implements KeyLengthSupplier {
forKey(key); // will throw UnsupportedKeyException if the key is not an Edwards key forKey(key); // will throw UnsupportedKeyException if the key is not an Edwards key
return key; return key;
} }
private static final class Pkcs8KeySpecFactory implements Function<byte[], KeySpec> {
private final byte[] PREFIX;
private Pkcs8KeySpecFactory(byte[] pkcs8EncodedKeyPrefix) {
this.PREFIX = Assert.notEmpty(pkcs8EncodedKeyPrefix, "pkcs8EncodedKeyPrefix cannot be null or empty.");
}
@Override
public KeySpec apply(byte[] d) {
Assert.notEmpty(d, "Key bytes cannot be null or empty.");
byte[] encoded = Bytes.concat(PREFIX, d);
return new PKCS8EncodedKeySpec(encoded);
}
}
// visible for testing
protected static final class ParameterizedKeySpecFactory implements Function<byte[], KeySpec> {
private final AlgorithmParameterSpec params;
private final Function<Object, KeySpec> keySpecFactory;
ParameterizedKeySpecFactory(AlgorithmParameterSpec params, Function<Object, KeySpec> keySpecFactory) {
this.params = Assert.notNull(params, "AlgorithmParameterSpec cannot be null.");
this.keySpecFactory = Assert.notNull(keySpecFactory, "KeySpec factory function cannot be null.");
}
@Override
public KeySpec apply(byte[] d) {
Assert.notEmpty(d, "Key bytes cannot be null or empty.");
Object[] args = new Object[]{params, d};
return this.keySpecFactory.apply(args);
}
}
} }

View File

@ -16,12 +16,16 @@
package io.jsonwebtoken.impl.security; package io.jsonwebtoken.impl.security;
import io.jsonwebtoken.Identifiable; import io.jsonwebtoken.Identifiable;
import io.jsonwebtoken.impl.lang.Bytes;
import io.jsonwebtoken.impl.lang.CheckedFunction; import io.jsonwebtoken.impl.lang.CheckedFunction;
import io.jsonwebtoken.impl.lang.CheckedSupplier;
import io.jsonwebtoken.impl.lang.DefaultRegistry; import io.jsonwebtoken.impl.lang.DefaultRegistry;
import io.jsonwebtoken.impl.lang.Function; import io.jsonwebtoken.impl.lang.Function;
import io.jsonwebtoken.lang.Assert; import io.jsonwebtoken.lang.Assert;
import io.jsonwebtoken.lang.Collections; import io.jsonwebtoken.lang.Collections;
import io.jsonwebtoken.lang.Objects;
import io.jsonwebtoken.lang.Registry; import io.jsonwebtoken.lang.Registry;
import io.jsonwebtoken.lang.Strings;
import io.jsonwebtoken.security.SecurityException; import io.jsonwebtoken.security.SecurityException;
import io.jsonwebtoken.security.SignatureException; import io.jsonwebtoken.security.SignatureException;
@ -32,19 +36,31 @@ import javax.crypto.Mac;
import javax.crypto.NoSuchPaddingException; import javax.crypto.NoSuchPaddingException;
import javax.crypto.SecretKey; import javax.crypto.SecretKey;
import javax.crypto.SecretKeyFactory; import javax.crypto.SecretKeyFactory;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.security.AlgorithmParameters; import java.security.AlgorithmParameters;
import java.security.InvalidAlgorithmParameterException; import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.KeyFactory; import java.security.KeyFactory;
import java.security.KeyPair; import java.security.KeyPair;
import java.security.KeyPairGenerator; import java.security.KeyPairGenerator;
import java.security.MessageDigest; import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.Provider; import java.security.Provider;
import java.security.PublicKey;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.security.Signature; import java.security.Signature;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory; import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.spec.AlgorithmParameterSpec; import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.KeySpec;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.List; import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
public class JcaTemplate { public class JcaTemplate {
@ -71,6 +87,11 @@ public class JcaTemplate {
} }
}); });
// visible for testing
protected Provider findBouncyCastle() {
return Providers.findBouncyCastle();
}
private final String jcaName; private final String jcaName;
private final Provider provider; private final Provider provider;
private final SecureRandom secureRandom; private final SecureRandom secureRandom;
@ -85,10 +106,54 @@ public class JcaTemplate {
this.provider = provider; //may be null, meaning to use the JCA subsystem default provider this.provider = provider; //may be null, meaning to use the JCA subsystem default provider
} }
private <T, R> R execute(Class<T> clazz, CheckedFunction<T, R> fn) throws SecurityException { private <T, R> R execute(Class<T> clazz, CheckedFunction<T, R> callback, Provider provider) throws Exception {
InstanceFactory<?> factory = REGISTRY.get(clazz); InstanceFactory<?> factory = REGISTRY.get(clazz);
Assert.notNull(factory, "Unsupported JCA instance class."); Assert.notNull(factory, "Unsupported JCA instance class.");
return execute(factory, clazz, fn);
Object object = factory.get(this.jcaName, provider);
T instance = Assert.isInstanceOf(clazz, object, "Factory instance does not match expected type.");
return callback.apply(instance);
}
private <T> T execute(Class<?> clazz, CheckedSupplier<T> fn) throws SecurityException {
try {
return fn.get();
} catch (SecurityException se) {
throw se; //propagate
} catch (Throwable t) {
String msg = clazz.getSimpleName() + " callback execution failed: " + t.getMessage();
throw new SecurityException(msg, t);
}
}
private <T, R> R execute(final Class<T> clazz, final CheckedFunction<T, R> fn) throws SecurityException {
return execute(clazz, new CheckedSupplier<R>() {
@Override
public R get() throws Exception {
return execute(clazz, fn, JcaTemplate.this.provider);
}
});
}
protected <T, R> R fallback(final Class<T> clazz, final CheckedFunction<T, R> callback) throws SecurityException {
return execute(clazz, new CheckedSupplier<R>() {
@Override
public R get() throws Exception {
try {
return execute(clazz, callback, JcaTemplate.this.provider);
} catch (Exception e) {
try { // fallback
Provider bc = findBouncyCastle();
if (bc != null) {
return execute(clazz, callback, bc);
}
} catch (Throwable ignored) { // report original exception instead
}
throw e;
}
}
});
} }
public <R> R withCipher(CheckedFunction<Cipher, R> fn) throws SecurityException { public <R> R withCipher(CheckedFunction<Cipher, R> fn) throws SecurityException {
@ -174,17 +239,99 @@ public class JcaTemplate {
}); });
} }
// protected visibility for testing public PublicKey generatePublic(final KeySpec spec) {
private <T, R> R execute(InstanceFactory<?> factory, Class<T> clazz, CheckedFunction<T, R> callback) throws SecurityException { return fallback(KeyFactory.class, new CheckedFunction<KeyFactory, PublicKey>() {
try { @Override
Object object = factory.get(this.jcaName, this.provider); public PublicKey apply(KeyFactory keyFactory) throws Exception {
T instance = Assert.isInstanceOf(clazz, object, "Factory instance does not match expected type."); return keyFactory.generatePublic(spec);
return callback.apply(instance);
} catch (SecurityException se) {
throw se; //propagate
} catch (Exception e) {
throw new SecurityException(factory.getId() + " callback execution failed: " + e.getMessage(), e);
} }
});
}
protected boolean isJdk11() {
return System.getProperty("java.version").startsWith("11");
}
private boolean isJdk8213363Bug(InvalidKeySpecException e) {
return isJdk11() &&
("XDH".equals(this.jcaName) || "X25519".equals(this.jcaName) || "X448".equals(this.jcaName)) &&
e.getCause() instanceof InvalidKeyException &&
!Objects.isEmpty(e.getStackTrace()) &&
"sun.security.ec.XDHKeyFactory".equals(e.getStackTrace()[0].getClassName()) &&
"engineGeneratePrivate".equals(e.getStackTrace()[0].getMethodName());
}
// visible for testing
private int getJdk8213363BugExpectedSize(InvalidKeyException e) {
String msg = e.getMessage();
String prefix = "key length must be ";
if (Strings.hasText(msg) && msg.startsWith(prefix)) {
String expectedSizeString = msg.substring(prefix.length());
try {
return Integer.parseInt(expectedSizeString);
} catch (NumberFormatException ignored) { // return -1 below
}
}
return -1;
}
private KeySpec respecIfNecessary(InvalidKeySpecException e, KeySpec spec) {
if (!(spec instanceof PKCS8EncodedKeySpec)) {
return null;
}
PKCS8EncodedKeySpec pkcs8Spec = (PKCS8EncodedKeySpec) spec;
byte[] encoded = pkcs8Spec.getEncoded();
// Address the [JDK 11 SunCE provider bug](https://bugs.openjdk.org/browse/JDK-8213363) for X25519
// and X448 encoded keys: Even though the key material might be encoded properly, JDK 11's
// SunCE provider incorrectly expects an ASN.1 OCTET STRING (without the DER tag/length prefix)
// when it should actually be a BER-encoded OCTET STRING (with the tag/length prefix).
// So we get the raw key bytes and use our key factory method:
if (isJdk8213363Bug(e)) {
InvalidKeyException cause = // asserted in isJdk8213363Bug method
Assert.isInstanceOf(InvalidKeyException.class, e.getCause(), "Unexpected argument.");
int size = getJdk8213363BugExpectedSize(cause);
if ((size == 32 || size == 56) && Bytes.length(encoded) >= size) {
byte[] adjusted = new byte[size];
System.arraycopy(encoded, encoded.length - size, adjusted, 0, size);
EdwardsCurve curve = size == 32 ? EdwardsCurve.X25519 : EdwardsCurve.X448;
return curve.privateKeySpec(adjusted, false);
}
}
return null;
}
// visible for testing
protected PrivateKey generatePrivate(KeyFactory factory, KeySpec spec) throws InvalidKeySpecException {
return factory.generatePrivate(spec);
}
public PrivateKey generatePrivate(final KeySpec spec) {
return fallback(KeyFactory.class, new CheckedFunction<KeyFactory, PrivateKey>() {
@Override
public PrivateKey apply(KeyFactory keyFactory) throws Exception {
try {
return generatePrivate(keyFactory, spec);
} catch (InvalidKeySpecException e) {
KeySpec respec = respecIfNecessary(e, spec);
if (respec != null) {
return generatePrivate(keyFactory, respec);
}
throw e; // could not respec, propagate
}
}
});
}
public X509Certificate generateX509Certificate(final byte[] x509DerBytes) {
return fallback(CertificateFactory.class, new CheckedFunction<CertificateFactory, X509Certificate>() {
@Override
public X509Certificate apply(CertificateFactory cf) throws CertificateException {
InputStream is = new ByteArrayInputStream(x509DerBytes);
return (X509Certificate) cf.generateCertificate(is);
}
});
} }
private interface InstanceFactory<T> extends Identifiable { private interface InstanceFactory<T> extends Identifiable {
@ -198,6 +345,9 @@ public class JcaTemplate {
private final Class<T> clazz; private final Class<T> clazz;
// Boolean value: missing/null = haven't attempted, true = attempted and succeeded, false = attempted and failed
private final ConcurrentMap<String, Boolean> FALLBACK_ATTEMPTS = new ConcurrentHashMap<>();
JcaInstanceFactory(Class<T> clazz) { JcaInstanceFactory(Class<T> clazz) {
this.clazz = Assert.notNull(clazz, "Class argument cannot be null."); this.clazz = Assert.notNull(clazz, "Class argument cannot be null.");
} }
@ -212,25 +362,66 @@ public class JcaTemplate {
return clazz.getSimpleName(); return clazz.getSimpleName();
} }
// visible for testing
protected Provider findBouncyCastle() {
return Providers.findBouncyCastle();
}
@SuppressWarnings("GrazieInspection")
@Override @Override
public final T get(String jcaName, Provider provider) throws Exception { public final T get(String jcaName, final Provider specifiedProvider) throws Exception {
Assert.hasText(jcaName, "jcaName cannot be null or empty."); Assert.hasText(jcaName, "jcaName cannot be null or empty.");
Provider provider = specifiedProvider;
final Boolean attempted = FALLBACK_ATTEMPTS.get(jcaName);
if (provider == null && attempted != null && attempted) {
// We tried with the default provider previously, and needed to fallback, so just
// preemptively load the fallback to avoid the fallback/retry again:
provider = findBouncyCastle();
}
try { try {
return doGet(jcaName, provider); return doGet(jcaName, provider);
} catch (Exception e) { } catch (NoSuchAlgorithmException nsa) { // try to fallback if possible
String msg = "Unable to obtain " + getId() + " instance from ";
if (provider != null) { if (specifiedProvider == null && attempted == null) { // default provider doesn't support the alg name,
msg += "specified Provider '" + provider + "' "; // and we haven't tried BC yet, so try that now:
} else { Provider fallback = findBouncyCastle();
msg += "default JCA Provider "; if (fallback != null) { // BC found, try again:
try {
T value = doGet(jcaName, fallback);
// record the successful attempt so we don't have to do this again:
FALLBACK_ATTEMPTS.putIfAbsent(jcaName, Boolean.TRUE);
return value;
} catch (Throwable ignored) {
// record the failed attempt so we don't keep trying and propagate original exception:
FALLBACK_ATTEMPTS.putIfAbsent(jcaName, Boolean.FALSE);
} }
msg += "for JCA algorithm '" + jcaName + "': " + e.getMessage(); }
throw wrap(msg, e); }
// otherwise, we tried the fallback, or there isn't a fallback, so no need to try again, so
// propagate the exception:
throw wrap(nsa, jcaName, specifiedProvider, null);
} catch (Exception e) {
throw wrap(e, jcaName, specifiedProvider, null);
} }
} }
protected abstract T doGet(String jcaName, Provider provider) throws Exception; protected abstract T doGet(String jcaName, Provider provider) throws Exception;
// visible for testing:
protected Exception wrap(Exception e, String jcaName, Provider specifiedProvider, Provider fallbackProvider) {
String msg = "Unable to obtain '" + jcaName + "' " + getId() + " instance from ";
if (specifiedProvider != null) {
msg += "specified '" + specifiedProvider + "' Provider";
} else {
msg += "default JCA Provider";
}
if (fallbackProvider != null) {
msg += " or fallback '" + fallbackProvider + "' Provider";
}
msg += ": " + e.getMessage();
return wrap(msg, e);
}
protected Exception wrap(String msg, Exception cause) { protected Exception wrap(String msg, Exception cause) {
if (Signature.class.isAssignableFrom(clazz) || Mac.class.isAssignableFrom(clazz)) { if (Signature.class.isAssignableFrom(clazz) || Mac.class.isAssignableFrom(clazz)) {
return new SignatureException(msg, cause); return new SignatureException(msg, cause);

View File

@ -16,20 +16,13 @@
package io.jsonwebtoken.impl.security; package io.jsonwebtoken.impl.security;
import io.jsonwebtoken.impl.lang.Bytes; import io.jsonwebtoken.impl.lang.Bytes;
import io.jsonwebtoken.impl.lang.CheckedFunction;
import io.jsonwebtoken.impl.lang.Conditions;
import io.jsonwebtoken.impl.lang.Converter; import io.jsonwebtoken.impl.lang.Converter;
import io.jsonwebtoken.io.Decoders; import io.jsonwebtoken.io.Decoders;
import io.jsonwebtoken.io.Encoders; import io.jsonwebtoken.io.Encoders;
import io.jsonwebtoken.lang.Assert; import io.jsonwebtoken.lang.Assert;
import io.jsonwebtoken.lang.Strings;
import io.jsonwebtoken.security.SecurityException; import io.jsonwebtoken.security.SecurityException;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.security.Provider;
import java.security.cert.CertificateEncodingException; import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
public class JwtX509StringConverter implements Converter<X509Certificate, String> { public class JwtX509StringConverter implements Converter<X509Certificate, String> {
@ -59,49 +52,19 @@ public class JwtX509StringConverter implements Converter<X509Certificate, String
} }
// visible for testing // visible for testing
protected X509Certificate toCert(final byte[] der, Provider provider) throws SecurityException { protected X509Certificate toCert(final byte[] der) throws SecurityException {
JcaTemplate template = new JcaTemplate("X.509", provider); return new JcaTemplate("X.509", null).generateX509Certificate(der);
final InputStream is = new ByteArrayInputStream(der);
return template.withCertificateFactory(new CheckedFunction<CertificateFactory, X509Certificate>() {
@Override
public X509Certificate apply(CertificateFactory cf) throws Exception {
return (X509Certificate) cf.generateCertificate(is);
}
});
} }
@Override @Override
public X509Certificate applyFrom(String s) { public X509Certificate applyFrom(String s) {
Assert.hasText(s, "X.509 Certificate encoded string cannot be null or empty."); Assert.hasText(s, "X.509 Certificate encoded string cannot be null or empty.");
byte[] der = null;
try { try {
der = Decoders.BASE64.decode(s); //RFC requires Base64, not Base64Url byte[] der = Decoders.BASE64.decode(s); //RFC requires Base64, not Base64Url
return toCert(der, null); return toCert(der);
} catch (final Throwable t) { } catch (Exception e) {
String msg = "Unable to convert Base64 String '" + s + "' to X509Certificate instance. Cause: " + e.getMessage();
// Some JDK implementations don't support RSASSA-PSS certificates: throw new IllegalArgumentException(msg, e);
//
// https://bugs.openjdk.org/browse/JDK-8242556
//
// Oracle only backported this fix to JDK 8u271+, 11.0.9+, and 15+, so we'll try to fall back to
// BC (which can read the files correctly) on JDK 9, 10, 12, 13, and 14:
String causeMsg = t.getMessage();
Provider bc = null;
if (!Bytes.isEmpty(der) && // Base64 decoding succeeded, so we can continue to try
Strings.hasText(causeMsg) && causeMsg.contains(RsaSignatureAlgorithm.PSS_OID)) {
// OID in exception message, so odds are high that the default provider doesn't support X.509
// certificates with a PSS_OID `AlgorithmId`. But BC does, so try to obtain that if we can:
bc = Providers.findBouncyCastle(Conditions.TRUE);
}
if (bc != null) {
try {
return toCert(der, bc);
} catch (Throwable ignored) {
// ignore this - we want to report the original exception to the caller
}
}
String msg = "Unable to convert Base64 String '" + s + "' to X509Certificate instance. Cause: " + causeMsg;
throw new IllegalArgumentException(msg, t);
} }
} }
} }

View File

@ -25,11 +25,11 @@ import java.security.spec.AlgorithmParameterSpec;
public class NamedParameterSpecValueFinder implements Function<Key, String> { public class NamedParameterSpecValueFinder implements Function<Key, String> {
private static final Function<Key, AlgorithmParameterSpec> EDEC_KEY_GET_PARAMS = private static final Function<Key, AlgorithmParameterSpec> EDEC_KEY_GET_PARAMS =
new OptionalMethodInvoker<>("java.security.interfaces.EdECKey", "getParams"); new OptionalMethodInvoker<>("java.security.interfaces.EdECKey", "getParams"); // >= JDK 15
private static final Function<Key, AlgorithmParameterSpec> XEC_KEY_GET_PARAMS = private static final Function<Key, AlgorithmParameterSpec> XEC_KEY_GET_PARAMS =
new OptionalMethodInvoker<>("java.security.interfaces.XECKey", "getParams"); new OptionalMethodInvoker<>("java.security.interfaces.XECKey", "getParams"); // >= JDK 11
private static final Function<Object, String> GET_NAME = private static final Function<Object, String> GET_NAME =
new OptionalMethodInvoker<>("java.security.spec.NamedParameterSpec", "getName"); new OptionalMethodInvoker<>("java.security.spec.NamedParameterSpec", "getName"); // >= JDK 11
private static final Function<Key, String> COMPOSED = Functions.andThen(Functions.firstResult(EDEC_KEY_GET_PARAMS, XEC_KEY_GET_PARAMS), GET_NAME); private static final Function<Key, String> COMPOSED = Functions.andThen(Functions.firstResult(EDEC_KEY_GET_PARAMS, XEC_KEY_GET_PARAMS), GET_NAME);

View File

@ -19,8 +19,6 @@ import io.jsonwebtoken.JweHeader;
import io.jsonwebtoken.impl.DefaultJweHeader; import io.jsonwebtoken.impl.DefaultJweHeader;
import io.jsonwebtoken.impl.lang.Bytes; import io.jsonwebtoken.impl.lang.Bytes;
import io.jsonwebtoken.impl.lang.CheckedFunction; import io.jsonwebtoken.impl.lang.CheckedFunction;
import io.jsonwebtoken.impl.lang.CheckedSupplier;
import io.jsonwebtoken.impl.lang.Conditions;
import io.jsonwebtoken.impl.lang.FieldReadable; import io.jsonwebtoken.impl.lang.FieldReadable;
import io.jsonwebtoken.impl.lang.RequiredFieldReader; import io.jsonwebtoken.impl.lang.RequiredFieldReader;
import io.jsonwebtoken.lang.Assert; import io.jsonwebtoken.lang.Assert;
@ -114,16 +112,6 @@ public class Pbes2HsAkwAlgorithm extends CryptoAlgorithm implements KeyAlgorithm
this.DERIVED_KEY_BIT_LENGTH = hashBitLength / 2; // results in 128, 192, or 256 this.DERIVED_KEY_BIT_LENGTH = hashBitLength / 2; // results in 128, 192, or 256
this.SALT_PREFIX = toRfcSaltPrefix(getId().getBytes(StandardCharsets.UTF_8)); this.SALT_PREFIX = toRfcSaltPrefix(getId().getBytes(StandardCharsets.UTF_8));
// PBKDF2WithHmacSHA* algorithms are only available on JDK 8 and later, so enable BC as a backup provider if
// necessary for <= JDK 7:
// TODO: remove when dropping Java 7 support:
setProvider(Providers.findBouncyCastle(Conditions.notExists(new CheckedSupplier<SecretKeyFactory>() {
@Override
public SecretKeyFactory get() throws Exception {
return SecretKeyFactory.getInstance(getJcaName());
}
})));
} }
// protected visibility for testing // protected visibility for testing

View File

@ -15,7 +15,6 @@
*/ */
package io.jsonwebtoken.impl.security; package io.jsonwebtoken.impl.security;
import io.jsonwebtoken.impl.lang.Condition;
import io.jsonwebtoken.lang.Classes; import io.jsonwebtoken.lang.Classes;
import java.security.Provider; import java.security.Provider;
@ -34,7 +33,21 @@ final class Providers {
private Providers() { private Providers() {
} }
private static Provider findBouncyCastle() { /**
* Returns the BouncyCastle provider if and only if BouncyCastle is available, or {@code null} otherwise.
*
* <p>If the JVM runtime already has BouncyCastle registered
* (e.g. {@code Security.addProvider(bcProvider)}, that Provider instance will be found and returned.
* If an existing BC provider is not found, a new BC instance will be created, cached for future reference,
* and returned.</p>
*
* <p>If a new BC provider is created and returned, it is <em>not</em> registered in the JVM via
* {@code Security.addProvider} to ensure JJWT doesn't interfere with the application security provider
* configuration and/or expectations.</p>
*
* @return any available BouncyCastle Provider, or {@code null} if BouncyCastle is not available.
*/
public static Provider findBouncyCastle() {
if (!BOUNCY_CASTLE_AVAILABLE) { if (!BOUNCY_CASTLE_AVAILABLE) {
return null; return null;
} }
@ -58,28 +71,4 @@ final class Providers {
} }
return provider; return provider;
} }
/**
* Returns the BouncyCastle provider if and only if the specified Condition evaluates to {@code true}
* <em>and</em> BouncyCastle is available. Returns {@code null} otherwise.
*
* <p>If the condition evaluates to true and the JVM runtime already has BouncyCastle registered
* (e.g. {@code Security.addProvider(bcProvider)}, that Provider instance will be found and returned.
* If an existing BC provider is not found, a new BC instance will be created, cached for future reference,
* and returned.</p>
*
* <p>If a new BC provider is created and returned, it is <em>not</em> registered in the JVM via
* {@code Security.addProvider} to ensure JJWT doesn't interfere with the application security provider
* configuration and/or expectations.</p>
*
* @param c condition to evaluate
* @return any available BouncyCastle Provider if {@code c} evaluates to true, or {@code null} if either
* {@code c} evaluates to false, or BouncyCastle is not available.
*/
public static Provider findBouncyCastle(Condition c) {
if (c.test()) {
return findBouncyCastle();
}
return null;
}
} }

View File

@ -16,8 +16,6 @@
package io.jsonwebtoken.impl.security; package io.jsonwebtoken.impl.security;
import io.jsonwebtoken.impl.lang.CheckedFunction; import io.jsonwebtoken.impl.lang.CheckedFunction;
import io.jsonwebtoken.impl.lang.CheckedSupplier;
import io.jsonwebtoken.impl.lang.Conditions;
import io.jsonwebtoken.lang.Assert; import io.jsonwebtoken.lang.Assert;
import io.jsonwebtoken.lang.Collections; import io.jsonwebtoken.lang.Collections;
import io.jsonwebtoken.lang.Strings; import io.jsonwebtoken.lang.Strings;
@ -106,13 +104,6 @@ final class RsaSignatureAlgorithm extends AbstractSignatureAlgorithm {
// RSASSA-PSS constructor // RSASSA-PSS constructor
private RsaSignatureAlgorithm(int digestBitLength, AlgorithmParameterSpec paramSpec) { private RsaSignatureAlgorithm(int digestBitLength, AlgorithmParameterSpec paramSpec) {
this("PS" + digestBitLength, PSS_JCA_NAME, digestBitLength, paramSpec); this("PS" + digestBitLength, PSS_JCA_NAME, digestBitLength, paramSpec);
// RSASSA-PSS is not available natively until JDK 11, so try to load BC as a backup provider if possible:
setProvider(Providers.findBouncyCastle(Conditions.notExists(new CheckedSupplier<Signature>() {
@Override
public Signature get() throws Exception {
return Signature.getInstance(PSS_JCA_NAME);
}
})));
} }
static SignatureAlgorithm findByKey(Key key) { static SignatureAlgorithm findByKey(Key key) {
@ -173,9 +164,7 @@ final class RsaSignatureAlgorithm extends AbstractSignatureAlgorithm {
// return new DefaultKeyPairBuilder(jcaName, keyGenSpec).provider(getProvider()).random(Randoms.secureRandom()); // return new DefaultKeyPairBuilder(jcaName, keyGenSpec).provider(getProvider()).random(Randoms.secureRandom());
// //
return new DefaultKeyPairBuilder(jcaName, this.preferredKeyBitLength) return new DefaultKeyPairBuilder(jcaName, this.preferredKeyBitLength).random(Randoms.secureRandom());
.provider(getProvider())
.random(Randoms.secureRandom());
} }
@Override @Override

View File

@ -15,18 +15,11 @@
*/ */
package io.jsonwebtoken.impl.security; package io.jsonwebtoken.impl.security;
import io.jsonwebtoken.impl.lang.CheckedSupplier;
import io.jsonwebtoken.impl.lang.Conditions;
import io.jsonwebtoken.impl.lang.DelegatingRegistry; import io.jsonwebtoken.impl.lang.DelegatingRegistry;
import io.jsonwebtoken.impl.lang.IdRegistry; import io.jsonwebtoken.impl.lang.IdRegistry;
import io.jsonwebtoken.lang.Assert;
import io.jsonwebtoken.lang.Collections; import io.jsonwebtoken.lang.Collections;
import io.jsonwebtoken.security.HashAlgorithm; import io.jsonwebtoken.security.HashAlgorithm;
import java.security.MessageDigest;
import java.security.Provider;
import java.util.Locale;
/** /**
* Backing implementation for the {@link io.jsonwebtoken.security.Jwks.HASH} implementation. * Backing implementation for the {@link io.jsonwebtoken.security.Jwks.HASH} implementation.
* *
@ -35,37 +28,18 @@ import java.util.Locale;
@SuppressWarnings("unused") // used via reflection in io.jsonwebtoken.security.Jwks.HASH @SuppressWarnings("unused") // used via reflection in io.jsonwebtoken.security.Jwks.HASH
public class StandardHashAlgorithms extends DelegatingRegistry<String, HashAlgorithm> { public class StandardHashAlgorithms extends DelegatingRegistry<String, HashAlgorithm> {
private static class MessageDigestSupplier implements CheckedSupplier<MessageDigest> {
private final String jcaName;
private MessageDigestSupplier(String jcaName) {
this.jcaName = Assert.hasText(jcaName, "jcaName cannot be null or empty.");
}
@Override
public MessageDigest get() throws Exception {
return MessageDigest.getInstance(jcaName);
}
}
private static DefaultHashAlgorithm fallbackProvider(String id) {
String jcaName = id.toUpperCase(Locale.ENGLISH);
Provider provider = Providers.findBouncyCastle(Conditions.notExists(new MessageDigestSupplier(jcaName)));
return new DefaultHashAlgorithm(id, jcaName, provider);
}
public StandardHashAlgorithms() { public StandardHashAlgorithms() {
super(new IdRegistry<>("IANA Hash Algorithm", Collections.of( super(new IdRegistry<>("IANA Hash Algorithm", Collections.<HashAlgorithm>of(
// We don't include DefaultHashAlgorithm.SHA1 here on purpose because 1) it's not in the JWK IANA // We don't include DefaultHashAlgorithm.SHA1 here on purpose because 1) it's not in the JWK IANA
// registry so we don't need to expose it anyway, and 2) we don't want to expose a less-safe algorithm. // registry so we don't need to expose it anyway, and 2) we don't want to expose a less-safe algorithm.
// The SHA1 instance only exists in JJWT's codebase to support RFC-required `x5t` // The SHA1 instance only exists in JJWT's codebase to support RFC-required `x5t`
// (X.509 SHA-1 Thumbprint) computation - we don't use it anywhere else. // (X.509 SHA-1 Thumbprint) computation - we don't use it anywhere else.
(HashAlgorithm) new DefaultHashAlgorithm("sha-256"), new DefaultHashAlgorithm("sha-256"),
new DefaultHashAlgorithm("sha-384"), new DefaultHashAlgorithm("sha-384"),
new DefaultHashAlgorithm("sha-512"), new DefaultHashAlgorithm("sha-512"),
fallbackProvider("sha3-256"), new DefaultHashAlgorithm("sha3-256"),
fallbackProvider("sha3-384"), new DefaultHashAlgorithm("sha3-384"),
fallbackProvider("sha3-512") new DefaultHashAlgorithm("sha3-512")
))); )));
} }
} }

View File

@ -1,68 +0,0 @@
/*
* Copyright © 2023 jsonwebtoken.io
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.jsonwebtoken.impl.lang
import org.junit.Test
import javax.crypto.spec.PBEKeySpec
import static org.junit.Assert.*
class OptionalCtorInvokerTest {
@Test
void testCtorWithClassArg() {
String foo = 'test'
def fn = new OptionalCtorInvoker<>("java.lang.String", String.class) // copy constructor
def result = fn.apply(foo)
assertEquals foo, result
}
@Test
void testCtorWithFqcnArg() {
String foo = 'test'
def fn = new OptionalCtorInvoker<>("java.lang.String", "java.lang.String") // copy constructor
def result = fn.apply(foo)
assertEquals foo, result
}
@Test
void testCtorWithMultipleMixedArgTypes() {
char[] chars = "foo".toCharArray()
byte[] salt = [0x00, 0x01, 0x02, 0x03] as byte[]
int iterations = 256
def fn = new OptionalCtorInvoker<>("javax.crypto.spec.PBEKeySpec", char[].class, byte[].class, int.class) //password, salt, iteration count
def args = [chars, salt, iterations] as Object[]
def result = fn.apply(args) as PBEKeySpec
assertArrayEquals chars, result.getPassword()
assertArrayEquals salt, result.getSalt()
assertEquals iterations, result.getIterationCount()
}
@Test
void testZeroArgConstructor() {
OptionalCtorInvoker fn = new OptionalCtorInvoker("java.util.LinkedHashMap")
Object args = null
def result = fn.apply(args)
assertTrue result instanceof LinkedHashMap
}
@Test
void testMissingConstructor() {
def fn = new OptionalCtorInvoker('com.foo.Bar')
assertNull fn.apply(null)
}
}

View File

@ -15,7 +15,7 @@
*/ */
package io.jsonwebtoken.impl.security package io.jsonwebtoken.impl.security
import io.jsonwebtoken.impl.lang.Conditions
import io.jsonwebtoken.security.Jwk import io.jsonwebtoken.security.Jwk
import io.jsonwebtoken.security.Jwks import io.jsonwebtoken.security.Jwks
import io.jsonwebtoken.security.MalformedKeyException import io.jsonwebtoken.security.MalformedKeyException
@ -144,9 +144,10 @@ class AbstractJwkBuilderTest {
@Test @Test
void testProvider() { void testProvider() {
def provider = Providers.findBouncyCastle(Conditions.TRUE) def provider = TestKeys.BC
def jwk = builder().provider(provider).build() def jwk = builder().provider(provider).build()
assertEquals 'oct', jwk.getType() assertEquals 'oct', jwk.getType()
assertSame provider, jwk.@context.@provider
} }
@Test @Test

View File

@ -21,16 +21,12 @@ import io.jsonwebtoken.MalformedJwtException
import io.jsonwebtoken.impl.DefaultMutableJweHeader import io.jsonwebtoken.impl.DefaultMutableJweHeader
import io.jsonwebtoken.impl.lang.Bytes import io.jsonwebtoken.impl.lang.Bytes
import io.jsonwebtoken.impl.lang.CheckedFunction import io.jsonwebtoken.impl.lang.CheckedFunction
import io.jsonwebtoken.impl.lang.CheckedSupplier
import io.jsonwebtoken.impl.lang.Conditions
import io.jsonwebtoken.lang.Arrays import io.jsonwebtoken.lang.Arrays
import io.jsonwebtoken.security.SecretKeyBuilder import io.jsonwebtoken.security.SecretKeyBuilder
import org.junit.Test import org.junit.Test
import javax.crypto.Cipher import javax.crypto.Cipher
import javax.crypto.SecretKeyFactory
import javax.crypto.spec.GCMParameterSpec import javax.crypto.spec.GCMParameterSpec
import java.security.Provider
import static org.junit.Assert.* import static org.junit.Assert.*
@ -54,17 +50,7 @@ class AesGcmKeyAlgorithmTest {
final String jcaName = "AES/GCM/NoPadding" final String jcaName = "AES/GCM/NoPadding"
// AES/GCM/NoPadding is only available on JDK 8 and later, so enable BC as a backup provider if JcaTemplate template = new JcaTemplate(jcaName, null)
// necessary for <= JDK 7:
// TODO: remove when dropping Java 7 support:
Provider provider = Providers.findBouncyCastle(Conditions.notExists(new CheckedSupplier<SecretKeyFactory>() {
@Override
SecretKeyFactory get() throws Exception {
return SecretKeyFactory.getInstance(jcaName)
}
}))
JcaTemplate template = new JcaTemplate(jcaName, provider)
byte[] jcaResult = template.withCipher(new CheckedFunction<Cipher, byte[]>() { byte[] jcaResult = template.withCipher(new CheckedFunction<Cipher, byte[]>() {
@Override @Override
byte[] apply(Cipher cipher) throws Exception { byte[] apply(Cipher cipher) throws Exception {

View File

@ -15,12 +15,9 @@
*/ */
package io.jsonwebtoken.impl.security package io.jsonwebtoken.impl.security
import io.jsonwebtoken.security.Request
import org.junit.Test import org.junit.Test
import java.security.Provider
import static org.easymock.EasyMock.*
import static org.junit.Assert.* import static org.junit.Assert.*
class CryptoAlgorithmTest { class CryptoAlgorithmTest {
@ -70,54 +67,6 @@ class CryptoAlgorithmTest {
assertSame Randoms.secureRandom(), random assertSame Randoms.secureRandom(), random
} }
@Test
void testRequestProviderPriorityOverDefaultProvider() {
def alg = new TestCryptoAlgorithm('test', 'test')
Provider defaultProvider = createMock(Provider)
Provider requestProvider = createMock(Provider)
Request request = createMock(Request)
alg.setProvider(defaultProvider)
expect(request.getProvider()).andReturn(requestProvider)
replay request, requestProvider, defaultProvider
assertSame requestProvider, alg.getProvider(request) // assert we get back the request provider, not the default
verify request, requestProvider, defaultProvider
}
@Test
void testMissingRequestProviderUsesDefaultProvider() {
def alg = new TestCryptoAlgorithm('test', 'test')
Provider defaultProvider = createMock(Provider)
Request request = createMock(Request)
alg.setProvider(defaultProvider)
expect(request.getProvider()).andReturn(null)
replay request, defaultProvider
assertSame defaultProvider, alg.getProvider(request) // assert we get back the default provider
verify request, defaultProvider
}
@Test
void testMissingRequestAndDefaultProviderReturnsNull() {
def alg = new TestCryptoAlgorithm('test', 'test')
Request request = createMock(Request)
expect(request.getProvider()).andReturn(null)
replay request
assertNull alg.getProvider(request) // null return value means use JCA internal default provider
verify request
}
class TestCryptoAlgorithm extends CryptoAlgorithm { class TestCryptoAlgorithm extends CryptoAlgorithm {
TestCryptoAlgorithm(String id, String jcaName) { TestCryptoAlgorithm(String id, String jcaName) {
super(id, jcaName) super(id, jcaName)

View File

@ -15,7 +15,7 @@
*/ */
package io.jsonwebtoken.impl.security package io.jsonwebtoken.impl.security
import io.jsonwebtoken.impl.lang.Conditions
import io.jsonwebtoken.impl.lang.Services import io.jsonwebtoken.impl.lang.Services
import io.jsonwebtoken.io.DeserializationException import io.jsonwebtoken.io.DeserializationException
import io.jsonwebtoken.io.Deserializer import io.jsonwebtoken.io.Deserializer
@ -26,7 +26,6 @@ import org.junit.Test
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
import java.security.Key import java.security.Key
import java.security.Provider
import static org.junit.Assert.* import static org.junit.Assert.*
@ -45,13 +44,7 @@ class DefaultJwkParserTest {
def serializer = Services.loadFirst(Serializer) def serializer = Services.loadFirst(Serializer)
for (Key key : keys) { for (Key key : keys) {
//noinspection GroovyAssignabilityCheck //noinspection GroovyAssignabilityCheck
Provider provider = null // assume default, but switch if key requires it def jwk = Jwks.builder().key(key).build()
if (key.getClass().getName().startsWith("org.bouncycastle.")) {
// No native JVM support for the key, so we need to enable BC:
provider = Providers.findBouncyCastle(Conditions.TRUE)
}
//noinspection GroovyAssignabilityCheck
def jwk = Jwks.builder().provider(provider).key(key).build()
def data = serializer.serialize(jwk) def data = serializer.serialize(jwk)
String json = new String(data, StandardCharsets.UTF_8) String json = new String(data, StandardCharsets.UTF_8)
def parsed = Jwks.parser().build().parse(json) def parsed = Jwks.parser().build().parse(json)
@ -70,30 +63,18 @@ class DefaultJwkParserTest {
} }
def serializer = Services.loadFirst(Serializer) def serializer = Services.loadFirst(Serializer)
def provider = Providers.findBouncyCastle(Conditions.TRUE) //always used def provider = TestKeys.BC //always used
for (Key key : keys) { for (Key key : keys) {
//noinspection GroovyAssignabilityCheck //noinspection GroovyAssignabilityCheck
def jwk = Jwks.builder().provider(provider).key(key).build() def jwk = Jwks.builder().provider(provider).key(key).build()
def data = serializer.serialize(jwk) def data = serializer.serialize(jwk)
String json = new String(data, StandardCharsets.UTF_8) String json = new String(data, StandardCharsets.UTF_8)
def parsed = Jwks.parser().build().parse(json)
assertEquals jwk, parsed
//assertSame provider, parsed.@context.@provider
}
}
@Test
void testParseWithProvider() {
def provider = Providers.findBouncyCastle(Conditions.TRUE)
def jwk = Jwks.builder().provider(provider).key(TestKeys.HS256).build()
def serializer = Services.loadFirst(Serializer)
def data = serializer.serialize(jwk)
String json = new String(data, StandardCharsets.UTF_8)
def parsed = Jwks.parser().provider(provider).build().parse(json) def parsed = Jwks.parser().provider(provider).build().parse(json)
assertEquals jwk, parsed assertEquals jwk, parsed
assertSame provider, parsed.@context.@provider assertSame provider, parsed.@context.@provider
} }
}
@Test @Test
void testDeserializationFailure() { void testDeserializationFailure() {

View File

@ -20,7 +20,6 @@ import io.jsonwebtoken.Jwts
import io.jsonwebtoken.MalformedJwtException import io.jsonwebtoken.MalformedJwtException
import io.jsonwebtoken.impl.DefaultJweHeader import io.jsonwebtoken.impl.DefaultJweHeader
import io.jsonwebtoken.impl.DefaultMutableJweHeader import io.jsonwebtoken.impl.DefaultMutableJweHeader
import io.jsonwebtoken.impl.lang.Conditions
import io.jsonwebtoken.security.DecryptionKeyRequest import io.jsonwebtoken.security.DecryptionKeyRequest
import io.jsonwebtoken.security.InvalidKeyException import io.jsonwebtoken.security.InvalidKeyException
import io.jsonwebtoken.security.Jwks import io.jsonwebtoken.security.Jwks
@ -46,7 +45,7 @@ class EcdhKeyAlgorithmTest {
def alg = new EcdhKeyAlgorithm() def alg = new EcdhKeyAlgorithm()
PublicKey encKey = TestKeys.X25519.pair.public as PublicKey PublicKey encKey = TestKeys.X25519.pair.public as PublicKey
def header = new DefaultMutableJweHeader(Jwts.header()) def header = new DefaultMutableJweHeader(Jwts.header())
def provider = Providers.findBouncyCastle(Conditions.TRUE) def provider = TestKeys.BC
def request = new DefaultKeyRequest(encKey, provider, null, header, Jwts.ENC.A128GCM) def request = new DefaultKeyRequest(encKey, provider, null, header, Jwts.ENC.A128GCM)
def result = alg.getEncryptionKey(request) def result = alg.getEncryptionKey(request)
assertNotNull result.getKey() assertNotNull result.getKey()
@ -59,7 +58,7 @@ class EcdhKeyAlgorithmTest {
PublicKey encKey = TestKeys.X25519.pair.public as PublicKey PublicKey encKey = TestKeys.X25519.pair.public as PublicKey
PrivateKey decKey = TestKeys.X25519.pair.private as PrivateKey PrivateKey decKey = TestKeys.X25519.pair.private as PrivateKey
def header = Jwts.header() def header = Jwts.header()
def provider = Providers.findBouncyCastle(Conditions.TRUE) def provider = TestKeys.BC
// encrypt // encrypt
def delegate = new DefaultMutableJweHeader(header) def delegate = new DefaultMutableJweHeader(header)

View File

@ -16,15 +16,11 @@
package io.jsonwebtoken.impl.security package io.jsonwebtoken.impl.security
import io.jsonwebtoken.impl.lang.Bytes import io.jsonwebtoken.impl.lang.Bytes
import io.jsonwebtoken.impl.lang.Function
import io.jsonwebtoken.impl.lang.Functions
import io.jsonwebtoken.security.InvalidKeyException import io.jsonwebtoken.security.InvalidKeyException
import io.jsonwebtoken.security.UnsupportedKeyException import io.jsonwebtoken.security.UnsupportedKeyException
import org.junit.Test import org.junit.Test
import java.security.spec.AlgorithmParameterSpec import java.security.spec.PKCS8EncodedKeySpec
import java.security.spec.ECGenParameterSpec
import java.security.spec.KeySpec
import static org.junit.Assert.* import static org.junit.Assert.*
@ -71,7 +67,7 @@ class EdwardsCurveTest {
@Test @Test
void testFindByKey() { // happy path test void testFindByKey() { // happy path test
for(def alg : EdwardsCurve.VALUES) { for (def alg : EdwardsCurve.VALUES) {
def keyPair = alg.keyPair().build() def keyPair = alg.keyPair().build()
def pub = keyPair.public def pub = keyPair.public
def priv = keyPair.private def priv = keyPair.private
@ -88,7 +84,7 @@ class EdwardsCurveTest {
@Test @Test
void testFindByKeyUsingEncoding() { void testFindByKeyUsingEncoding() {
curves.each { curves.each {
def pair = TestKeys.forCurve(it).pair def pair = TestKeys.forAlgorithm(it).pair
def key = new TestKey(algorithm: 'foo', encoded: pair.public.getEncoded()) def key = new TestKey(algorithm: 'foo', encoded: pair.public.getEncoded())
def found = EdwardsCurve.findByKey(key) def found = EdwardsCurve.findByKey(key)
assertEquals(it, found) assertEquals(it, found)
@ -107,7 +103,7 @@ class EdwardsCurveTest {
@Test @Test
void testFindByKeyUsingMalformedEncoding() { void testFindByKeyUsingMalformedEncoding() {
curves.each { curves.each {
byte[] encoded = EdwardsCurve.DER_OID_PREFIX // just the prefix isn't enough byte[] encoded = EdwardsCurve.ASN1_OID_PREFIX // just the prefix isn't enough
def key = new TestKey(algorithm: 'foo', encoded: encoded) def key = new TestKey(algorithm: 'foo', encoded: encoded)
assertNull EdwardsCurve.findByKey(key) assertNull EdwardsCurve.findByKey(key)
} }
@ -116,10 +112,10 @@ class EdwardsCurveTest {
@Test @Test
void testToPrivateKey() { void testToPrivateKey() {
curves.each { curves.each {
def pair = TestKeys.forCurve(it).pair def pair = TestKeys.forAlgorithm(it).pair
def key = pair.getPrivate() def key = pair.getPrivate()
def d = it.getKeyMaterial(key) def d = it.getKeyMaterial(key)
def result = it.toPrivateKey(d, it.getProvider()) def result = it.toPrivateKey(d, null)
assertEquals(key, result) assertEquals(key, result)
} }
} }
@ -127,10 +123,11 @@ class EdwardsCurveTest {
@Test @Test
void testToPublicKey() { void testToPublicKey() {
curves.each { curves.each {
def pair = TestKeys.forCurve(it).pair def bundle = TestKeys.forAlgorithm(it)
def pair = bundle.pair
def key = pair.getPublic() def key = pair.getPublic()
def x = it.getKeyMaterial(key) def x = it.getKeyMaterial(key)
def result = it.toPublicKey(x, it.getProvider()) def result = it.toPublicKey(x, null)
assertEquals(key, result) assertEquals(key, result)
} }
} }
@ -141,7 +138,7 @@ class EdwardsCurveTest {
byte[] d = new byte[it.encodedKeyByteLength + 1] // more than required byte[] d = new byte[it.encodedKeyByteLength + 1] // more than required
Randoms.secureRandom().nextBytes(d) Randoms.secureRandom().nextBytes(d)
try { try {
it.toPrivateKey(d, it.getProvider()) it.toPrivateKey(d, null)
} catch (InvalidKeyException ike) { } catch (InvalidKeyException ike) {
String msg = "Invalid ${it.id} encoded PrivateKey length. Should be " + String msg = "Invalid ${it.id} encoded PrivateKey length. Should be " +
"${Bytes.bitsMsg(it.keyBitLength)}, found ${Bytes.bytesMsg(d.length)}." "${Bytes.bitsMsg(it.keyBitLength)}, found ${Bytes.bytesMsg(d.length)}."
@ -150,13 +147,24 @@ class EdwardsCurveTest {
} }
} }
@Test
void testPrivateKeySpecJdk11() {
curves.each {
byte[] d = new byte[it.encodedKeyByteLength]; Randoms.secureRandom().nextBytes(d)
def keySpec = it.privateKeySpec(d, false) // standard = false for JDK 11 bug
assertTrue keySpec instanceof PKCS8EncodedKeySpec
def expectedEncoded = Bytes.concat(it.PRIVATE_KEY_JDK11_PREFIX, d)
assertArrayEquals expectedEncoded, ((PKCS8EncodedKeySpec)keySpec).getEncoded()
}
}
@Test @Test
void testToPublicKeyInvalidLength() { void testToPublicKeyInvalidLength() {
curves.each { curves.each {
byte[] x = new byte[it.encodedKeyByteLength - 1] // less than required byte[] x = new byte[it.encodedKeyByteLength - 1] // less than required
Randoms.secureRandom().nextBytes(x) Randoms.secureRandom().nextBytes(x)
try { try {
it.toPublicKey(x, it.getProvider()) it.toPublicKey(x, null)
} catch (InvalidKeyException ike) { } catch (InvalidKeyException ike) {
String msg = "Invalid ${it.id} encoded PublicKey length. Should be " + String msg = "Invalid ${it.id} encoded PublicKey length. Should be " +
"${Bytes.bitsMsg(it.keyBitLength)}, found ${Bytes.bytesMsg(x.length)}." "${Bytes.bitsMsg(it.keyBitLength)}, found ${Bytes.bytesMsg(x.length)}."
@ -179,7 +187,7 @@ class EdwardsCurveTest {
byte[] encoded = Bytes.concat( byte[] encoded = Bytes.concat(
[0x30, it.encodedKeyByteLength + 10 + DER_NULL.length, 0x30, 0x05] as byte[], [0x30, it.encodedKeyByteLength + 10 + DER_NULL.length, 0x30, 0x05] as byte[],
it.DER_OID, it.ASN1_OID,
DER_NULL, // this should be skipped when getting key material DER_NULL, // this should be skipped when getting key material
[0x03, it.encodedKeyByteLength + 1, 0x00] as byte[], [0x03, it.encodedKeyByteLength + 1, 0x00] as byte[],
x x
@ -217,7 +225,7 @@ class EdwardsCurveTest {
it.getKeyMaterial(key) it.getKeyMaterial(key)
fail() fail()
} catch (InvalidKeyException ike) { } catch (InvalidKeyException ike) {
String msg = "Invalid ${it.getId()} DER encoding: Missing or incorrect algorithm OID." as String String msg = "Invalid ${it.getId()} ASN.1 encoding: Missing or incorrect algorithm OID." as String
assertEquals msg, ike.getMessage() assertEquals msg, ike.getMessage()
} }
} }
@ -231,13 +239,13 @@ class EdwardsCurveTest {
encoded[0] = 0x20 // anything other than 0x03, 0x04, 0x05 encoded[0] = 0x20 // anything other than 0x03, 0x04, 0x05
curves.each { curves.each {
// prefix it with the OID to make it look valid: // prefix it with the OID to make it look valid:
encoded = Bytes.concat(it.DER_OID, encoded) encoded = Bytes.concat(it.ASN1_OID, encoded)
def key = new TestKey(encoded: encoded) def key = new TestKey(encoded: encoded)
try { try {
it.getKeyMaterial(key) it.getKeyMaterial(key)
fail() fail()
} catch (InvalidKeyException ike) { } catch (InvalidKeyException ike) {
String msg = "Invalid ${it.getId()} DER encoding: Invalid key length." as String String msg = "Invalid ${it.getId()} ASN.1 encoding: Invalid key length." as String
assertEquals msg, ike.getMessage() assertEquals msg, ike.getMessage()
} }
} }
@ -251,13 +259,13 @@ class EdwardsCurveTest {
size = it.encodedKeyByteLength size = it.encodedKeyByteLength
byte[] keyBytes = new byte[size] byte[] keyBytes = new byte[size]
Randoms.secureRandom().nextBytes(keyBytes) Randoms.secureRandom().nextBytes(keyBytes)
byte[] encoded = Bytes.concat(it.PUBLIC_KEY_DER_PREFIX, keyBytes) byte[] encoded = Bytes.concat(it.PUBLIC_KEY_ASN1_PREFIX, keyBytes)
encoded[11] = 0x01 // should always be zero encoded[11] = 0x01 // should always be zero
def key = new TestKey(encoded: encoded) def key = new TestKey(encoded: encoded)
it.getKeyMaterial(key) it.getKeyMaterial(key)
fail() fail()
} catch (InvalidKeyException ike) { } catch (InvalidKeyException ike) {
String msg = "Invalid ${it.getId()} DER encoding: BIT STREAM should not indicate unused bytes." as String String msg = "Invalid ${it.getId()} ASN.1 encoding: BIT STREAM should not indicate unused bytes." as String
assertEquals msg, ike.getMessage() assertEquals msg, ike.getMessage()
} }
} }
@ -271,13 +279,13 @@ class EdwardsCurveTest {
size = it.encodedKeyByteLength size = it.encodedKeyByteLength
byte[] keyBytes = new byte[size] byte[] keyBytes = new byte[size]
Randoms.secureRandom().nextBytes(keyBytes) Randoms.secureRandom().nextBytes(keyBytes)
byte[] encoded = Bytes.concat(it.PRIVATE_KEY_DER_PREFIX, keyBytes) byte[] encoded = Bytes.concat(it.PRIVATE_KEY_ASN1_PREFIX, keyBytes)
encoded[14] = 0x0F // should always be 0x04 (DER SEQUENCE tag) encoded[14] = 0x0F // should always be 0x04 (ASN.1 SEQUENCE tag)
def key = new TestKey(encoded: encoded) def key = new TestKey(encoded: encoded)
it.getKeyMaterial(key) it.getKeyMaterial(key)
fail() fail()
} catch (InvalidKeyException ike) { } catch (InvalidKeyException ike) {
String msg = "Invalid ${it.getId()} DER encoding: Invalid key length." as String String msg = "Invalid ${it.getId()} ASN.1 encoding: Invalid key length." as String
assertEquals msg, ike.getMessage() assertEquals msg, ike.getMessage()
} }
} }
@ -291,13 +299,13 @@ class EdwardsCurveTest {
size = it.encodedKeyByteLength - 1 // one less than required size = it.encodedKeyByteLength - 1 // one less than required
byte[] keyBytes = new byte[size] byte[] keyBytes = new byte[size]
Randoms.secureRandom().nextBytes(keyBytes) Randoms.secureRandom().nextBytes(keyBytes)
byte[] encoded = Bytes.concat(it.PUBLIC_KEY_DER_PREFIX, keyBytes) byte[] encoded = Bytes.concat(it.PUBLIC_KEY_ASN1_PREFIX, keyBytes)
encoded[10] = (byte) (size + 1) // DER size value (zero byte + key bytes) encoded[10] = (byte) (size + 1) // ASN.1 size value (zero byte + key bytes)
def key = new TestKey(encoded: encoded) def key = new TestKey(encoded: encoded)
it.getKeyMaterial(key) it.getKeyMaterial(key)
fail() fail()
} catch (InvalidKeyException ike) { } catch (InvalidKeyException ike) {
String msg = "Invalid ${it.getId()} DER encoding: Invalid key length." as String String msg = "Invalid ${it.getId()} ASN.1 encoding: Invalid key length." as String
assertEquals msg, ike.getMessage() assertEquals msg, ike.getMessage()
} }
} }
@ -311,67 +319,21 @@ class EdwardsCurveTest {
size = it.encodedKeyByteLength + 1 // one less than required size = it.encodedKeyByteLength + 1 // one less than required
byte[] keyBytes = new byte[size] byte[] keyBytes = new byte[size]
Randoms.secureRandom().nextBytes(keyBytes) Randoms.secureRandom().nextBytes(keyBytes)
byte[] encoded = Bytes.concat(it.PUBLIC_KEY_DER_PREFIX, keyBytes) byte[] encoded = Bytes.concat(it.PUBLIC_KEY_ASN1_PREFIX, keyBytes)
encoded[10] = (byte) (size + 1) // DER size value (zero byte + key bytes) encoded[10] = (byte) (size + 1) // ASN.1 size value (zero byte + key bytes)
def key = new TestKey(encoded: encoded) def key = new TestKey(encoded: encoded)
it.getKeyMaterial(key) it.getKeyMaterial(key)
fail() fail()
} catch (InvalidKeyException ike) { } catch (InvalidKeyException ike) {
String msg = "Invalid ${it.getId()} DER encoding: Invalid key length." as String String msg = "Invalid ${it.getId()} ASN.1 encoding: Invalid key length." as String
assertEquals msg, ike.getMessage() assertEquals msg, ike.getMessage()
} }
} }
} }
@Test
void testParamKeySpecFactoryWithNullSpec() {
def fn = EdwardsCurve.paramKeySpecFactory(null, true)
assertSame Functions.forNull(), fn
}
@Test
void testXecParamKeySpecFactory() {
AlgorithmParameterSpec spec = new ECGenParameterSpec('foo') // any impl will do for this test
def fn = EdwardsCurve.paramKeySpecFactory(spec, false) as EdwardsCurve.ParameterizedKeySpecFactory
assertSame spec, fn.params
assertSame EdwardsCurve.XEC_PRIV_KEY_SPEC_CTOR, fn.keySpecFactory
}
@Test
void testEdEcParamKeySpecFactory() {
AlgorithmParameterSpec spec = new ECGenParameterSpec('foo') // any impl will do for this test
def fn = EdwardsCurve.paramKeySpecFactory(spec, true) as EdwardsCurve.ParameterizedKeySpecFactory
assertSame spec, fn.params
assertSame EdwardsCurve.EDEC_PRIV_KEY_SPEC_CTOR, fn.keySpecFactory
}
@Test
void testParamKeySpecFactoryInvocation() {
AlgorithmParameterSpec spec = new ECGenParameterSpec('foo') // any impl will do for this test
KeySpec keySpec = new PasswordSpec("foo".toCharArray()) // any KeySpec impl will do
byte[] d = new byte[32]
Randoms.secureRandom().nextBytes(d)
def keySpecFn = new Function<Object, KeySpec>() {
@Override
KeySpec apply(Object o) {
assertTrue o instanceof Object[]
Object[] args = (Object[]) o
assertSame spec, args[0]
assertSame d, args[1]
return keySpec // simulate a creation
}
}
def fn = new EdwardsCurve.ParameterizedKeySpecFactory(spec, keySpecFn)
def result = fn.apply(d)
assertSame keySpec, result
}
@Test @Test
void testDerivePublicKeyFromPrivateKey() { void testDerivePublicKeyFromPrivateKey() {
for(def curve : EdwardsCurve.VALUES) { for (def curve : EdwardsCurve.VALUES) {
def pair = curve.keyPair().build() // generate a standard key pair using the JCA APIs def pair = curve.keyPair().build() // generate a standard key pair using the JCA APIs
def pubKey = pair.getPublic() def pubKey = pair.getPublic()
def derivedPubKey = EdwardsCurve.derivePublic(pair.getPrivate()) def derivedPubKey = EdwardsCurve.derivePublic(pair.getPrivate())

View File

@ -15,7 +15,9 @@
*/ */
package io.jsonwebtoken.impl.security package io.jsonwebtoken.impl.security
import io.jsonwebtoken.impl.lang.Bytes
import io.jsonwebtoken.impl.lang.CheckedFunction import io.jsonwebtoken.impl.lang.CheckedFunction
import io.jsonwebtoken.lang.Classes
import io.jsonwebtoken.security.SecurityException import io.jsonwebtoken.security.SecurityException
import io.jsonwebtoken.security.SignatureException import io.jsonwebtoken.security.SignatureException
import org.bouncycastle.jce.provider.BouncyCastleProvider import org.bouncycastle.jce.provider.BouncyCastleProvider
@ -23,9 +25,14 @@ import org.junit.Test
import javax.crypto.Cipher import javax.crypto.Cipher
import javax.crypto.Mac import javax.crypto.Mac
import java.security.Provider import java.security.*
import java.security.Security import java.security.cert.CertificateException
import java.security.Signature import java.security.cert.CertificateFactory
import java.security.cert.X509Certificate
import java.security.spec.InvalidKeySpecException
import java.security.spec.KeySpec
import java.security.spec.PKCS8EncodedKeySpec
import java.security.spec.X509EncodedKeySpec
import static org.junit.Assert.* import static org.junit.Assert.*
@ -37,7 +44,7 @@ class JcaTemplateTest {
@Test @Test
void testGetInstanceExceptionMessage() { void testGetInstanceExceptionMessage() {
def factories = JcaTemplate.FACTORIES def factories = JcaTemplate.FACTORIES
for(def factory : factories) { for (def factory : factories) {
def clazz = factory.getInstanceClass() def clazz = factory.getInstanceClass()
try { try {
factory.get('foo', null) factory.get('foo', null)
@ -45,8 +52,8 @@ class JcaTemplateTest {
if (clazz == Signature || clazz == Mac) { if (clazz == Signature || clazz == Mac) {
assertTrue expected instanceof SignatureException assertTrue expected instanceof SignatureException
} }
String prefix = "Unable to obtain ${clazz.getSimpleName()} instance " + String prefix = "Unable to obtain 'foo' ${clazz.getSimpleName()} instance " +
"from default JCA Provider for JCA algorithm 'foo': " "from default JCA Provider: "
assertTrue expected.getMessage().startsWith(prefix) assertTrue expected.getMessage().startsWith(prefix)
} }
} }
@ -56,7 +63,7 @@ class JcaTemplateTest {
void testGetInstanceWithExplicitProviderExceptionMessage() { void testGetInstanceWithExplicitProviderExceptionMessage() {
def factories = JcaTemplate.FACTORIES def factories = JcaTemplate.FACTORIES
def provider = BC_PROVIDER def provider = BC_PROVIDER
for(def factory : factories) { for (def factory : factories) {
def clazz = factory.getInstanceClass() def clazz = factory.getInstanceClass()
try { try {
factory.get('foo', provider) factory.get('foo', provider)
@ -64,8 +71,8 @@ class JcaTemplateTest {
if (clazz == Signature || clazz == Mac) { if (clazz == Signature || clazz == Mac) {
assertTrue expected instanceof SignatureException assertTrue expected instanceof SignatureException
} }
String prefix = "Unable to obtain ${clazz.getSimpleName()} instance " + String prefix = "Unable to obtain 'foo' ${clazz.getSimpleName()} instance " +
"from specified Provider '${provider.toString()}' for JCA algorithm 'foo': " "from specified '${provider.toString()}' Provider: "
assertTrue expected.getMessage().startsWith(prefix) assertTrue expected.getMessage().startsWith(prefix)
} }
} }
@ -102,69 +109,219 @@ class JcaTemplateTest {
}) })
} }
// @Test @Test
// void testGetInstanceFailureWithExplicitProvider() { void testInstanceFactoryFallbackFailureRetainsOriginalException() {
// //noinspection GroovyUnusedAssignment String alg = 'foo'
// Provider provider = Security.getProvider('SunJCE') NoSuchAlgorithmException ex = new NoSuchAlgorithmException('foo')
// def supplier = new JcaTemplate.JcaInstanceSupplier<Cipher>(Cipher.class, "AES", provider) { def factory = new JcaTemplate.JcaInstanceFactory<Cipher>(Cipher.class) {
// @Override @Override
// protected Cipher doGetInstance() { protected Cipher doGet(String jcaName, Provider provider) throws Exception {
// throw new IllegalStateException("foo") throw ex
// } }
// }
//
// try {
// supplier.getInstance()
// } catch (SecurityException ce) { //should be wrapped as SecurityException
// String msg = ce.getMessage()
// //we check for starts-with/ends-with logic here instead of equals because the JCE provider String value
// //contains the JCE version number, and that can differ across JDK versions. Since we use different JDK
// //versions in the test machine matrix, we don't want test failures from JDKs that run on higher versions
// assertTrue msg.startsWith('Unable to obtain Cipher instance from specified Provider {SunJCE')
// assertTrue msg.endsWith('} for JCA algorithm \'AES\': foo')
// }
// }
//
// @Test
// void testGetInstanceDoesNotWrapCryptoExceptions() {
// def ex = new SecurityException("foo")
// def supplier = new JcaTemplate.JcaInstanceSupplier<Cipher>(Cipher.class, 'AES', null) {
// @Override
// protected Cipher doGetInstance() {
// throw ex
// }
// }
//
// try {
// supplier.getInstance()
// } catch (SecurityException ce) {
// assertSame ex, ce
// }
// }
//
// static void wrapInSignatureException(Class instanceType, String jcaName) {
// def ex = new IllegalArgumentException("foo")
// def supplier = new JcaTemplate.JcaInstanceSupplier<Object>(instanceType, jcaName, null) {
// @Override
// protected Object doGetInstance() {
// throw ex
// }
// }
//
// try {
// supplier.getInstance()
// } catch (SignatureException se) {
// assertSame ex, se.getCause()
// String msg = "Unable to obtain ${instanceType.simpleName} instance from default JCA Provider for JCA algorithm '${jcaName}': foo"
// assertEquals msg, se.getMessage()
// }
// }
// @Test @Override
// void testNonCryptoExceptionForSignatureOrMacInstanceIsWrappedInSignatureException() { protected Provider findBouncyCastle() {
// wrapInSignatureException(Signature.class, 'RSA') return null
// wrapInSignatureException(Mac.class, 'HmacSHA256') }
// } }
try {
factory.get(alg, null)
fail()
} catch (SecurityException se) {
assertSame ex, se.getCause()
String msg = "Unable to obtain '$alg' Cipher instance from default JCA Provider: $alg"
assertEquals msg, se.getMessage()
}
}
@Test
void testWrapWithDefaultJcaProviderAndFallbackProvider() {
JcaTemplate.FACTORIES.each {
Provider fallback = TestKeys.BC
String jcaName = 'foo'
NoSuchAlgorithmException nsa = new NoSuchAlgorithmException("doesn't exist")
Exception out = ((JcaTemplate.JcaInstanceFactory) it).wrap(nsa, jcaName, null, fallback)
assertTrue out instanceof SecurityException
String msg = "Unable to obtain '${jcaName}' ${it.getId()} instance from default JCA Provider or fallback " +
"'${fallback.toString()}' Provider: doesn't exist"
assertEquals msg, out.getMessage()
}
}
@Test
void testFallbackWithBouncyCastle() {
def template = new JcaTemplate('foo', null)
try {
template.generateX509Certificate(Bytes.random(32))
} catch (SecurityException expected) {
String prefix = "Unable to obtain 'foo' CertificateFactory instance from default JCA Provider: "
assertTrue expected.getMessage().startsWith(prefix)
assertTrue expected.getCause() instanceof CertificateException
}
}
@Test
void testFallbackWithoutBouncyCastle() {
def template = new JcaTemplate('foo', null) {
@Override
protected Provider findBouncyCastle() {
return null
}
}
try {
template.generateX509Certificate(Bytes.random(32))
} catch (SecurityException expected) {
String prefix = "Unable to obtain 'foo' CertificateFactory instance from default JCA Provider: "
assertTrue expected.getMessage().startsWith(prefix)
assertTrue expected.getCause() instanceof CertificateException
}
}
static InvalidKeySpecException jdk8213363BugEx(String msg) {
// mock up JDK 11 bug behavior:
String className = 'sun.security.ec.XDHKeyFactory'
String methodName = 'engineGeneratePrivate'
def ste = new StackTraceElement(className, methodName, null, 0)
StackTraceElement[] stes = new StackTraceElement[1]
stes[0] = ste
def cause = new InvalidKeyException(msg)
def ex = new InvalidKeySpecException(cause) {
@Override
StackTraceElement[] getStackTrace() {
return stes
}
}
return ex
}
@Test
void testJdk8213363Bug() {
for (def bundle in [TestKeys.X25519, TestKeys.X448]) {
def privateKey = bundle.pair.private
byte[] d = bundle.alg.getKeyMaterial(privateKey)
byte[] prefix = new byte[2]; prefix[0] = (byte) 0x04; prefix[1] = (byte) d.length
byte[] pkcs8d = Bytes.concat(prefix, d)
int callCount = 0
def ex = jdk8213363BugEx("key length must be ${d.length}")
def template = new Jdk8213363JcaTemplate(bundle.alg.id) {
@Override
protected PrivateKey generatePrivate(KeyFactory factory, KeySpec spec) throws InvalidKeySpecException {
if (callCount == 0) { // simulate first attempt throwing an exception
callCount++
throw ex
}
// otherwise 2nd call due to fallback logic, simulate a successful call:
return privateKey
}
}
assertSame privateKey, template.generatePrivate(new PKCS8EncodedKeySpec(pkcs8d))
}
}
@Test
void testGeneratePrivateRespecWithoutPkcs8() {
byte[] invalid = Bytes.random(456)
def template = new JcaTemplate('X448', null)
try {
template.generatePrivate(new X509EncodedKeySpec(invalid))
fail()
} catch (SecurityException expected) {
boolean jdk11OrLater = Classes.isAvailable('java.security.interfaces.XECPrivateKey')
String msg = 'KeyFactory callback execution failed: key spec not recognized'
if (jdk11OrLater) {
msg = 'KeyFactory callback execution failed: Only PKCS8EncodedKeySpec and XECPrivateKeySpec supported'
}
assertEquals msg, expected.getMessage()
}
}
@Test
void testGeneratePrivateRespecTooSmall() {
byte[] invalid = Bytes.random(16)
def ex = jdk8213363BugEx("key length must be ${invalid.length}")
def template = new Jdk8213363JcaTemplate('X25519') {
@Override
protected PrivateKey generatePrivate(KeyFactory factory, KeySpec spec) throws InvalidKeySpecException {
throw ex
}
}
try {
template.generatePrivate(new PKCS8EncodedKeySpec(invalid))
fail()
} catch (SecurityException expected) {
String msg = "KeyFactory callback execution failed: java.security.InvalidKeyException: " +
"key length must be ${invalid.length}"
assertEquals msg, expected.getMessage()
}
}
@Test
void testGeneratePrivateRespecTooLarge() {
byte[] invalid = Bytes.random(50)
def ex = jdk8213363BugEx("key length must be ${invalid.length}")
def template = new Jdk8213363JcaTemplate('X448') {
@Override
protected PrivateKey generatePrivate(KeyFactory factory, KeySpec spec) throws InvalidKeySpecException {
throw ex
}
}
try {
template.generatePrivate(new PKCS8EncodedKeySpec(invalid))
fail()
} catch (SecurityException expected) {
String msg = "KeyFactory callback execution failed: java.security.InvalidKeyException: " +
"key length must be ${invalid.length}"
assertEquals msg, expected.getMessage()
}
}
@Test
void testGetJdk8213363BugExpectedSizeNoExMsg() {
InvalidKeyException ex = new InvalidKeyException()
def template = new JcaTemplate('X448', null)
assertEquals(-1, template.getJdk8213363BugExpectedSize(ex))
}
@Test
void testGetJdk8213363BugExpectedSizeExMsgDoesntMatch() {
InvalidKeyException ex = new InvalidKeyException('not what is expected')
def template = new JcaTemplate('X448', null)
assertEquals(-1, template.getJdk8213363BugExpectedSize(ex))
}
@Test
void testGetJdk8213363BugExpectedSizeExMsgDoesntContainNumber() {
InvalidKeyException ex = new InvalidKeyException('key length must be foo')
def template = new JcaTemplate('X448', null)
assertEquals(-1, template.getJdk8213363BugExpectedSize(ex))
}
@Test
void testRespecIfNecessaryWithoutPkcs8KeySpec() {
def spec = new X509EncodedKeySpec(Bytes.random(32))
def template = new JcaTemplate('X448', null)
assertNull template.respecIfNecessary(null, spec)
}
@Test
void testRespecIfNecessaryNotJdk8213363Bug() {
def ex = new InvalidKeySpecException('foo')
def template = new JcaTemplate('X448', null)
assertNull template.respecIfNecessary(ex, new PKCS8EncodedKeySpec(Bytes.random(32)))
}
@Test
void testIsJdk11() {
// determine which JDK the test is being run on in CI:
boolean testMachineIsJdk11 = System.getProperty('java.version').startsWith('11')
def template = new JcaTemplate('X448', null)
if (testMachineIsJdk11) {
assertTrue template.isJdk11()
} else {
assertFalse template.isJdk11()
}
}
@Test @Test
void testCallbackThrowsException() { void testCallbackThrowsException() {
@ -183,4 +340,27 @@ class JcaTemplateTest {
} }
} }
@Test
void testWithCertificateFactory() {
def template = new JcaTemplate('X.509', null)
X509Certificate expected = TestKeys.RS256.cert
X509Certificate cert = template.withCertificateFactory(new CheckedFunction<CertificateFactory, X509Certificate>() {
@Override
X509Certificate apply(CertificateFactory certificateFactory) throws Exception {
(X509Certificate)certificateFactory.generateCertificate(new ByteArrayInputStream(expected.getEncoded()))
}
})
assertEquals expected, cert
}
private static class Jdk8213363JcaTemplate extends JcaTemplate {
Jdk8213363JcaTemplate(String jcaName) {
super(jcaName, null)
}
@Override
protected boolean isJdk11() {
return true
}
}
} }

View File

@ -16,7 +16,6 @@
package io.jsonwebtoken.impl.security package io.jsonwebtoken.impl.security
import io.jsonwebtoken.Jwts import io.jsonwebtoken.Jwts
import io.jsonwebtoken.impl.lang.Conditions
import io.jsonwebtoken.impl.lang.Converters import io.jsonwebtoken.impl.lang.Converters
import io.jsonwebtoken.io.Decoders import io.jsonwebtoken.io.Decoders
import io.jsonwebtoken.io.Encoders import io.jsonwebtoken.io.Encoders
@ -24,7 +23,10 @@ import io.jsonwebtoken.security.*
import org.junit.Test import org.junit.Test
import javax.crypto.SecretKey import javax.crypto.SecretKey
import java.security.* import java.security.MessageDigest
import java.security.PrivateKey
import java.security.PublicKey
import java.security.SecureRandom
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.security.interfaces.ECKey import java.security.interfaces.ECKey
import java.security.interfaces.ECPublicKey import java.security.interfaces.ECPublicKey
@ -262,21 +264,11 @@ class JwksTest {
PublicKey pub = pair.getPublic() PublicKey pub = pair.getPublic()
PrivateKey priv = pair.getPrivate() PrivateKey priv = pair.getPrivate()
Provider provider = null // assume default
if (pub.getClass().getName().startsWith("org.bouncycastle.")) {
// No native JVM support for the key, so we need to enable BC:
provider = Providers.findBouncyCastle(Conditions.TRUE)
}
// test individual keys // test individual keys
PublicJwk pubJwk = Jwks.builder().provider(provider).key(pub).publicKeyUse("sig").build() PublicJwk pubJwk = Jwks.builder().key(pub).publicKeyUse("sig").build()
assertEquals pub, pubJwk.toKey() assertEquals pub, pubJwk.toKey()
def builder = Jwks.builder().provider(provider).key(priv).publicKeyUse('sig') def builder = Jwks.builder().key(priv).publicKeyUse('sig')
if (alg instanceof EdSignatureAlgorithm) {
// We haven't implemented EdDSA public-key derivation yet, so public key is required
builder.publicKey(pub)
}
PrivateJwk privJwk = builder.build() PrivateJwk privJwk = builder.build()
assertEquals priv, privJwk.toKey() assertEquals priv, privJwk.toKey()
PublicJwk privPubJwk = privJwk.toPublicJwk() PublicJwk privPubJwk = privJwk.toPublicJwk()
@ -287,7 +279,7 @@ class JwksTest {
assertEquals priv, jwkPair.getPrivate() assertEquals priv, jwkPair.getPrivate()
// test pair // test pair
builder = Jwks.builder().provider(provider) builder = Jwks.builder()
if (pub instanceof ECKey) { if (pub instanceof ECKey) {
builder = builder.ecKeyPair(pair) builder = builder.ecKeyPair(pair)
} else if (pub instanceof RSAKey) { } else if (pub instanceof RSAKey) {

View File

@ -21,7 +21,6 @@ import io.jsonwebtoken.security.SecurityException
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import java.security.Provider
import java.security.cert.CertificateEncodingException import java.security.cert.CertificateEncodingException
import java.security.cert.CertificateException import java.security.cert.CertificateException
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
@ -38,6 +37,21 @@ class JwtX509StringConverterTest {
converter = JwtX509StringConverter.INSTANCE converter = JwtX509StringConverter.INSTANCE
} }
/**
* Ensures we can convert and convert-back all OpenSSL certs across all JVM versions automatically
* (because X25519 and X448 >= JDK 11 and Ed25519 and Ed448 are >= JDK 15), but they should still work on earlier
* JDKs due to JcaTemplate auto-fallback with BouncyCastle
*/
@Test
void testOpenSSLCertRoundtrip() {
// X25519 and X448 don't have certs, so we filter to leave those out:
TestKeys.ASYM.findAll({ it.cert != null }).each {
X509Certificate cert = it.cert
String encoded = converter.applyTo(cert)
assertEquals cert, converter.applyFrom(encoded)
}
}
@Test @Test
void testApplyToThrowsEncodingException() { void testApplyToThrowsEncodingException() {
@ -81,53 +95,24 @@ class JwtX509StringConverterTest {
@Test @Test
void testApplyFromBadBase64() { void testApplyFromBadBase64() {
final CertificateException ex = new CertificateException('nope') String s = 'f$oo'
converter = new JwtX509StringConverter() {
@Override
protected X509Certificate toCert(byte[] der, Provider provider) throws SecurityException {
assertNull provider // ensures not called twice (no fallback) because der bytes aren't available
throw ex
}
}
String s = 'foo'
try { try {
converter.applyFrom(s) converter.applyFrom(s)
fail() fail()
} catch (IllegalArgumentException expected) { } catch (IllegalArgumentException expected) {
String expectedMsg = "Unable to convert Base64 String '$s' to X509Certificate instance. Cause: nope" String expectedMsg = "Unable to convert Base64 String '$s' to X509Certificate instance. " +
"Cause: Illegal base64 character: '\$'"
assertEquals expectedMsg, expected.getMessage() assertEquals expectedMsg, expected.getMessage()
assertSame ex, expected.getCause()
} }
} }
@Test @Test
void testApplyFromRsaSsaPssCertStringWithSuccessfulBCRetry() { void testApplyFromInvalidCertString() {
final CertificateException ex = new CertificateException("nope: ${RsaSignatureAlgorithm.PSS_OID}")
converter = new JwtX509StringConverter() {
@Override
protected X509Certificate toCert(byte[] der, Provider provider) throws SecurityException {
if (provider == null) {
throw ex // first time called, throw ex (simulates JVM parse failure)
} else { // this time BC is available:
assertNotNull provider
return super.toCert(der, provider)
}
}
}
def cert = TestKeys.RS256.cert
def validBase64 = Encoders.BASE64.encode(cert.getEncoded())
assertEquals cert, converter.applyFrom(validBase64)
}
@Test
void testApplyFromRsaSsaPssCertStringWithFailedBCRetry() {
final String exMsg = "nope: ${RsaSignatureAlgorithm.PSS_OID}" final String exMsg = "nope: ${RsaSignatureAlgorithm.PSS_OID}"
final CertificateException ex = new CertificateException(exMsg) final CertificateException ex = new CertificateException(exMsg)
converter = new JwtX509StringConverter() { converter = new JwtX509StringConverter() {
@Override @Override
protected X509Certificate toCert(byte[] der, Provider provider) throws SecurityException { protected X509Certificate toCert(byte[] der) throws SecurityException {
throw ex // ensure fails first and second time throw ex // ensure fails first and second time
} }
} }

View File

@ -16,7 +16,6 @@
package io.jsonwebtoken.impl.security package io.jsonwebtoken.impl.security
import io.jsonwebtoken.Jwts import io.jsonwebtoken.Jwts
import io.jsonwebtoken.impl.lang.Conditions
import io.jsonwebtoken.impl.lang.Functions import io.jsonwebtoken.impl.lang.Functions
import io.jsonwebtoken.lang.Classes import io.jsonwebtoken.lang.Classes
import io.jsonwebtoken.security.Jwks import io.jsonwebtoken.security.Jwks
@ -28,7 +27,6 @@ class PrivateConstructorsTest {
void testPrivateCtors() { // for code coverage only void testPrivateCtors() { // for code coverage only
new Classes() new Classes()
new KeysBridge() new KeysBridge()
new Conditions()
new Functions() new Functions()
new Jwts.SIG() new Jwts.SIG()
new Jwts.ENC() new Jwts.ENC()

View File

@ -15,7 +15,6 @@
*/ */
package io.jsonwebtoken.impl.security package io.jsonwebtoken.impl.security
import io.jsonwebtoken.impl.lang.Conditions
import io.jsonwebtoken.lang.Classes import io.jsonwebtoken.lang.Classes
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
@ -75,12 +74,12 @@ class ProvidersTest {
assertTrue bcRegistered() // ensure it exists in the system as expected assertTrue bcRegistered() // ensure it exists in the system as expected
//now ensure that we find it and cache it: //now ensure that we find it and cache it:
def returned = Providers.findBouncyCastle(Conditions.TRUE) def returned = Providers.findBouncyCastle()
assertSame bc, returned assertSame bc, returned
assertSame bc, Providers.BC_PROVIDER.get() // ensure cached for future lookup assertSame bc, Providers.BC_PROVIDER.get() // ensure cached for future lookup
//ensure cache hit works: //ensure cache hit works:
assertSame bc, Providers.findBouncyCastle(Conditions.TRUE) assertSame bc, Providers.findBouncyCastle()
//cleanup() method will remove the provider from the system //cleanup() method will remove the provider from the system
} }
@ -93,12 +92,12 @@ class ProvidersTest {
// ensure we can create one and cache it, *without* modifying the system JVM: // ensure we can create one and cache it, *without* modifying the system JVM:
//now ensure that we find it and cache it: //now ensure that we find it and cache it:
def returned = Providers.findBouncyCastle(Conditions.TRUE) def returned = Providers.findBouncyCastle()
assertNotNull returned assertNotNull returned
assertSame Providers.BC_PROVIDER.get(), returned //ensure cached for future lookup assertSame Providers.BC_PROVIDER.get(), returned //ensure cached for future lookup
assertFalse bcRegistered() //ensure we don't alter the system environment assertFalse bcRegistered() //ensure we don't alter the system environment
assertSame returned, Providers.findBouncyCastle(Conditions.TRUE) //ensure cache hit assertSame returned, Providers.findBouncyCastle() //ensure cache hit
assertFalse bcRegistered() //ensure we don't alter the system environment assertFalse bcRegistered() //ensure we don't alter the system environment
} }
} }

View File

@ -15,7 +15,6 @@
*/ */
package io.jsonwebtoken.impl.security package io.jsonwebtoken.impl.security
import io.jsonwebtoken.impl.lang.Conditions
import io.jsonwebtoken.lang.Classes import io.jsonwebtoken.lang.Classes
import org.junit.After import org.junit.After
import org.junit.Test import org.junit.Test
@ -43,7 +42,7 @@ class ProvidersWithoutBCTest {
mockStatic(Classes) mockStatic(Classes)
expect(Classes.isAvailable(eq("org.bouncycastle.jce.provider.BouncyCastleProvider"))).andReturn(Boolean.FALSE).anyTimes() expect(Classes.isAvailable(eq("org.bouncycastle.jce.provider.BouncyCastleProvider"))).andReturn(Boolean.FALSE).anyTimes()
replay Classes replay Classes
assertNull Providers.findBouncyCastle(Conditions.TRUE) // one should not be created/exist assertNull Providers.findBouncyCastle() // one should not be created/exist
verify Classes verify Classes
assertFalse ProvidersTest.bcRegistered() // nothing should be in the environment assertFalse ProvidersTest.bcRegistered() // nothing should be in the environment
} }

View File

@ -79,8 +79,7 @@ class RsaSignatureAlgorithmTest {
gen.initialize(1024) //too week for any JWA RSA algorithm gen.initialize(1024) //too week for any JWA RSA algorithm
def rsaPair = gen.generateKeyPair() def rsaPair = gen.generateKeyPair()
def provider = RsaSignatureAlgorithm.PS256.getProvider() // in case BC was loaded def pssPair = new JcaTemplate(RsaSignatureAlgorithm.PSS_JCA_NAME, null)
def pssPair = new JcaTemplate(RsaSignatureAlgorithm.PSS_JCA_NAME, provider)
.withKeyPairGenerator(new CheckedFunction<KeyPairGenerator, KeyPair>() { .withKeyPairGenerator(new CheckedFunction<KeyPairGenerator, KeyPair>() {
@Override @Override
KeyPair apply(KeyPairGenerator generator) throws Exception { KeyPair apply(KeyPairGenerator generator) throws Exception {

View File

@ -16,26 +16,21 @@
package io.jsonwebtoken.impl.security package io.jsonwebtoken.impl.security
import io.jsonwebtoken.Identifiable import io.jsonwebtoken.Identifiable
import io.jsonwebtoken.impl.lang.Bytes
import io.jsonwebtoken.impl.lang.CheckedFunction
import io.jsonwebtoken.impl.lang.Conditions
import io.jsonwebtoken.lang.Assert
import io.jsonwebtoken.lang.Classes import io.jsonwebtoken.lang.Classes
import io.jsonwebtoken.lang.Strings import io.jsonwebtoken.lang.Strings
import io.jsonwebtoken.security.SignatureAlgorithm
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo import org.bouncycastle.asn1.pkcs.PrivateKeyInfo
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.openssl.PEMKeyPair import org.bouncycastle.openssl.PEMKeyPair
import org.bouncycastle.openssl.PEMParser import org.bouncycastle.openssl.PEMParser
import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
import java.security.KeyFactory
import java.security.PrivateKey import java.security.PrivateKey
import java.security.Provider import java.security.Provider
import java.security.PublicKey import java.security.PublicKey
import java.security.cert.CertificateFactory
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.security.spec.KeySpec
import java.security.spec.PKCS8EncodedKeySpec
import java.security.spec.X509EncodedKeySpec import java.security.spec.X509EncodedKeySpec
/** /**
@ -55,8 +50,7 @@ import java.security.spec.X509EncodedKeySpec
*/ */
class TestCertificates { class TestCertificates {
private static Provider BC = Assert.notNull(Providers.findBouncyCastle(Conditions.TRUE), static Provider BC = new BouncyCastleProvider()
"BC must be available to test cases.")
private static InputStream getResourceStream(String filename) { private static InputStream getResourceStream(String filename) {
String packageName = TestCertificates.class.getPackage().getName() String packageName = TestCertificates.class.getPackage().getName()
@ -69,116 +63,72 @@ class TestCertificates {
return new PEMParser(new BufferedReader(new InputStreamReader(is, StandardCharsets.ISO_8859_1))) return new PEMParser(new BufferedReader(new InputStreamReader(is, StandardCharsets.ISO_8859_1)))
} }
private static <T> T bcFallback(final Identifiable alg, Closure<T> closure) { private static String keyJcaName(Identifiable alg) {
Provider provider = alg.getProvider() as Provider // null on JVMs with native support for `alg` String jcaName = alg.getId()
try { if (jcaName.startsWith('ES')) {
return closure.call(alg, provider) jcaName = 'EC'
} catch (Throwable t) { } else if (jcaName.startsWith('PS')) {
jcaName = 'RSASSA-PSS'
// All test cert and key files were created with OpenSSL, so the only time this should happen is if the } else if (jcaName.startsWith('RS')) {
// JDK natively supports the alg, but has a bug that prevents it from reading the file correctly. So jcaName = 'RSA'
// we account for those bugs here as indicators that we should retry with BC. }
return jcaName
// https://bugs.openjdk.org/browse/JDK-8242556
// Oracle only backported this fix to JDK 8u271+, 11.0.9+, and 15+, so we'll need to fall back to
// BC (which can read the files correctly) on JDK 9, 10, 12, 13, and 14.
boolean jdk8242556Bug = alg instanceof SignatureAlgorithm && alg.getId().startsWith("PS") &&
t.message.contains('Unsupported algorithm 1.2.840.113549.1.1.10')
// https://bugs.openjdk.org/browse/JDK-8213363) for X25519 and X448 encoded keys. JDK 11's
// SunCE provider incorrectly expects an ASN.1 OCTET STRING (without the DER tag/length prefix)
// when it should actually be a BER-encoded OCTET STRING (with the tag/length prefix).
boolean jdk8213363Bug = alg instanceof EdwardsCurve && !((EdwardsCurve) alg).isSignatureCurve() &&
System.getProperty("java.version").startsWith("11")
// Now assert that we're experiencing one of the expected bugs, because if not, we need to know about
// it in test results and fix this implementation:
if (!jdk8242556Bug && !jdk8213363Bug) {
String msg = "Unable to read ${alg.getId()} file: ${t.message}"
throw new IllegalStateException(msg, t)
} }
// otherwise, we are indeed experiencing one of the expected bugs, so use BC as a backup: private static PublicKey readPublicKey(Identifiable alg) {
return closure.call(alg, BC)
}
}
private static def readPublicKey = { Identifiable alg, Provider provider ->
PEMParser parser = getParser(alg.id + '.pub.pem') PEMParser parser = getParser(alg.id + '.pub.pem')
parser.withCloseable { parser.withCloseable {
SubjectPublicKeyInfo info = parser.readObject() as SubjectPublicKeyInfo SubjectPublicKeyInfo info = it.readObject() as SubjectPublicKeyInfo
JcaTemplate template = new JcaTemplate(alg.getJcaName(), provider) JcaTemplate template = new JcaTemplate(keyJcaName(alg), null)
template.withKeyFactory(new CheckedFunction<KeyFactory, PublicKey>() { return template.generatePublic(new X509EncodedKeySpec(info.getEncoded()))
@Override
PublicKey apply(KeyFactory keyFactory) throws Exception {
return keyFactory.generatePublic(new X509EncodedKeySpec(info.getEncoded()))
}
})
} }
} }
private static def readCert = { Identifiable alg, Provider provider -> private static X509Certificate readCert(Identifiable alg, Provider provider) {
InputStream is = getResourceStream(alg.id + '.crt.pem') InputStream is = getResourceStream(alg.id + '.crt.pem')
is.withCloseable {
JcaTemplate template = new JcaTemplate("X.509", provider) JcaTemplate template = new JcaTemplate("X.509", provider)
template.withCertificateFactory(new CheckedFunction<CertificateFactory, X509Certificate>() { return template.generateX509Certificate(is.getBytes())
@Override
X509Certificate apply(CertificateFactory factory) throws Exception {
return (X509Certificate) factory.generateCertificate(it)
}
})
}
} }
private static def readPrivateKey = { Identifiable alg, Provider provider -> private static PrivateKey readPrivateKey(Identifiable alg) {
final String id = alg.id final String id = alg.id
PEMParser parser = getParser(id + '.key.pem') PEMParser parser = getParser(id + '.key.pem')
parser.withCloseable { parser.withCloseable {
PrivateKeyInfo info PrivateKeyInfo info
Object object = parser.readObject() Object object = it.readObject()
if (object instanceof PEMKeyPair) { if (object instanceof PEMKeyPair) {
info = ((PEMKeyPair) object).getPrivateKeyInfo() info = ((PEMKeyPair) object).getPrivateKeyInfo()
} else { } else {
info = (PrivateKeyInfo) object info = (PrivateKeyInfo) object
} }
def converter = new JcaPEMKeyConverter() final KeySpec spec = new PKCS8EncodedKeySpec(info.getEncoded())
if (provider != null) { return new JcaTemplate(keyJcaName(alg), null).generatePrivate(spec)
converter.setProvider(provider)
} else if (id.startsWith("X") && System.getProperty("java.version").startsWith("11")) {
EdwardsCurve curve = EdwardsCurve.findById(id)
Assert.notNull(curve, "Curve cannot be null.")
int expectedByteLen = ((curve.keyBitLength + 7) / 8) as int
// Address the [JDK 11 SunCE provider bug](https://bugs.openjdk.org/browse/JDK-8213363) for X25519
// and X448 encoded keys: Even though the file is encoded properly (it was created by OpenSSL), JDK 11's
// SunCE provider incorrectly expects an ASN.1 OCTET STRING (without the DER tag/length prefix)
// when it should actually be a BER-encoded OCTET STRING (with the tag/length prefix).
// So we get the raw bytes and use our key generator:
byte[] keyOctets = info.getPrivateKey().getOctets()
int lenDifference = Bytes.length(keyOctets) - expectedByteLen
if (lenDifference > 0) {
byte[] derPrefixRemoved = new byte[expectedByteLen]
System.arraycopy(keyOctets, lenDifference, derPrefixRemoved, 0, expectedByteLen)
keyOctets = derPrefixRemoved
} }
return curve.toPrivateKey(keyOctets, null)
}
return converter.getPrivateKey(info)
}
}
static TestKeys.Bundle readBundle(EdwardsCurve curve) {
//PublicKey pub = readTestPublicKey(curve)
//PrivateKey priv = readTestPrivateKey(curve)
PublicKey pub = bcFallback(curve, readPublicKey) as PublicKey
PrivateKey priv = bcFallback(curve, readPrivateKey) as PrivateKey
return new TestKeys.Bundle(pub, priv)
} }
static TestKeys.Bundle readBundle(Identifiable alg) { static TestKeys.Bundle readBundle(Identifiable alg) {
//X509Certificate cert = readTestCertificate(alg)
//PrivateKey priv = readTestPrivateKey(alg) PublicKey pub = readPublicKey(alg) as PublicKey
X509Certificate cert = bcFallback(alg, readCert) as X509Certificate PrivateKey priv = readPrivateKey(alg) as PrivateKey
PrivateKey priv = bcFallback(alg, readPrivateKey) as PrivateKey
return new TestKeys.Bundle(cert, priv) // X25519 and X448 cannot have self-signed certs:
if (alg instanceof EdwardsCurve && !((EdwardsCurve) alg).isSignatureCurve()) {
return new TestKeys.Bundle(alg, pub, priv)
}
// otherwise we can get a cert:
// If the public key loaded is a BC key, the default provider doesn't understand the cert key OID
// (for example, an Ed25519 key on JDK 8 which doesn't natively support such keys). This means the
// X.509 certificate should also be loaded by BC; otherwise the Sun X.509 CertificateFactory returns
// a certificate with certificate.getPublicKey() being a sun X509Key instead of the type-specific key we want:
Provider provider = null
if (pub.getClass().getName().startsWith("org.bouncycastle")) {
provider = BC
}
X509Certificate cert = readCert(alg, provider) as X509Certificate
PublicKey certPub = cert.getPublicKey()
assert pub.equals(certPub)
return new TestKeys.Bundle(alg, pub, priv, cert)
} }
} }

View File

@ -23,6 +23,7 @@ import io.jsonwebtoken.security.Jwks
import javax.crypto.SecretKey import javax.crypto.SecretKey
import java.security.KeyPair import java.security.KeyPair
import java.security.PrivateKey import java.security.PrivateKey
import java.security.Provider
import java.security.PublicKey import java.security.PublicKey
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
@ -31,6 +32,8 @@ import java.security.cert.X509Certificate
*/ */
class TestKeys { class TestKeys {
static Provider BC = TestCertificates.BC
// ======================================================= // =======================================================
// Secret Keys // Secret Keys
// ======================================================= // =======================================================
@ -100,29 +103,18 @@ class TestKeys {
return TestKeys.metaClass.getAttribute(TestKeys, id) as Bundle return TestKeys.metaClass.getAttribute(TestKeys, id) as Bundle
} }
static Bundle forCurve(EdwardsCurve curve) {
return TestKeys.metaClass.getAttribute(TestKeys, curve.getId()) as Bundle
}
static class Bundle { static class Bundle {
Identifiable alg
X509Certificate cert X509Certificate cert
List<X509Certificate> chain List<X509Certificate> chain
KeyPair pair KeyPair pair
Bundle(X509Certificate cert, PrivateKey privateKey) { Bundle(Identifiable alg, PublicKey publicKey, PrivateKey privateKey, X509Certificate cert = null) {
this.alg = alg
this.cert = cert this.cert = cert
this.chain = Collections.of(cert) this.chain = cert != null ? Collections.of(cert) : Collections.<X509Certificate> emptyList()
this.pair = new KeyPair(cert.getPublicKey(), privateKey) this.pair = new KeyPair(publicKey, privateKey);
}
Bundle(KeyPair pair) {
this.cert = null
this.chain = Collections.emptyList()
this.pair = pair
}
Bundle(PublicKey pub, PrivateKey priv) {
this(new KeyPair(pub, priv))
} }
} }
} }