Add Cache to NimbusJwtDecoderJwkSetUriBuilder

PR gh-8332
This commit is contained in:
Mykyta Bezverkhyi 2020-04-06 10:43:49 +03:00 committed by Josh Cummings
parent b7d3acc02c
commit 9133cc24e4
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
2 changed files with 154 additions and 3 deletions

View File

@ -34,6 +34,8 @@ import javax.crypto.SecretKey;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.RemoteKeySourceException;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.source.JWKSetCache;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.JWSKeySelector;
@ -49,6 +51,7 @@ import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTProcessor;
import org.springframework.cache.Cache;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
@ -68,6 +71,7 @@ import org.springframework.web.client.RestTemplate;
*
* @author Josh Cummings
* @author Joe Grandja
* @author Mykyta Bezverkhyi
* @since 5.2
*/
public final class NimbusJwtDecoder implements JwtDecoder {
@ -215,6 +219,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
private String jwkSetUri;
private Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
private RestOperations restOperations = new RestTemplate();
private Cache cache;
private JwkSetUriJwtDecoderBuilder(String jwkSetUri) {
Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty");
@ -264,6 +269,20 @@ public final class NimbusJwtDecoder implements JwtDecoder {
return this;
}
/**
* Use the given {@link Cache} to store
* <a href="https://tools.ietf.org/html/rfc7517#section-5">JWK Set</a>.
*
* @param cache the {@link Cache} to be used to store JWK Set
* @return a {@link JwkSetUriJwtDecoderBuilder} for further configurations
* @since 5.4
*/
public JwkSetUriJwtDecoderBuilder cache(Cache cache) {
Assert.notNull(cache, "cache cannot be null");
this.cache = cache;
return this;
}
JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSource) {
if (this.signatureAlgorithms.isEmpty()) {
return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
@ -280,9 +299,17 @@ public final class NimbusJwtDecoder implements JwtDecoder {
}
}
JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever) {
if (this.cache == null) {
return new RemoteJWKSet<>(toURL(this.jwkSetUri), jwkSetRetriever);
}
ResourceRetriever cachingJwkSetRetriever = new CachingResourceRetriever(this.cache, jwkSetRetriever);
return new RemoteJWKSet<>(toURL(this.jwkSetUri), cachingJwkSetRetriever, new NoOpJwkSetCache());
}
JWTProcessor<SecurityContext> processor() {
ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever(this.restOperations);
JWKSource<SecurityContext> jwkSource = new RemoteJWKSet<>(toURL(this.jwkSetUri), jwkSetRetriever);
JWKSource<SecurityContext> jwkSource = jwkSource(jwkSetRetriever);
ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource));
@ -309,6 +336,44 @@ public final class NimbusJwtDecoder implements JwtDecoder {
}
}
private static class NoOpJwkSetCache implements JWKSetCache {
@Override
public void put(JWKSet jwkSet) {
}
@Override
public JWKSet get() {
return null;
}
@Override
public boolean requiresRefresh() {
return true;
}
}
private static class CachingResourceRetriever implements ResourceRetriever {
private final Cache cache;
private final ResourceRetriever resourceRetriever;
CachingResourceRetriever(Cache cache, ResourceRetriever resourceRetriever) {
this.cache = cache;
this.resourceRetriever = resourceRetriever;
}
@Override
public Resource retrieveResource(URL url) throws IOException {
String jwkSet;
try {
jwkSet = cache.get(url.toString(), () -> resourceRetriever.retrieveResource(url).getContent());
} catch (Exception ex) {
throw new IOException(ex);
}
return new Resource(jwkSet, "UTF-8");
}
}
private static class RestOperationsResourceRetriever implements ResourceRetriever {
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
private final RestOperations restOperations;

View File

@ -32,6 +32,7 @@ import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import javax.crypto.SecretKey;
import com.nimbusds.jose.JWSAlgorithm;
@ -55,6 +56,8 @@ import org.junit.BeforeClass;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.cache.Cache;
import org.springframework.cache.concurrent.ConcurrentMapCache;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
@ -66,6 +69,7 @@ import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
import org.springframework.security.oauth2.jose.TestKeys;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestOperations;
import static org.assertj.core.api.Assertions.assertThat;
@ -75,6 +79,8 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri;
import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withPublicKey;
@ -85,6 +91,7 @@ import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withSecre
*
* @author Josh Cummings
* @author Joe Grandja
* @author Mykyta Bezverkhyi
*/
public class NimbusJwtDecoderTests {
private static final String JWK_SET = "{\"keys\":[{\"p\":\"49neceJFs8R6n7WamRGy45F5Tv0YM-R2ODK3eSBUSLOSH2tAqjEVKOkLE5fiNA3ygqq15NcKRadB2pTVf-Yb5ZIBuKzko8bzYIkIqYhSh_FAdEEr0vHF5fq_yWSvc6swsOJGqvBEtuqtJY027u-G2gAQasCQdhyejer68zsTn8M\",\"kty\":\"RSA\",\"q\":\"tWR-ysspjZ73B6p2vVRVyHwP3KQWL5KEQcdgcmMOE_P_cPs98vZJfLhxobXVmvzuEWBpRSiqiuyKlQnpstKt94Cy77iO8m8ISfF3C9VyLWXi9HUGAJb99irWABFl3sNDff5K2ODQ8CmuXLYM25OwN3ikbrhEJozlXg_NJFSGD4E\",\"d\":\"FkZHYZlw5KSoqQ1i2RA2kCUygSUOf1OqMt3uomtXuUmqKBm_bY7PCOhmwbvbn4xZYEeHuTR8Xix-0KpHe3NKyWrtRjkq1T_un49_1LLVUhJ0dL-9_x0xRquVjhl_XrsRXaGMEHs8G9pLTvXQ1uST585gxIfmCe0sxPZLvwoic-bXf64UZ9BGRV3lFexWJQqCZp2S21HfoU7wiz6kfLRNi-K4xiVNB1gswm_8o5lRuY7zB9bRARQ3TS2G4eW7p5sxT3CgsGiQD3_wPugU8iDplqAjgJ5ofNJXZezoj0t6JMB_qOpbrmAM1EnomIPebSLW7Ky9SugEd6KMdL5lW6AuAQ\",\"e\":\"AQAB\",\"use\":\"sig\",\"kid\":\"one\",\"qi\":\"wdkFu_tV2V1l_PWUUimG516Zvhqk2SWDw1F7uNDD-Lvrv_WNRIJVzuffZ8WYiPy8VvYQPJUrT2EXL8P0ocqwlaSTuXctrORcbjwgxDQDLsiZE0C23HYzgi0cofbScsJdhcBg7d07LAf7cdJWG0YVl1FkMCsxUlZ2wTwHfKWf-v4\",\"dp\":\"uwnPxqC-IxG4r33-SIT02kZC1IqC4aY7PWq0nePiDEQMQWpjjNH50rlq9EyLzbtdRdIouo-jyQXB01K15-XXJJ60dwrGLYNVqfsTd0eGqD1scYJGHUWG9IDgCsxyEnuG3s0AwbW2UolWVSsU2xMZGb9PurIUZECeD1XDZwMp2s0\",\"dq\":\"hra786AunB8TF35h8PpROzPoE9VJJMuLrc6Esm8eZXMwopf0yhxfN2FEAvUoTpLJu93-UH6DKenCgi16gnQ0_zt1qNNIVoRfg4rw_rjmsxCYHTVL3-RDeC8X_7TsEySxW0EgFTHh-nr6I6CQrAJjPM88T35KHtdFATZ7BCBB8AE\",\"n\":\"oXJ8OyOv_eRnce4akdanR4KYRfnC2zLV4uYNQpcFn6oHL0dj7D6kxQmsXoYgJV8ZVDn71KGmuLvolxsDncc2UrhyMBY6DVQVgMSVYaPCTgW76iYEKGgzTEw5IBRQL9w3SRJWd3VJTZZQjkXef48Ocz06PGF3lhbz4t5UEZtdF4rIe7u-977QwHuh7yRPBQ3sII-cVoOUMgaXB9SHcGF2iZCtPzL_IffDUcfhLQteGebhW8A6eUHgpD5A1PQ-JCw_G7UOzZAjjDjtNM2eqm8j-Ms_gqnm4MiCZ4E-9pDN77CAAPVN7kuX6ejs9KBXpk01z48i9fORYk9u7rAkh1HuQw\"}]}";
@ -247,6 +254,21 @@ public class NimbusJwtDecoderTests {
}
}
@Test
public void shouldThrowJwtExceptionWhenJwkSetEndpointHasNotRespondedAndCacheIsConfigured() throws Exception {
try ( MockWebServer server = new MockWebServer() ) {
Cache cache = new ConcurrentMapCache("test-jwk-set-cache");
String jwkSetUri = server.url("/.well-known/jwks.json").toString();
NimbusJwtDecoder jwtDecoder = withJwkSetUri(jwkSetUri).cache(cache).build();
server.shutdown();
assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT))
.isInstanceOf(JwtException.class)
.isNotInstanceOf(BadJwtException.class)
.hasMessageContaining("An error occurred while attempting to decode the Jwt");
}
}
@Test
public void withJwkSetUriWhenNullOrEmptyThenThrowsException() {
Assertions.assertThatCode(() -> withJwkSetUri(null)).isInstanceOf(IllegalArgumentException.class);
@ -264,6 +286,12 @@ public class NimbusJwtDecoderTests {
Assertions.assertThatCode(() -> builder.restOperations(null)).isInstanceOf(IllegalArgumentException.class);
}
@Test
public void shouldThrowIllegalArgumentExceptionWhenJwkSetCacheIsNull() {
NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder builder = withJwkSetUri(JWK_SET_URI);
Assertions.assertThatCode(() -> builder.cache(null)).isInstanceOf(IllegalArgumentException.class);
}
@Test
public void withPublicKeyWhenNullThenThrowsException() {
assertThatThrownBy(() -> withPublicKey(null))
@ -425,7 +453,7 @@ public class NimbusJwtDecoderTests {
RestOperations restOperations = mock(RestOperations.class);
when(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
.thenReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK));
JWTProcessor<SecurityContext> processor = withJwkSetUri("https://issuer/.well-known/jwks.json")
JWTProcessor<SecurityContext> processor = withJwkSetUri(JWK_SET_URI)
.restOperations(restOperations)
.processor();
NimbusJwtDecoder jwtDecoder = new NimbusJwtDecoder(processor);
@ -436,6 +464,64 @@ public class NimbusJwtDecoderTests {
assertThat(acceptHeader).contains(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON);
}
@Test
public void shouldStoreRetrievedJwkSetToCache() {
// given
Cache cache = new ConcurrentMapCache("test-jwk-set-cache");
RestOperations restOperations = mock(RestOperations.class);
when(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
.thenReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK));
NimbusJwtDecoder jwtDecoder = withJwkSetUri(JWK_SET_URI)
.restOperations(restOperations)
.cache(cache)
.build();
// when
jwtDecoder.decode(SIGNED_JWT);
// then
assertThat(cache.get(JWK_SET_URI, String.class)).isEqualTo(JWK_SET);
ArgumentCaptor<RequestEntity> requestEntityCaptor = ArgumentCaptor.forClass(RequestEntity.class);
verify(restOperations).exchange(requestEntityCaptor.capture(), eq(String.class));
verifyNoMoreInteractions(restOperations);
List<MediaType> acceptHeader = requestEntityCaptor.getValue().getHeaders().getAccept();
assertThat(acceptHeader).contains(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON);
}
@Test
public void shouldDecodeJwtUsingJwkSetCache() {
// given
RestOperations restOperations = mock(RestOperations.class);
Cache cache = mock(Cache.class);
when(cache.get(eq(JWK_SET_URI), any(Callable.class))).thenReturn(JWK_SET);
NimbusJwtDecoder jwtDecoder = withJwkSetUri(JWK_SET_URI)
.cache(cache)
.restOperations(restOperations)
.build();
// when
jwtDecoder.decode(SIGNED_JWT);
// then
verify(cache).get(eq(JWK_SET_URI), any(Callable.class));
verifyNoMoreInteractions(cache);
verifyNoInteractions(restOperations);
}
@Test
public void shouldThrowJwtExceptionWhenExceptionOccurredWhileRetrievingJwkSetInsideCachingRetriever() {
// given
Cache cache = new ConcurrentMapCache("test-jwk-set-cache");
RestOperations restOperations = mock(RestOperations.class);
when(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
.thenThrow(new RestClientException("Cannot retrieve JWK Set"));
NimbusJwtDecoder jwtDecoder = withJwkSetUri(JWK_SET_URI)
.restOperations(restOperations)
.cache(cache)
.build();
// then
assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT))
.isInstanceOf(JwtException.class)
.isNotInstanceOf(BadJwtException.class)
.hasMessageContaining("An error occurred while attempting to decode the Jwt");
}
private RSAPublicKey key() throws InvalidKeySpecException {
byte[] decoded = Base64.getDecoder().decode(VERIFY_KEY.getBytes());
EncodedKeySpec spec = new X509EncodedKeySpec(decoded);
@ -466,7 +552,7 @@ public class NimbusJwtDecoderTests {
RestOperations restOperations = mock(RestOperations.class);
when(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
.thenReturn(new ResponseEntity<>(jwkResponse, HttpStatus.OK));
return withJwkSetUri("https://issuer/.well-known/jwks.json")
return withJwkSetUri(JWK_SET_URI)
.restOperations(restOperations)
.processor();
}