Polish setJwkSelector

Make so that it runs only when selection is needed.
Require the provided selector be non-null.
Add Tests.

Issue gh-16170
This commit is contained in:
Josh Cummings 2025-02-14 09:35:21 -07:00
parent e22bc11cc9
commit 6793334575
3 changed files with 83 additions and 21 deletions

View File

@ -87,17 +87,12 @@ public final class NimbusJwtEncoder implements JwtEncoder {
private final JWKSource<SecurityContext> jwkSource;
private Converter<List<JWK>, JWK> jwkSelector= (jwks)->{
if (jwks.size() > 1) {
throw new JwtEncodingException(String.format(
"Failed to select a key since there are multiple for the signing algorithm [%s]; " +
"please specify a selector in NimbusJwsEncoder#setJwkSelector",jwks.get(0).getAlgorithm()));
}
if (jwks.isEmpty()) {
private Converter<List<JWK>, JWK> jwkSelector = (jwks) -> {
throw new JwtEncodingException(
String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key"));
}
return jwks.get(0);
String.format(
"Failed to select a key since there are multiple for the signing algorithm [%s]; "
+ "please specify a selector in NimbusJwsEncoder#setJwkSelector",
jwks.get(0).getAlgorithm()));
};
/**
@ -108,16 +103,19 @@ public final class NimbusJwtEncoder implements JwtEncoder {
Assert.notNull(jwkSource, "jwkSource cannot be null");
this.jwkSource = jwkSource;
}
/**
* Use this strategy to reduce the list of matching JWKs down to a since one.
* <p> For example, you can call {@code setJwkSelector(List::getFirst)} in order
* to have this encoder select the first match.
* Use this strategy to reduce the list of matching JWKs when there is more than one.
* <p>
* For example, you can call {@code setJwkSelector(List::getFirst)} in order to have
* this encoder select the first match.
*
* <p> By default, the class with throw an exception if there is more than one result.
* <p>
* By default, the class with throw an exception.
* @since 6.5
*/
public void setJwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
if(null!=jwkSelector)
Assert.notNull(jwkSelector, "jwkSelector cannot be null");
this.jwkSelector = jwkSelector;
}
@ -149,6 +147,13 @@ public final class NimbusJwtEncoder implements JwtEncoder {
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
"Failed to select a JWK signing key -> " + ex.getMessage()), ex);
}
if (jwks.isEmpty()) {
throw new JwtEncodingException(
String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key"));
}
if (jwks.size() == 1) {
return jwks.get(0);
}
return this.jwkSelector.convert(jwks);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -59,6 +59,10 @@ public final class TestJwks {
private TestJwks() {
}
public static RSAKey.Builder rsa() {
return jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY);
}
public static RSAKey.Builder jwk(RSAPublicKey publicKey, RSAPrivateKey privateKey) {
// @formatter:off
return new RSAKey.Builder(publicKey)

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -23,6 +23,7 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.KeySourceException;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWK;
@ -39,6 +40,7 @@ import org.junit.jupiter.api.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jose.TestKeys;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
@ -51,6 +53,8 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.willAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
/**
* Tests for {@link NimbusJwtEncoder}.
@ -109,7 +113,7 @@ public class NimbusJwtEncoderTests {
@Test
public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws Exception {
RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK;
RSAKey rsaJwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build();
this.jwkList.add(rsaJwk);
this.jwkList.add(rsaJwk);
@ -118,7 +122,7 @@ public class NimbusJwtEncoderTests {
assertThatExceptionOfType(JwtEncodingException.class)
.isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet)))
.withMessageContaining("Found multiple JWK signing keys for algorithm 'RS256'");
.withMessageContaining("Failed to select a key since there are multiple for the signing algorithm [RS256]");
}
@Test
@ -291,6 +295,55 @@ public class NimbusJwtEncoderTests {
assertThat(jwk1.getKeyID()).isNotEqualTo(jwk2.getKeyID());
}
@Test
public void encodeWhenMultipleKeysThenJwkSelectorUsed() throws Exception {
JWK jwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build();
JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
given(jwkSource.get(any(), any())).willReturn(List.of(jwk, jwk));
Converter<List<JWK>, JWK> selector = mock(Converter.class);
given(selector.convert(any())).willReturn(TestJwks.DEFAULT_RSA_JWK);
NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource);
jwtEncoder.setJwkSelector(selector);
JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build();
jwtEncoder.encode(JwtEncoderParameters.from(claims));
verify(selector).convert(any());
}
@Test
public void encodeWhenSingleKeyThenJwkSelectorIsNotUsed() throws Exception {
JWK jwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build();
JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
given(jwkSource.get(any(), any())).willReturn(List.of(jwk));
Converter<List<JWK>, JWK> selector = mock(Converter.class);
NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource);
jwtEncoder.setJwkSelector(selector);
JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build();
jwtEncoder.encode(JwtEncoderParameters.from(claims));
verifyNoInteractions(selector);
}
@Test
public void encodeWhenNoKeysThenJwkSelectorIsNotUsed() throws Exception {
JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
given(jwkSource.get(any(), any())).willReturn(List.of());
Converter<List<JWK>, JWK> selector = mock(Converter.class);
NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource);
jwtEncoder.setJwkSelector(selector);
JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build();
assertThatExceptionOfType(JwtEncodingException.class)
.isThrownBy(() -> jwtEncoder.encode(JwtEncoderParameters.from(claims)));
verifyNoInteractions(selector);
}
private static final class JwkListResultCaptor implements Answer<List<JWK>> {
private List<JWK> result;