Add NimbusJwtEncoder Builders

Closes gh-16267

Signed-off-by: Suraj Bhadrike <surajbh2233@gmail.com>
This commit is contained in:
Suraj Bhadrike 2025-05-09 01:43:55 +05:30 committed by Josh Cummings
parent 709f5db0e5
commit ee09215f89
2 changed files with 454 additions and 3 deletions

View File

@ -18,6 +18,8 @@ package org.springframework.security.oauth2.jwt;
import java.net.URI;
import java.net.URL;
import java.security.KeyPair;
import java.security.interfaces.ECPublicKey;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Date;
@ -25,19 +27,28 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import javax.crypto.SecretKey;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JOSEObjectType;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
import com.nimbusds.jose.jwk.Curve;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyType;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.produce.JWSSignerFactory;
@ -47,6 +58,7 @@ import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
@ -83,6 +95,8 @@ public final class NimbusJwtEncoder implements JwtEncoder {
private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory();
private JwsHeader jwsHeader;
private final Map<JWK, JWSSigner> jwsSigners = new ConcurrentHashMap<>();
private final JWKSource<SecurityContext> jwkSource;
@ -119,14 +133,16 @@ public final class NimbusJwtEncoder implements JwtEncoder {
this.jwkSelector = jwkSelector;
}
public void setJwsHeader(JwsHeader jwsHeader) {
this.jwsHeader = jwsHeader;
}
@Override
public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException {
Assert.notNull(parameters, "parameters cannot be null");
JwsHeader headers = parameters.getJwsHeader();
if (headers == null) {
headers = DEFAULT_JWS_HEADER;
}
headers = (headers != null) ? headers : (this.jwsHeader != null) ? this.jwsHeader : DEFAULT_JWS_HEADER;
JwtClaimsSet claims = parameters.getClaims();
JWK jwk = selectJwk(headers);
@ -369,4 +385,249 @@ public final class NimbusJwtEncoder implements JwtEncoder {
}
}
/**
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
* {@link SecretKey}.
* @param secretKey the {@link SecretKey} to use for signing JWTs
* @return a {@link SecretKeyJwtEncoderBuilder} for further configuration
* @since 7.0
*/
public static SecretKeyJwtEncoderBuilder withSecretKey(SecretKey secretKey) {
Assert.notNull(secretKey, "secretKey cannot be null");
return new SecretKeyJwtEncoderBuilder(secretKey);
}
/**
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
* {@link KeyPair}. The key pair must contain either an {@link RSAKey} or an
* {@link ECKey}.
* @param keyPair the {@link KeyPair} to use for signing JWTs
* @return a {@link KeyPairJwtEncoderBuilder} for further configuration
* @since 7.0
*/
public static KeyPairJwtEncoderBuilder withKeyPair(KeyPair keyPair) {
Assert.isTrue(keyPair != null && keyPair.getPrivate() != null && keyPair.getPublic() != null,
"keyPair, its private key, and public key must not be null");
Assert.isTrue(
keyPair.getPrivate() instanceof java.security.interfaces.RSAKey
|| keyPair.getPrivate() instanceof java.security.interfaces.ECKey,
"keyPair must be an RSAKey or an ECKey");
if (keyPair.getPrivate() instanceof java.security.interfaces.RSAKey) {
return new RsaKeyPairJwtEncoderBuilder(keyPair);
}
if (keyPair.getPrivate() instanceof java.security.interfaces.ECKey) {
return new EcKeyPairJwtEncoderBuilder(keyPair);
}
throw new IllegalArgumentException("keyPair must be an RSAKey or an ECKey");
}
/**
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
* {@link SecretKey}.
*
* @since 7.0
*/
public static final class SecretKeyJwtEncoderBuilder {
private final SecretKey secretKey;
private String keyId;
private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256;
private SecretKeyJwtEncoderBuilder(SecretKey secretKey) {
this.secretKey = secretKey;
}
/**
* Sets the JWS algorithm to use for signing. Defaults to
* {@link JWSAlgorithm#HS256}. Must be an HMAC-based algorithm (HS256, HS384, or
* HS512).
* @param macAlgorithm the {@link MacAlgorithm} to use
* @return this builder instance for method chaining
*/
public SecretKeyJwtEncoderBuilder macAlgorithm(MacAlgorithm macAlgorithm) {
Assert.notNull(macAlgorithm, "macAlgorithm cannot be null");
Assert.state(JWSAlgorithm.Family.HMAC_SHA.contains(this.jwsAlgorithm),
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with a SecretKey. "
+ "Please use one of the HS256, HS384, or HS512 algorithms.");
this.jwsAlgorithm = JWSAlgorithm.parse(macAlgorithm.getName());
return this;
}
/**
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
* header.
* @param keyId the key identifier
* @return this builder instance for method chaining
*/
public SecretKeyJwtEncoderBuilder keyId(String keyId) {
this.keyId = keyId;
return this;
}
/**
* Builds the {@link NimbusJwtEncoder} instance.
* @return the configured {@link NimbusJwtEncoder}
* @throws IllegalStateException if the configured JWS algorithm is not compatible
* with a {@link SecretKey}.
*/
public NimbusJwtEncoder build() {
this.jwsAlgorithm = (this.jwsAlgorithm != null) ? this.jwsAlgorithm : JWSAlgorithm.HS256;
OctetSequenceKey.Builder builder = new OctetSequenceKey.Builder(this.secretKey).keyUse(KeyUse.SIGNATURE)
.algorithm(this.jwsAlgorithm)
.keyID(this.keyId);
OctetSequenceKey jwk = builder.build();
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
encoder.setJwsHeader(JwsHeader.with(MacAlgorithm.from(this.jwsAlgorithm.getName())).build());
return encoder;
}
}
/**
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
* {@link KeyPair}.
*
* @since 7.0
*/
public abstract static class KeyPairJwtEncoderBuilder {
private final KeyPair keyPair;
private String keyId;
private JWSAlgorithm jwsAlgorithm;
private KeyPairJwtEncoderBuilder(KeyPair keyPair) {
this.keyPair = keyPair;
}
/**
* Sets the JWS algorithm to use for signing. Must be compatible with the key type
* (RSA or EC). If not set, a default algorithm will be chosen based on the key
* type (e.g., RS256 for RSA, ES256 for EC).
* @param signatureAlgorithm the {@link SignatureAlgorithm} to use
* @return this builder instance for method chaining
*/
public KeyPairJwtEncoderBuilder signatureAlgorithm(SignatureAlgorithm signatureAlgorithm) {
Assert.notNull(signatureAlgorithm, "signatureAlgorithm cannot be null");
this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
return this;
}
/**
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
* header.
* @param keyId the key identifier
* @return this builder instance for method chaining
*/
public KeyPairJwtEncoderBuilder keyId(String keyId) {
this.keyId = keyId;
return this;
}
/**
* Builds the {@link NimbusJwtEncoder} instance.
* @return the configured {@link NimbusJwtEncoder}
* @throws IllegalStateException if the key type is unsupported or the configured
* JWS algorithm is not compatible with the key type.
* @throws JwtEncodingException if the key is invalid (e.g., EC key with unknown
* curve)
*/
public NimbusJwtEncoder build() {
this.keyId = (this.keyId != null) ? this.keyId : UUID.randomUUID().toString();
JWK jwk = buildJwk();
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.from(this.jwsAlgorithm.getName()))
.keyId(jwk.getKeyID())
.build();
encoder.setJwsHeader(jwsHeader);
return encoder;
}
protected abstract JWK buildJwk();
}
/**
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
* {@link KeyPair}.
*
* @since 7.0
*/
public static final class RsaKeyPairJwtEncoderBuilder extends KeyPairJwtEncoderBuilder {
private RsaKeyPairJwtEncoderBuilder(KeyPair keyPair) {
super(keyPair);
}
@Override
protected JWK buildJwk() {
if (super.jwsAlgorithm == null) {
super.jwsAlgorithm = JWSAlgorithm.RS256;
}
Assert.state(JWSAlgorithm.Family.RSA.contains(super.jwsAlgorithm),
() -> "The algorithm '" + super.jwsAlgorithm + "' is not compatible with an RSAKey. "
+ "Please use one of the RS256, RS384, RS512, PS256, PS384, or PS512 algorithms.");
RSAKey.Builder builder = new RSAKey.Builder(
(java.security.interfaces.RSAPublicKey) super.keyPair.getPublic())
.privateKey(super.keyPair.getPrivate())
.keyID(super.keyId)
.keyUse(KeyUse.SIGNATURE)
.algorithm(super.jwsAlgorithm);
return builder.build();
}
}
/**
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
* {@link KeyPair}.
*
* @since 7.0
*/
public static final class EcKeyPairJwtEncoderBuilder extends KeyPairJwtEncoderBuilder {
private EcKeyPairJwtEncoderBuilder(KeyPair keyPair) {
super(keyPair);
}
@Override
protected JWK buildJwk() {
if (super.jwsAlgorithm == null) {
super.jwsAlgorithm = JWSAlgorithm.ES256;
}
Assert.state(JWSAlgorithm.Family.EC.contains(super.jwsAlgorithm),
() -> "The algorithm '" + super.jwsAlgorithm + "' is not compatible with an ECKey. "
+ "Please use one of the ES256, ES384, or ES512 algorithms.");
ECPublicKey publicKey = (ECPublicKey) super.keyPair.getPublic();
Curve curve = Curve.forECParameterSpec(publicKey.getParams());
if (curve == null) {
throw new JwtEncodingException("Unable to determine Curve for EC public key.");
}
com.nimbusds.jose.jwk.ECKey.Builder builder = new com.nimbusds.jose.jwk.ECKey.Builder(curve, publicKey)
.privateKey(super.keyPair.getPrivate())
.keyUse(KeyUse.SIGNATURE)
.keyID(super.keyId)
.algorithm(super.jwsAlgorithm);
try {
return builder.build();
}
catch (IllegalStateException ex) {
throw new IllegalArgumentException("Failed to build ECKey: " + ex.getMessage(), ex);
}
}
}
}

View File

@ -16,12 +16,20 @@
package org.springframework.security.oauth2.jwt;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.KeySourceException;
@ -48,6 +56,7 @@ import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.willAnswer;
@ -344,6 +353,187 @@ public class NimbusJwtEncoderTests {
verifyNoInteractions(selector);
}
@Test
void secretKeyBuilderWithDefaultAlgorithm() {
SecretKey secretKey = new SecretKeySpec("thisIsASecretKeyUsedForTesting12345".getBytes(), "HMAC");
JwtClaimsSet claims = buildClaims();
NimbusJwtEncoder encoder = NimbusJwtEncoder.withSecretKey(secretKey).build();
Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims));
assertThat(jwt).isNotNull();
assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("HS256");
assertThatNoException().isThrownBy(jwt::getClaims);
assertJwt(jwt);
}
@Test
void secretKeyBuilderWithKeyId() {
SecretKey secretKey = new SecretKeySpec("thisIsASecretKeyUsedForTesting12345".getBytes(), "HMAC");
String keyId = "test-key-id";
JwtClaimsSet claims = buildClaims();
NimbusJwtEncoder encoder = NimbusJwtEncoder.withSecretKey(secretKey).keyId(keyId).build();
Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims));
assertThat(jwt).isNotNull();
assertThat(jwt.getHeaders().get("kid").toString()).isEqualTo(keyId);
assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("HS256");
assertThatNoException().isThrownBy(jwt::getClaims);
assertJwt(jwt);
}
@Test
void secretKeyBuilderWithCustomJwkSelector() {
SecretKey secretKey = new SecretKeySpec("thisIsASecretKeyUsedForTesting12345".getBytes(), "HMAC");
String keyId = "test-key-id";
JwtClaimsSet claims = buildClaims();
NimbusJwtEncoder encoder = NimbusJwtEncoder.withSecretKey(secretKey).keyId(keyId).build();
Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims));
assertThat(jwt).isNotNull();
assertThat(jwt.getHeaders().get("kid")).isEqualTo(keyId);
assertThat(jwt.getClaims()).containsEntry("sub", "subject");
assertThatNoException().isThrownBy(() -> jwt.getClaims());
assertJwt(jwt);
}
@Test
void secretKeyBuilderWithCustomHeaders() {
SecretKey secretKey = new SecretKeySpec("thisIsASecretKeyUsedForTesting12345".getBytes(), "HMAC");
JwtClaimsSet claims = buildClaims();
JwsHeader headers = JwsHeader.with(org.springframework.security.oauth2.jose.jws.MacAlgorithm.HS256)
.type("JWT")
.contentType("application/jwt")
.build();
NimbusJwtEncoder encoder = NimbusJwtEncoder.withSecretKey(secretKey).build();
Jwt jwt = encoder.encode(JwtEncoderParameters.from(headers, claims));
assertThat(jwt).isNotNull();
assertThat(jwt.getHeaders().get("typ").toString()).isEqualTo("JWT");
assertThat(jwt.getHeaders().get("cty").toString()).isEqualTo("application/jwt");
assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("HS256");
assertThatNoException().isThrownBy(() -> jwt.getClaims());
assertJwt(jwt);
}
@Test
void keyPairBuilderWithRsaDefaultAlgorithm() throws Exception {
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
keyPairGenerator.initialize(2048);
KeyPair keyPair = keyPairGenerator.generateKeyPair();
JwtClaimsSet claims = buildClaims();
NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair).build();
Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims));
assertThat(jwt).isNotNull();
assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("RS256");
assertThat(jwt.getSubject()).isEqualTo(claims.getSubject());
assertThat(jwt.getAudience()).isEqualTo(claims.getAudience());
assertThatNoException().isThrownBy(() -> jwt.getClaims());
assertJwt(jwt);
}
@Test
void keyPairBuilderWithRsaCustomAlgorithm() throws Exception {
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
keyPairGenerator.initialize(2048);
KeyPair keyPair = keyPairGenerator.generateKeyPair();
JwtClaimsSet claims = buildClaims();
NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair)
.signatureAlgorithm(SignatureAlgorithm.RS512)
.build();
Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims));
assertThat(jwt).isNotNull();
assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("RS512");
assertThat(jwt.getSubject()).isEqualTo(claims.getSubject());
assertThatNoException().isThrownBy(() -> jwt.getClaims());
assertJwt(jwt);
}
@Test
void keyPairBuilderWithEcDefaultAlgorithm() throws Exception {
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("EC");
keyPairGenerator.initialize(256);
KeyPair keyPair = keyPairGenerator.generateKeyPair();
JwtClaimsSet claims = buildClaims();
NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair).build();
Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims));
assertThat(jwt).isNotNull();
assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("ES256");
assertThat(jwt.getSubject()).isEqualTo(claims.getSubject());
assertThatNoException().isThrownBy(() -> jwt.getClaims());
assertJwt(jwt);
}
@Test
void keyPairBuilderWithEcCustomAlgorithm() throws Exception {
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("EC");
keyPairGenerator.initialize(256);
KeyPair keyPair = keyPairGenerator.generateKeyPair();
NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair)
.keyId(UUID.randomUUID().toString())
.signatureAlgorithm(SignatureAlgorithm.ES256)
.build();
JwtClaimsSet claims = buildClaims();
Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims));
assertThat(jwt).isNotNull();
assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("ES256");
assertThatNoException().isThrownBy(() -> jwt.getClaims());
assertJwt(jwt);
}
@Test
void keyPairBuilderWithKeyId() throws Exception { // d
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
keyPairGenerator.initialize(2048);
KeyPair keyPair = keyPairGenerator.generateKeyPair();
String keyId = "test-key-id";
JwtClaimsSet claims = buildClaims();
NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair).keyId(keyId).build();
Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims));
assertThat(jwt).isNotNull();
assertThat(jwt.getHeaders().get("kid")).isEqualTo(keyId);
assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("RS256");
assertThatNoException().isThrownBy(() -> jwt.getClaims());
}
private JwtClaimsSet buildClaims() {
Instant now = Instant.now();
return JwtClaimsSet.builder()
.issuer("https://example.com")
.subject("subject")
.audience(Collections.singletonList("audience"))
.issuedAt(now)
.notBefore(now)
.expiresAt(now.plus(1, ChronoUnit.HOURS))
.id(UUID.randomUUID().toString())
.claim("custom", "value")
.build();
}
private static void assertJwt(Jwt jwt) {
assertThat(jwt.getIssuer().toString()).isEqualTo("https://example.com");
assertThat(jwt.getSubject()).isEqualTo("subject");
assertThat(jwt.getAudience()).containsExactly("audience");
assertThat(jwt.getIssuedAt()).isNotNull();
assertThat(jwt.getNotBefore()).isNotNull();
assertThat(jwt.getExpiresAt()).isNotNull();
assertThat(jwt.getId()).isNotNull();
assertThat(jwt.getClaim("custom").toString()).isEqualTo("value");
}
private static final class JwkListResultCaptor implements Answer<List<JWK>> {
private List<JWK> result;