From 11113adf6277c9b851df815055935a0190947cb0 Mon Sep 17 00:00:00 2001 From: Josh Cummings <3627351+jzheaux@users.noreply.github.com> Date: Tue, 4 Feb 2025 17:32:12 -0700 Subject: [PATCH] Polish Nimbus JWK Source Implementation Issue gh-16251 --- .../security/oauth2/jwt/NimbusJwtDecoder.java | 168 ++++++------------ .../security/oauth2/jwt/JwtDecodersTests.java | 3 +- 2 files changed, 52 insertions(+), 119 deletions(-) diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java index 732ecc2476..df0239ebfe 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java @@ -16,15 +16,7 @@ package org.springframework.security.oauth2.jwt; -import com.nimbusds.jose.KeySourceException; -import com.nimbusds.jose.jwk.JWK; -import com.nimbusds.jose.jwk.JWKMatcher; -import com.nimbusds.jose.jwk.JWKSelector; -import com.nimbusds.jose.jwk.source.JWKSetParseException; -import com.nimbusds.jose.jwk.source.JWKSetRetrievalException; -import java.io.IOException; -import java.net.MalformedURLException; -import java.net.URL; +import java.net.URI; import java.security.interfaces.RSAPublicKey; import java.text.ParseException; import java.util.Arrays; @@ -32,7 +24,6 @@ import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.locks.ReentrantLock; @@ -43,8 +34,13 @@ import javax.crypto.SecretKey; import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.KeySourceException; +import com.nimbusds.jose.RemoteKeySourceException; import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.source.JWKSetCacheRefreshEvaluator; +import com.nimbusds.jose.jwk.source.JWKSetSource; import com.nimbusds.jose.jwk.source.JWKSource; +import com.nimbusds.jose.jwk.source.JWKSourceBuilder; import com.nimbusds.jose.proc.JWSKeySelector; import com.nimbusds.jose.proc.JWSVerificationKeySelector; import com.nimbusds.jose.proc.SecurityContext; @@ -170,7 +166,7 @@ public final class NimbusJwtDecoder implements JwtDecoder { .build(); // @formatter:on } - catch (KeySourceException ex) { + catch (RemoteKeySourceException ex) { this.logger.trace("Failed to retrieve JWK set", ex); if (ex.getCause() instanceof ParseException) { throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex); @@ -383,7 +379,11 @@ public final class NimbusJwtDecoder implements JwtDecoder { JWKSource jwkSource() { String jwkSetUri = this.jwkSetUri.apply(this.restOperations); - return new SpringJWKSource<>(this.restOperations, this.cache, toURL(jwkSetUri), jwkSetUri); + return JWKSourceBuilder.create(new SpringJWKSource<>(this.restOperations, this.cache, jwkSetUri)) + .refreshAheadCache(false) + .rateLimited(false) + .cache(this.cache instanceof NoOpCache) + .build(); } JWTProcessor processor() { @@ -405,16 +405,7 @@ public final class NimbusJwtDecoder implements JwtDecoder { return new NimbusJwtDecoder(processor()); } - private static URL toURL(String url) { - try { - return new URL(url); - } - catch (MalformedURLException ex) { - throw new IllegalArgumentException("Invalid JWK Set URL \"" + url + "\" : " + ex.getMessage(), ex); - } - } - - private static final class SpringJWKSource implements JWKSource { + private static final class SpringJWKSource implements JWKSetSource { private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json"); @@ -424,120 +415,63 @@ public final class NimbusJwtDecoder implements JwtDecoder { private final Cache cache; - private final URL url; - private final String jwkSetUri; - private SpringJWKSource(RestOperations restOperations, Cache cache, URL url, String jwkSetUri) { + private JWKSet jwkSet; + + private SpringJWKSource(RestOperations restOperations, Cache cache, String jwkSetUri) { Assert.notNull(restOperations, "restOperations cannot be null"); this.restOperations = restOperations; this.cache = cache; - this.url = url; this.jwkSetUri = jwkSetUri; - } - - - @Override - public List get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException { - String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class); - JWKSet jwkSet = null; - if (cachedJwkSet != null) { - jwkSet = parse(cachedJwkSet); - } - if (jwkSet == null) { - if(reentrantLock.tryLock()) { - try { - String cachedJwkSetAfterLock = this.cache.get(this.jwkSetUri, String.class); - if (cachedJwkSetAfterLock != null) { - jwkSet = parse(cachedJwkSetAfterLock); - } - if(jwkSet == null) { - try { - jwkSet = fetchJWKSet(); - } catch (IOException e) { - throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e); - } - } - } finally { - reentrantLock.unlock(); - } - } - } - List matches = jwkSelector.select(jwkSet); - if(!matches.isEmpty()) { - return matches; - } - String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher()); - if (soughtKeyID == null) { - return Collections.emptyList(); - } - if (jwkSet.getKeyByKeyId(soughtKeyID) != null) { - return Collections.emptyList(); - } - - if(reentrantLock.tryLock()) { + String jwks = this.cache.get(this.jwkSetUri, String.class); + if (jwks != null) { try { - String jwkSetUri = this.cache.get(this.jwkSetUri, String.class); - JWKSet cacheJwkSet = parse(jwkSetUri); - if(jwkSetUri != null && cacheJwkSet.toString().equals(jwkSet.toString())) { - try { - jwkSet = fetchJWKSet(); - } catch (IOException e) { - throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e); - } - } else if (jwkSetUri != null) { - jwkSet = parse(jwkSetUri); - } - } finally { - reentrantLock.unlock(); + this.jwkSet = JWKSet.parse(jwks); + } + catch (ParseException ignored) { + // Ignore invalid cache value } } - if(jwkSet == null) { - return Collections.emptyList(); - } - return jwkSelector.select(jwkSet); } - private JWKSet fetchJWKSet() throws IOException, KeySourceException { + private String fetchJwks() throws Exception { HttpHeaders headers = new HttpHeaders(); headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON)); - ResponseEntity response = getResponse(headers); - if (response.getStatusCode().value() != 200) { - throw new IOException(response.toString()); - } + RequestEntity request = new RequestEntity<>(headers, HttpMethod.GET, URI.create(this.jwkSetUri)); + ResponseEntity response = this.restOperations.exchange(request, String.class); + String jwks = response.getBody(); + this.jwkSet = JWKSet.parse(jwks); + return jwks; + } + + @Override + public JWKSet getJWKSet(JWKSetCacheRefreshEvaluator refreshEvaluator, long currentTime, C context) + throws KeySourceException { try { - String jwkSet = response.getBody(); - this.cache.put(this.jwkSetUri, jwkSet); - return JWKSet.parse(jwkSet); - } catch (ParseException e) { - throw new JWKSetParseException("Unable to parse JWK set", e); + this.reentrantLock.lock(); + if (refreshEvaluator.requiresRefresh(this.jwkSet)) { + this.cache.invalidate(); + } + this.cache.get(this.jwkSetUri, this::fetchJwks); + return this.jwkSet; + } + catch (Cache.ValueRetrievalException ex) { + if (ex.getCause() instanceof RemoteKeySourceException keys) { + throw keys; + } + throw new RemoteKeySourceException(ex.getCause().getMessage(), ex.getCause()); + } + finally { + this.reentrantLock.unlock(); } } - private ResponseEntity getResponse(HttpHeaders headers) throws IOException { - try { - RequestEntity request = new RequestEntity<>(headers, HttpMethod.GET, this.url.toURI()); - return this.restOperations.exchange(request, String.class); - } catch (Exception ex) { - throw new IOException(ex); - } + @Override + public void close() { + } - private JWKSet parse(String cachedJwkSet) { - JWKSet jwkSet = null; - try { - jwkSet = JWKSet.parse(cachedJwkSet); - } catch (ParseException ignored) { - // Ignore invalid cache value - } - return jwkSet; - } - - private String getFirstSpecifiedKeyID(JWKMatcher jwkMatcher) { - Set keyIDs = jwkMatcher.getKeyIDs(); - return (keyIDs == null || keyIDs.isEmpty()) ? - null : keyIDs.stream().filter(id -> id != null).findFirst().orElse(null); - } } } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java index 378a6dbd41..f343cd2b69 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2025 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -308,7 +308,6 @@ public class JwtDecodersTests { private void prepareConfigurationResponse(String body) { this.server.enqueue(response(body)); this.server.enqueue(response(JWK_SET)); - this.server.enqueue(response(JWK_SET)); // default NoOpCache } private void prepareConfigurationResponseOidc() {