From 676b44ebb052d4600bc99e6ecb3fb4f656b85a8a Mon Sep 17 00:00:00 2001 From: Josh Cummings <3627351+jzheaux@users.noreply.github.com> Date: Tue, 17 Jun 2025 15:45:41 -0600 Subject: [PATCH] Polish NimbusJwtEncoder Builders - Simplify withKeyPair methods to match withPublicKey convention in NimbusJwtDecoder - Update tests to confirm support of other algorithms - Update constructor to apply additional JWK properties to the default header - Deduce the possibly algorithms for a given key based on curve and key size - Remove algorithm method from EC builder since the algorithm is determined by the Curve of the EC Key Issue gh-16267 Co-Authored-By: Suraj Bhadrike --- .../security/oauth2/jwt/JWKS.java | 87 ++++++ .../security/oauth2/jwt/NimbusJwtEncoder.java | 288 +++++++++--------- .../oauth2/jwt/NimbusJwtEncoderTests.java | 213 ++++++------- 3 files changed, 323 insertions(+), 265 deletions(-) create mode 100644 oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JWKS.java diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JWKS.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JWKS.java new file mode 100644 index 0000000000..8596749bc2 --- /dev/null +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JWKS.java @@ -0,0 +1,87 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.jwt; + +import java.security.interfaces.ECPrivateKey; +import java.security.interfaces.ECPublicKey; +import java.security.interfaces.RSAPrivateKey; +import java.security.interfaces.RSAPublicKey; +import java.util.Date; +import java.util.Set; + +import javax.crypto.SecretKey; + +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.crypto.impl.ECDSA; +import com.nimbusds.jose.jwk.Curve; +import com.nimbusds.jose.jwk.ECKey; +import com.nimbusds.jose.jwk.KeyOperation; +import com.nimbusds.jose.jwk.KeyUse; +import com.nimbusds.jose.jwk.OctetSequenceKey; +import com.nimbusds.jose.jwk.RSAKey; + +final class JWKS { + + private JWKS() { + + } + + static OctetSequenceKey.Builder signing(SecretKey key) throws JOSEException { + Date issued = new Date(); + return new OctetSequenceKey.Builder(key).keyOperations(Set.of(KeyOperation.SIGN)) + .keyUse(KeyUse.SIGNATURE) + .algorithm(JWSAlgorithm.HS256) + .keyIDFromThumbprint() + .issueTime(issued) + .notBeforeTime(issued); + } + + static ECKey.Builder signingWithEc(ECPublicKey pub, ECPrivateKey key) throws JOSEException { + Date issued = new Date(); + Curve curve = Curve.forECParameterSpec(pub.getParams()); + JWSAlgorithm algorithm = computeAlgorithm(curve); + return new ECKey.Builder(curve, pub).privateKey(key) + .keyOperations(Set.of(KeyOperation.SIGN)) + .keyUse(KeyUse.SIGNATURE) + .algorithm(algorithm) + .keyIDFromThumbprint() + .issueTime(issued) + .notBeforeTime(issued); + } + + private static JWSAlgorithm computeAlgorithm(Curve curve) { + try { + return ECDSA.resolveAlgorithm(curve); + } + catch (JOSEException ex) { + throw new IllegalArgumentException(ex); + } + } + + static RSAKey.Builder signingWithRsa(RSAPublicKey pub, RSAPrivateKey key) throws JOSEException { + Date issued = new Date(); + return new RSAKey.Builder(pub).privateKey(key) + .keyUse(KeyUse.SIGNATURE) + .keyOperations(Set.of(KeyOperation.SIGN)) + .algorithm(JWSAlgorithm.RS256) + .keyIDFromThumbprint() + .issueTime(issued) + .notBeforeTime(issued); + } + +} diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java index 8e1b7f57e6..8fd1ada518 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java @@ -19,7 +19,10 @@ package org.springframework.security.oauth2.jwt; import java.net.URI; import java.net.URL; import java.security.KeyPair; +import java.security.interfaces.ECPrivateKey; import java.security.interfaces.ECPublicKey; +import java.security.interfaces.RSAPrivateKey; +import java.security.interfaces.RSAPublicKey; import java.time.Instant; import java.util.ArrayList; import java.util.Date; @@ -27,8 +30,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; import javax.crypto.SecretKey; @@ -37,6 +40,7 @@ import com.nimbusds.jose.JOSEObjectType; import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSHeader; import com.nimbusds.jose.JWSSigner; +import com.nimbusds.jose.crypto.MACSigner; import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory; import com.nimbusds.jose.jwk.Curve; import com.nimbusds.jose.jwk.ECKey; @@ -58,11 +62,14 @@ import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; import org.springframework.core.convert.converter.Converter; +import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; +import org.springframework.util.function.ThrowingBiFunction; +import org.springframework.util.function.ThrowingFunction; /** * An implementation of a {@link JwtEncoder} that encodes a JSON Web Token (JWT) using the @@ -74,6 +81,8 @@ import org.springframework.util.StringUtils; * NOTE: This implementation uses the Nimbus JOSE + JWT SDK. * * @author Joe Grandja + * @author Josh Cummings + * @author Suraj Bhadrike * @since 5.6 * @see JwtEncoder * @see com.nimbusds.jose.jwk.source.JWKSource @@ -95,7 +104,7 @@ public final class NimbusJwtEncoder implements JwtEncoder { private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory(); - private JwsHeader jwsHeader; + private final JwsHeader defaultJwsHeader; private final Map jwsSigners = new ConcurrentHashMap<>(); @@ -114,10 +123,35 @@ public final class NimbusJwtEncoder implements JwtEncoder { * @param jwkSource the {@code com.nimbusds.jose.jwk.source.JWKSource} */ public NimbusJwtEncoder(JWKSource jwkSource) { + this.defaultJwsHeader = DEFAULT_JWS_HEADER; Assert.notNull(jwkSource, "jwkSource cannot be null"); this.jwkSource = jwkSource; } + private NimbusJwtEncoder(JWK jwk) { + Assert.notNull(jwk, "jwk cannot be null"); + this.jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk)); + JwsAlgorithm algorithm = SignatureAlgorithm.from(jwk.getAlgorithm().getName()); + if (algorithm == null) { + algorithm = MacAlgorithm.from(jwk.getAlgorithm().getName()); + } + Assert.notNull(algorithm, "Failed to derive supported algorithm from " + jwk.getAlgorithm()); + JwsHeader.Builder builder = JwsHeader.with(algorithm).type(jwk.getKeyType().getValue()).keyId(jwk.getKeyID()); + URI x509Url = jwk.getX509CertURL(); + if (x509Url != null) { + builder.x509Url(jwk.getX509CertURL().toASCIIString()); + } + List certs = jwk.getX509CertChain(); + if (certs != null) { + builder.x509CertificateChain(certs.stream().map(Base64::toString).toList()); + } + Base64URL thumbprint = jwk.getX509CertSHA256Thumbprint(); + if (thumbprint != null) { + builder.x509SHA256Thumbprint(thumbprint.toString()); + } + this.defaultJwsHeader = builder.build(); + } + /** * Use this strategy to reduce the list of matching JWKs when there is more than one. *

@@ -133,16 +167,15 @@ public final class NimbusJwtEncoder implements JwtEncoder { this.jwkSelector = jwkSelector; } - public void setJwsHeader(JwsHeader jwsHeader) { - this.jwsHeader = jwsHeader; - } - @Override public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException { Assert.notNull(parameters, "parameters cannot be null"); JwsHeader headers = parameters.getJwsHeader(); - headers = (headers != null) ? headers : (this.jwsHeader != null) ? this.jwsHeader : DEFAULT_JWS_HEADER; + if (headers == null) { + headers = this.defaultJwsHeader; + } + JwtClaimsSet claims = parameters.getClaims(); JWK jwk = selectJwk(headers); @@ -387,38 +420,34 @@ public final class NimbusJwtEncoder implements JwtEncoder { /** * Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided - * {@link SecretKey}. - * @param secretKey the {@link SecretKey} to use for signing JWTs - * @return a {@link SecretKeyJwtEncoderBuilder} for further configuration + * @param publicKey the {@link RSAPublicKey} and @Param privateKey the + * {@link RSAPrivateKey} to use for signing JWTs + * @return a {@link RsaKeyPairJwtEncoderBuilder} * @since 7.0 */ - public static SecretKeyJwtEncoderBuilder withSecretKey(SecretKey secretKey) { - Assert.notNull(secretKey, "secretKey cannot be null"); - return new SecretKeyJwtEncoderBuilder(secretKey); + public static RsaKeyPairJwtEncoderBuilder withKeyPair(RSAPublicKey publicKey, RSAPrivateKey privateKey) { + return new RsaKeyPairJwtEncoderBuilder(publicKey, privateKey); } /** * Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided - * {@link KeyPair}. The key pair must contain either an {@link RSAKey} or an - * {@link ECKey}. - * @param keyPair the {@link KeyPair} to use for signing JWTs - * @return a {@link KeyPairJwtEncoderBuilder} for further configuration + * @param publicKey the {@link ECPublicKey} and @param privateKey the + * {@link ECPrivateKey} to use for signing JWTs + * @return a {@link EcKeyPairJwtEncoderBuilder} * @since 7.0 */ - public static KeyPairJwtEncoderBuilder withKeyPair(KeyPair keyPair) { - Assert.isTrue(keyPair != null && keyPair.getPrivate() != null && keyPair.getPublic() != null, - "keyPair, its private key, and public key must not be null"); - Assert.isTrue( - keyPair.getPrivate() instanceof java.security.interfaces.RSAKey - || keyPair.getPrivate() instanceof java.security.interfaces.ECKey, - "keyPair must be an RSAKey or an ECKey"); - if (keyPair.getPrivate() instanceof java.security.interfaces.RSAKey) { - return new RsaKeyPairJwtEncoderBuilder(keyPair); - } - if (keyPair.getPrivate() instanceof java.security.interfaces.ECKey) { - return new EcKeyPairJwtEncoderBuilder(keyPair); - } - throw new IllegalArgumentException("keyPair must be an RSAKey or an ECKey"); + public static EcKeyPairJwtEncoderBuilder withKeyPair(ECPublicKey publicKey, ECPrivateKey privateKey) { + return new EcKeyPairJwtEncoderBuilder(publicKey, privateKey); + } + + /** + * Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided + * @param secretKey + * @return a {@link SecretKeyJwtEncoderBuilder} for configuring the {@link JWK} + * @since 7.0 + */ + public static SecretKeyJwtEncoderBuilder withSecretKey(SecretKey secretKey) { + return new SecretKeyJwtEncoderBuilder(secretKey); } /** @@ -429,14 +458,29 @@ public final class NimbusJwtEncoder implements JwtEncoder { */ public static final class SecretKeyJwtEncoderBuilder { - private final SecretKey secretKey; + private static final ThrowingFunction defaultJwk = JWKS::signing; - private String keyId; + private final OctetSequenceKey.Builder builder; - private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256; + private final Set allowedAlgorithms; private SecretKeyJwtEncoderBuilder(SecretKey secretKey) { - this.secretKey = secretKey; + Assert.notNull(secretKey, "secretKey cannot be null"); + Set allowedAlgorithms = computeAllowedAlgorithms(secretKey); + Assert.notEmpty(allowedAlgorithms, + "This key is too small for any standard JWK symmetric signing algorithm"); + this.allowedAlgorithms = allowedAlgorithms; + this.builder = defaultJwk.apply(secretKey, IllegalArgumentException::new) + .algorithm(this.allowedAlgorithms.iterator().next()); + } + + private Set computeAllowedAlgorithms(SecretKey secretKey) { + try { + return new MACSigner(secretKey).supportedJWSAlgorithms(); + } + catch (JOSEException ex) { + throw new IllegalArgumentException(ex); + } } /** @@ -446,24 +490,24 @@ public final class NimbusJwtEncoder implements JwtEncoder { * @param macAlgorithm the {@link MacAlgorithm} to use * @return this builder instance for method chaining */ - public SecretKeyJwtEncoderBuilder macAlgorithm(MacAlgorithm macAlgorithm) { + public SecretKeyJwtEncoderBuilder algorithm(MacAlgorithm macAlgorithm) { Assert.notNull(macAlgorithm, "macAlgorithm cannot be null"); - Assert.state(JWSAlgorithm.Family.HMAC_SHA.contains(this.jwsAlgorithm), - () -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with a SecretKey. " - + "Please use one of the HS256, HS384, or HS512 algorithms."); - - this.jwsAlgorithm = JWSAlgorithm.parse(macAlgorithm.getName()); + JWSAlgorithm jws = JWSAlgorithm.parse(macAlgorithm.getName()); + Assert.isTrue(this.allowedAlgorithms.contains(jws), String + .format("This key can only support " + "the following algorithms: [%s]", this.allowedAlgorithms)); + this.builder.algorithm(JWSAlgorithm.parse(macAlgorithm.getName())); return this; } /** - * Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS - * header. - * @param keyId the key identifier + * Post-process the {@link JWK} using the given {@link Consumer}. For example, you + * may use this to override the default {@code kid} + * @param jwkPostProcessor the post-processor to use * @return this builder instance for method chaining */ - public SecretKeyJwtEncoderBuilder keyId(String keyId) { - this.keyId = keyId; + public SecretKeyJwtEncoderBuilder jwkPostProcessor(Consumer jwkPostProcessor) { + Assert.notNull(jwkPostProcessor, "jwkPostProcessor cannot be null"); + jwkPostProcessor.accept(this.builder); return this; } @@ -474,17 +518,7 @@ public final class NimbusJwtEncoder implements JwtEncoder { * with a {@link SecretKey}. */ public NimbusJwtEncoder build() { - this.jwsAlgorithm = (this.jwsAlgorithm != null) ? this.jwsAlgorithm : JWSAlgorithm.HS256; - - OctetSequenceKey.Builder builder = new OctetSequenceKey.Builder(this.secretKey).keyUse(KeyUse.SIGNATURE) - .algorithm(this.jwsAlgorithm) - .keyID(this.keyId); - - OctetSequenceKey jwk = builder.build(); - JWKSource jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk)); - NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource); - encoder.setJwsHeader(JwsHeader.with(MacAlgorithm.from(this.jwsAlgorithm.getName())).build()); - return encoder; + return new NimbusJwtEncoder(this.builder.build()); } } @@ -495,137 +529,93 @@ public final class NimbusJwtEncoder implements JwtEncoder { * * @since 7.0 */ - public abstract static class KeyPairJwtEncoderBuilder { + public static final class RsaKeyPairJwtEncoderBuilder { - private final KeyPair keyPair; + private static final ThrowingBiFunction defaultKid = JWKS::signingWithRsa; - private String keyId; + private final RSAKey.Builder builder; - private JWSAlgorithm jwsAlgorithm; - - private KeyPairJwtEncoderBuilder(KeyPair keyPair) { - this.keyPair = keyPair; + private RsaKeyPairJwtEncoderBuilder(RSAPublicKey publicKey, RSAPrivateKey privateKey) { + Assert.notNull(publicKey, "publicKey cannot be null"); + Assert.notNull(privateKey, "privateKey cannot be null"); + this.builder = defaultKid.apply(publicKey, privateKey); } /** - * Sets the JWS algorithm to use for signing. Must be compatible with the key type - * (RSA or EC). If not set, a default algorithm will be chosen based on the key - * type (e.g., RS256 for RSA, ES256 for EC). + * Sets the JWS algorithm to use for signing. Defaults to + * {@link SignatureAlgorithm#RS256}. Must be an RSA-based algorithm * @param signatureAlgorithm the {@link SignatureAlgorithm} to use * @return this builder instance for method chaining */ - public KeyPairJwtEncoderBuilder signatureAlgorithm(SignatureAlgorithm signatureAlgorithm) { + public RsaKeyPairJwtEncoderBuilder algorithm(SignatureAlgorithm signatureAlgorithm) { Assert.notNull(signatureAlgorithm, "signatureAlgorithm cannot be null"); - this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName()); + this.builder.algorithm(JWSAlgorithm.parse(signatureAlgorithm.getName())); return this; } /** - * Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS - * header. - * @param keyId the key identifier + * Add commentMore actions Post-process the {@link JWK} using the given + * {@link Consumer}. For example, you may use this to override the default + * {@code kid} + * @param jwkPostProcessor the post-processor to use * @return this builder instance for method chaining */ - public KeyPairJwtEncoderBuilder keyId(String keyId) { - this.keyId = keyId; + public RsaKeyPairJwtEncoderBuilder jwkPostProcessor(Consumer jwkPostProcessor) { + Assert.notNull(jwkPostProcessor, "jwkPostProcessor cannot be null"); + jwkPostProcessor.accept(this.builder); return this; } /** * Builds the {@link NimbusJwtEncoder} instance. * @return the configured {@link NimbusJwtEncoder} - * @throws IllegalStateException if the key type is unsupported or the configured - * JWS algorithm is not compatible with the key type. - * @throws JwtEncodingException if the key is invalid (e.g., EC key with unknown - * curve) */ public NimbusJwtEncoder build() { - this.keyId = (this.keyId != null) ? this.keyId : UUID.randomUUID().toString(); - JWK jwk = buildJwk(); - JWKSource jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk)); - NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource); - JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.from(this.jwsAlgorithm.getName())) - .keyId(jwk.getKeyID()) - .build(); - encoder.setJwsHeader(jwsHeader); - return encoder; - } - - protected abstract JWK buildJwk(); - - } - - /** - * A builder for creating {@link NimbusJwtEncoder} instances configured with a - * {@link KeyPair}. - * - * @since 7.0 - */ - public static final class RsaKeyPairJwtEncoderBuilder extends KeyPairJwtEncoderBuilder { - - private RsaKeyPairJwtEncoderBuilder(KeyPair keyPair) { - super(keyPair); - } - - @Override - protected JWK buildJwk() { - if (super.jwsAlgorithm == null) { - super.jwsAlgorithm = JWSAlgorithm.RS256; - } - Assert.state(JWSAlgorithm.Family.RSA.contains(super.jwsAlgorithm), - () -> "The algorithm '" + super.jwsAlgorithm + "' is not compatible with an RSAKey. " - + "Please use one of the RS256, RS384, RS512, PS256, PS384, or PS512 algorithms."); - - RSAKey.Builder builder = new RSAKey.Builder( - (java.security.interfaces.RSAPublicKey) super.keyPair.getPublic()) - .privateKey(super.keyPair.getPrivate()) - .keyID(super.keyId) - .keyUse(KeyUse.SIGNATURE) - .algorithm(super.jwsAlgorithm); - return builder.build(); + return new NimbusJwtEncoder(this.builder.build()); } } /** * A builder for creating {@link NimbusJwtEncoder} instances configured with a - * {@link KeyPair}. + * {@link ECPublicKey} and {@link ECPrivateKey}. + *

+ * This builder is used to create a {@link NimbusJwtEncoder} * * @since 7.0 */ - public static final class EcKeyPairJwtEncoderBuilder extends KeyPairJwtEncoderBuilder { + public static final class EcKeyPairJwtEncoderBuilder { - private EcKeyPairJwtEncoderBuilder(KeyPair keyPair) { - super(keyPair); - } + private static final ThrowingBiFunction defaultKid = JWKS::signingWithEc; - @Override - protected JWK buildJwk() { - if (super.jwsAlgorithm == null) { - super.jwsAlgorithm = JWSAlgorithm.ES256; - } - Assert.state(JWSAlgorithm.Family.EC.contains(super.jwsAlgorithm), - () -> "The algorithm '" + super.jwsAlgorithm + "' is not compatible with an ECKey. " - + "Please use one of the ES256, ES384, or ES512 algorithms."); + private final ECKey.Builder builder; - ECPublicKey publicKey = (ECPublicKey) super.keyPair.getPublic(); + private EcKeyPairJwtEncoderBuilder(ECPublicKey publicKey, ECPrivateKey privateKey) { + Assert.notNull(publicKey, "publicKey cannot be null"); + Assert.notNull(privateKey, "privateKey cannot be null"); Curve curve = Curve.forECParameterSpec(publicKey.getParams()); - if (curve == null) { - throw new JwtEncodingException("Unable to determine Curve for EC public key."); - } + Assert.notNull(curve, "Unable to determine Curve for EC public key."); + this.builder = defaultKid.apply(publicKey, privateKey); + } - com.nimbusds.jose.jwk.ECKey.Builder builder = new com.nimbusds.jose.jwk.ECKey.Builder(curve, publicKey) - .privateKey(super.keyPair.getPrivate()) - .keyUse(KeyUse.SIGNATURE) - .keyID(super.keyId) - .algorithm(super.jwsAlgorithm); + /** + * Post-process the {@link JWK} using the given {@link Consumer}. For example, you + * may use this to override the default {@code kid} + * @param jwkPostProcessor the post-processor to use + * @return this builder instance for method chaining + */ + public EcKeyPairJwtEncoderBuilder jwkPostProcessor(Consumer jwkPostProcessor) { + Assert.notNull(jwkPostProcessor, "jwkPostProcessor cannot be null"); + jwkPostProcessor.accept(this.builder); + return this; + } - try { - return builder.build(); - } - catch (IllegalStateException ex) { - throw new IllegalArgumentException("Failed to build ECKey: " + ex.getMessage(), ex); - } + /** + * Builds the {@link NimbusJwtEncoder} instance. + * @return the configured {@link NimbusJwtEncoder} + */ + public NimbusJwtEncoder build() { + return new NimbusJwtEncoder(this.builder.build()); } } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java index 2abbaebf33..840b9cdcca 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java @@ -16,8 +16,6 @@ package org.springframework.security.oauth2.jwt; -import java.security.KeyPair; -import java.security.KeyPairGenerator; import java.security.interfaces.ECPrivateKey; import java.security.interfaces.ECPublicKey; import java.time.Instant; @@ -27,12 +25,15 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.UUID; +import java.util.function.Consumer; import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; +import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.KeySourceException; +import com.nimbusds.jose.jwk.Curve; import com.nimbusds.jose.jwk.ECKey; import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWKSelector; @@ -40,6 +41,8 @@ import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.KeyUse; import com.nimbusds.jose.jwk.OctetSequenceKey; import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jose.jwk.gen.ECKeyGenerator; +import com.nimbusds.jose.jwk.gen.RSAKeyGenerator; import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.proc.SecurityContext; import com.nimbusds.jose.util.Base64URL; @@ -51,12 +54,12 @@ import org.mockito.stubbing.Answer; import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.jose.TestJwks; import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.assertj.core.api.Assertions.assertThatNoException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.willAnswer; @@ -353,160 +356,138 @@ public class NimbusJwtEncoderTests { verifyNoInteractions(selector); } + // Default algorithm @Test - void secretKeyBuilderWithDefaultAlgorithm() { - SecretKey secretKey = new SecretKeySpec("thisIsASecretKeyUsedForTesting12345".getBytes(), "HMAC"); + void keyPairBuilderWithRsaDefaultAlgorithm() throws JOSEException { + RSAKeyGenerator generator = new RSAKeyGenerator(2048); + RSAKey key = generator.generate(); + NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withKeyPair(key.toRSAPublicKey(), key.toRSAPrivateKey()).build(); JwtClaimsSet claims = buildClaims(); - - NimbusJwtEncoder encoder = NimbusJwtEncoder.withSecretKey(secretKey).build(); - Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims)); - - assertThat(jwt).isNotNull(); - assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("HS256"); - assertThatNoException().isThrownBy(jwt::getClaims); + Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims)); assertJwt(jwt); + assertThat(jwt.getHeaders()).containsKey(JoseHeaderNames.KID); } @Test - void secretKeyBuilderWithKeyId() { - SecretKey secretKey = new SecretKeySpec("thisIsASecretKeyUsedForTesting12345".getBytes(), "HMAC"); - String keyId = "test-key-id"; + void keyPairBuilderWithEcDefaultAlgorithm() throws JOSEException { + ECKeyGenerator generator = new ECKeyGenerator(Curve.P_256); + ECKey key = generator.generate(); + NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withKeyPair(key.toECPublicKey(), key.toECPrivateKey()).build(); JwtClaimsSet claims = buildClaims(); - - NimbusJwtEncoder encoder = NimbusJwtEncoder.withSecretKey(secretKey).keyId(keyId).build(); - Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims)); - - assertThat(jwt).isNotNull(); - assertThat(jwt.getHeaders().get("kid").toString()).isEqualTo(keyId); - assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("HS256"); - assertThatNoException().isThrownBy(jwt::getClaims); + Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims)); assertJwt(jwt); + assertThat(jwt.getHeaders()).containsKey(JoseHeaderNames.KID); } @Test - void secretKeyBuilderWithCustomJwkSelector() { - SecretKey secretKey = new SecretKeySpec("thisIsASecretKeyUsedForTesting12345".getBytes(), "HMAC"); - String keyId = "test-key-id"; + void keyPairBuilderWithSecretKeyDefaultAlgorithm() { + SecretKey key = TestKeys.DEFAULT_SECRET_KEY; + NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withSecretKey(key).build(); JwtClaimsSet claims = buildClaims(); - - NimbusJwtEncoder encoder = NimbusJwtEncoder.withSecretKey(secretKey).keyId(keyId).build(); - Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims)); - - assertThat(jwt).isNotNull(); - assertThat(jwt.getHeaders().get("kid")).isEqualTo(keyId); - assertThat(jwt.getClaims()).containsEntry("sub", "subject"); - assertThatNoException().isThrownBy(() -> jwt.getClaims()); + Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims)); assertJwt(jwt); + assertThat(jwt.getHeaders()).containsKey(JoseHeaderNames.KID); } + // With custom algorithm @Test - void secretKeyBuilderWithCustomHeaders() { - SecretKey secretKey = new SecretKeySpec("thisIsASecretKeyUsedForTesting12345".getBytes(), "HMAC"); - JwtClaimsSet claims = buildClaims(); - JwsHeader headers = JwsHeader.with(org.springframework.security.oauth2.jose.jws.MacAlgorithm.HS256) - .type("JWT") - .contentType("application/jwt") + void keyPairBuilderWithRsaWithAlgorithm() throws JOSEException { + RSAKeyGenerator generator = new RSAKeyGenerator(2048); + RSAKey key = generator.generate(); + NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withKeyPair(key.toRSAPublicKey(), key.toRSAPrivateKey()) + .algorithm(SignatureAlgorithm.RS384) .build(); - - NimbusJwtEncoder encoder = NimbusJwtEncoder.withSecretKey(secretKey).build(); - Jwt jwt = encoder.encode(JwtEncoderParameters.from(headers, claims)); - - assertThat(jwt).isNotNull(); - assertThat(jwt.getHeaders().get("typ").toString()).isEqualTo("JWT"); - assertThat(jwt.getHeaders().get("cty").toString()).isEqualTo("application/jwt"); - assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("HS256"); - assertThatNoException().isThrownBy(() -> jwt.getClaims()); + JwtClaimsSet claims = buildClaims(); + Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims)); assertJwt(jwt); + assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.ALG, SignatureAlgorithm.RS384); + assertThat(jwt.getHeaders()).containsKey(JoseHeaderNames.KID); } @Test - void keyPairBuilderWithRsaDefaultAlgorithm() throws Exception { - KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); - keyPairGenerator.initialize(2048); - KeyPair keyPair = keyPairGenerator.generateKeyPair(); + void keyPairBuilderWithEcWithAlgorithm() throws JOSEException { + ECKeyGenerator generator = new ECKeyGenerator(Curve.P_384); + ECKey key = generator.generate(); + NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withKeyPair(key.toECPublicKey(), key.toECPrivateKey()).build(); JwtClaimsSet claims = buildClaims(); - - NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair).build(); - Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims)); - - assertThat(jwt).isNotNull(); - assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("RS256"); - assertThat(jwt.getSubject()).isEqualTo(claims.getSubject()); - assertThat(jwt.getAudience()).isEqualTo(claims.getAudience()); - assertThatNoException().isThrownBy(() -> jwt.getClaims()); + Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims)); assertJwt(jwt); + assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.ALG, SignatureAlgorithm.ES384); + assertThat(jwt.getHeaders()).containsKey(JoseHeaderNames.KID); } @Test - void keyPairBuilderWithRsaCustomAlgorithm() throws Exception { - KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); - keyPairGenerator.initialize(2048); - KeyPair keyPair = keyPairGenerator.generateKeyPair(); + void keyPairBuilderWithSecretKeyWithAlgorithm() { + String keyStr = UUID.randomUUID().toString(); + keyStr += keyStr; + SecretKey Key = new SecretKeySpec(keyStr.getBytes(), "AES"); + NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withSecretKey(Key).algorithm(MacAlgorithm.HS512).build(); JwtClaimsSet claims = buildClaims(); + Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims)); + assertJwt(jwt); + assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.ALG, MacAlgorithm.HS512); + assertThat(jwt.getHeaders()).containsKey(JoseHeaderNames.KID); + } - NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair) - .signatureAlgorithm(SignatureAlgorithm.RS512) + @Test + void keyPairBuilderWhenShortSecretThenHigherAlgorithmNotSupported() { + String keyStr = UUID.randomUUID().toString(); + SecretKey Key = new SecretKeySpec(keyStr.getBytes(), "AES"); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> NimbusJwtEncoder.withSecretKey(Key).algorithm(MacAlgorithm.HS512).build()); + } + + @Test + void keyPairBuilderWhenTooShortSecretThenException() { + SecretKey Key = new SecretKeySpec("key".getBytes(), "AES"); + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> NimbusJwtEncoder.withSecretKey(Key)); + } + + // with custom jwkPostProcessor + @Test + void keyPairBuilderWithRsaWithAlgorithmAndJwkSource() throws JOSEException { + RSAKeyGenerator generator = new RSAKeyGenerator(2048); + RSAKey key = generator.generate(); + String keyId = UUID.randomUUID().toString(); + NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withKeyPair(key.toRSAPublicKey(), key.toRSAPrivateKey()) + .algorithm(SignatureAlgorithm.RS384) + .jwkPostProcessor((builder) -> builder.keyID(keyId)) .build(); - Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims)); - - assertThat(jwt).isNotNull(); - assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("RS512"); - assertThat(jwt.getSubject()).isEqualTo(claims.getSubject()); - assertThatNoException().isThrownBy(() -> jwt.getClaims()); - assertJwt(jwt); - } - - @Test - void keyPairBuilderWithEcDefaultAlgorithm() throws Exception { - KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("EC"); - keyPairGenerator.initialize(256); - KeyPair keyPair = keyPairGenerator.generateKeyPair(); JwtClaimsSet claims = buildClaims(); - - NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair).build(); - Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims)); - - assertThat(jwt).isNotNull(); - assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("ES256"); - assertThat(jwt.getSubject()).isEqualTo(claims.getSubject()); - assertThatNoException().isThrownBy(() -> jwt.getClaims()); + Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims)); assertJwt(jwt); + assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.ALG, SignatureAlgorithm.RS384); + assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.KID, keyId); } @Test - void keyPairBuilderWithEcCustomAlgorithm() throws Exception { - KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("EC"); - keyPairGenerator.initialize(256); - KeyPair keyPair = keyPairGenerator.generateKeyPair(); - NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair) - .keyId(UUID.randomUUID().toString()) - .signatureAlgorithm(SignatureAlgorithm.ES256) + void keyPairBuilderWithEcWithAlgorithmAndJwkSource() throws JOSEException { + ECKeyGenerator generator = new ECKeyGenerator(Curve.P_256); + ECKey key = generator.generate(); + String keyId = UUID.randomUUID().toString(); + Consumer jwkPostProcessor = (builder) -> builder.keyID(keyId); + NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withKeyPair(key.toECPublicKey(), key.toECPrivateKey()) + .jwkPostProcessor(jwkPostProcessor) .build(); - JwtClaimsSet claims = buildClaims(); - Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims)); - - assertThat(jwt).isNotNull(); - assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("ES256"); - assertThatNoException().isThrownBy(() -> jwt.getClaims()); + Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims)); assertJwt(jwt); + assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.ALG, SignatureAlgorithm.ES256); + assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.KID, keyId); } @Test - void keyPairBuilderWithKeyId() throws Exception { // d - KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); - keyPairGenerator.initialize(2048); - KeyPair keyPair = keyPairGenerator.generateKeyPair(); - String keyId = "test-key-id"; + void keyPairBuilderWithSecretKeyWithAlgorithmAndJwkSource() { + final String keyStr = UUID.randomUUID().toString(); + SecretKey key = new SecretKeySpec(keyStr.getBytes(), "HS256"); + String keyId = UUID.randomUUID().toString(); + Consumer jwkPostProcessor = (builder) -> builder.keyID(keyId); + NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withSecretKey(key).jwkPostProcessor(jwkPostProcessor).build(); JwtClaimsSet claims = buildClaims(); - - NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair).keyId(keyId).build(); - Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims)); - - assertThat(jwt).isNotNull(); - assertThat(jwt.getHeaders().get("kid")).isEqualTo(keyId); - assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("RS256"); - assertThatNoException().isThrownBy(() -> jwt.getClaims()); + Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims)); + assertJwt(jwt); + assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.ALG, MacAlgorithm.HS256); + assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.KID, keyId); } private JwtClaimsSet buildClaims() {