From 7b7abb28bb72d7b5378599ebbd8436da53a402d2 Mon Sep 17 00:00:00 2001 From: Daeho Kwon Date: Wed, 5 Feb 2025 02:51:32 +0900 Subject: [PATCH] Remove Deprecated Usages of RemoteJWKSet Closes gh-16251 Signed-off-by: Daeho Kwon --- .../security/oauth2/jwt/NimbusJwtDecoder.java | 190 +++++++++++------- .../security/oauth2/jwt/JwtDecodersTests.java | 3 +- 2 files changed, 120 insertions(+), 73 deletions(-) diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java index 2713ee96b2..732ecc2476 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,12 @@ package org.springframework.security.oauth2.jwt; +import com.nimbusds.jose.KeySourceException; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.JWKMatcher; +import com.nimbusds.jose.jwk.JWKSelector; +import com.nimbusds.jose.jwk.source.JWKSetParseException; +import com.nimbusds.jose.jwk.source.JWKSetRetrievalException; import java.io.IOException; import java.net.MalformedURLException; import java.net.URL; @@ -26,8 +32,10 @@ import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; import java.util.function.Function; @@ -35,17 +43,12 @@ import javax.crypto.SecretKey; import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWSAlgorithm; -import com.nimbusds.jose.RemoteKeySourceException; import com.nimbusds.jose.jwk.JWKSet; -import com.nimbusds.jose.jwk.source.JWKSetCache; import com.nimbusds.jose.jwk.source.JWKSource; -import com.nimbusds.jose.jwk.source.RemoteJWKSet; import com.nimbusds.jose.proc.JWSKeySelector; import com.nimbusds.jose.proc.JWSVerificationKeySelector; import com.nimbusds.jose.proc.SecurityContext; import com.nimbusds.jose.proc.SingleKeyJWSKeySelector; -import com.nimbusds.jose.util.Resource; -import com.nimbusds.jose.util.ResourceRetriever; import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTParser; @@ -57,6 +60,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.cache.Cache; +import org.springframework.cache.support.NoOpCache; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -80,6 +84,7 @@ import org.springframework.web.client.RestTemplate; * @author Josh Cummings * @author Joe Grandja * @author Mykyta Bezverkhyi + * @author Daeho Kwon * @since 5.2 */ public final class NimbusJwtDecoder implements JwtDecoder { @@ -165,7 +170,7 @@ public final class NimbusJwtDecoder implements JwtDecoder { .build(); // @formatter:on } - catch (RemoteKeySourceException ex) { + catch (KeySourceException ex) { this.logger.trace("Failed to retrieve JWK set", ex); if (ex.getCause() instanceof ParseException) { throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex); @@ -273,7 +278,7 @@ public final class NimbusJwtDecoder implements JwtDecoder { private RestOperations restOperations = new RestTemplate(); - private Cache cache; + private Cache cache = new NoOpCache("default"); private Consumer> jwtProcessorCustomizer; @@ -376,18 +381,13 @@ public final class NimbusJwtDecoder implements JwtDecoder { return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource); } - JWKSource jwkSource(ResourceRetriever jwkSetRetriever, String jwkSetUri) { - if (this.cache == null) { - return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever); - } - JWKSetCache jwkSetCache = new SpringJWKSetCache(jwkSetUri, this.cache); - return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever, jwkSetCache); + JWKSource jwkSource() { + String jwkSetUri = this.jwkSetUri.apply(this.restOperations); + return new SpringJWKSource<>(this.restOperations, this.cache, toURL(jwkSetUri), jwkSetUri); } JWTProcessor processor() { - ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever(this.restOperations); - String jwkSetUri = this.jwkSetUri.apply(this.restOperations); - JWKSource jwkSource = jwkSource(jwkSetRetriever, jwkSetUri); + JWKSource jwkSource = jwkSource(); ConfigurableJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource)); // Spring Security validates the claim set independent from Nimbus @@ -414,84 +414,130 @@ public final class NimbusJwtDecoder implements JwtDecoder { } } - private static final class SpringJWKSetCache implements JWKSetCache { - - private final String jwkSetUri; - - private final Cache cache; - - private JWKSet jwkSet; - - SpringJWKSetCache(String jwkSetUri, Cache cache) { - this.jwkSetUri = jwkSetUri; - this.cache = cache; - this.updateJwkSetFromCache(); - } - - private void updateJwkSetFromCache() { - String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class); - if (cachedJwkSet != null) { - try { - this.jwkSet = JWKSet.parse(cachedJwkSet); - } - catch (ParseException ignored) { - // Ignore invalid cache value - } - } - } - - // Note: Only called from inside a synchronized block in RemoteJWKSet. - @Override - public void put(JWKSet jwkSet) { - this.jwkSet = jwkSet; - this.cache.put(this.jwkSetUri, jwkSet.toString(false)); - } - - @Override - public JWKSet get() { - return (!requiresRefresh()) ? this.jwkSet : null; - - } - - @Override - public boolean requiresRefresh() { - return this.cache.get(this.jwkSetUri) == null; - } - - } - - private static class RestOperationsResourceRetriever implements ResourceRetriever { + private static final class SpringJWKSource implements JWKSource { private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json"); + private final ReentrantLock reentrantLock = new ReentrantLock(); + private final RestOperations restOperations; - RestOperationsResourceRetriever(RestOperations restOperations) { + private final Cache cache; + + private final URL url; + + private final String jwkSetUri; + + private SpringJWKSource(RestOperations restOperations, Cache cache, URL url, String jwkSetUri) { Assert.notNull(restOperations, "restOperations cannot be null"); this.restOperations = restOperations; + this.cache = cache; + this.url = url; + this.jwkSetUri = jwkSetUri; } + @Override - public Resource retrieveResource(URL url) throws IOException { + public List get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException { + String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class); + JWKSet jwkSet = null; + if (cachedJwkSet != null) { + jwkSet = parse(cachedJwkSet); + } + if (jwkSet == null) { + if(reentrantLock.tryLock()) { + try { + String cachedJwkSetAfterLock = this.cache.get(this.jwkSetUri, String.class); + if (cachedJwkSetAfterLock != null) { + jwkSet = parse(cachedJwkSetAfterLock); + } + if(jwkSet == null) { + try { + jwkSet = fetchJWKSet(); + } catch (IOException e) { + throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e); + } + } + } finally { + reentrantLock.unlock(); + } + } + } + List matches = jwkSelector.select(jwkSet); + if(!matches.isEmpty()) { + return matches; + } + String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher()); + if (soughtKeyID == null) { + return Collections.emptyList(); + } + if (jwkSet.getKeyByKeyId(soughtKeyID) != null) { + return Collections.emptyList(); + } + + if(reentrantLock.tryLock()) { + try { + String jwkSetUri = this.cache.get(this.jwkSetUri, String.class); + JWKSet cacheJwkSet = parse(jwkSetUri); + if(jwkSetUri != null && cacheJwkSet.toString().equals(jwkSet.toString())) { + try { + jwkSet = fetchJWKSet(); + } catch (IOException e) { + throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e); + } + } else if (jwkSetUri != null) { + jwkSet = parse(jwkSetUri); + } + } finally { + reentrantLock.unlock(); + } + } + if(jwkSet == null) { + return Collections.emptyList(); + } + return jwkSelector.select(jwkSet); + } + + private JWKSet fetchJWKSet() throws IOException, KeySourceException { HttpHeaders headers = new HttpHeaders(); headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON)); - ResponseEntity response = getResponse(url, headers); + ResponseEntity response = getResponse(headers); if (response.getStatusCode().value() != 200) { throw new IOException(response.toString()); } - return new Resource(response.getBody(), "UTF-8"); + try { + String jwkSet = response.getBody(); + this.cache.put(this.jwkSetUri, jwkSet); + return JWKSet.parse(jwkSet); + } catch (ParseException e) { + throw new JWKSetParseException("Unable to parse JWK set", e); + } } - private ResponseEntity getResponse(URL url, HttpHeaders headers) throws IOException { + private ResponseEntity getResponse(HttpHeaders headers) throws IOException { try { - RequestEntity request = new RequestEntity<>(headers, HttpMethod.GET, url.toURI()); + RequestEntity request = new RequestEntity<>(headers, HttpMethod.GET, this.url.toURI()); return this.restOperations.exchange(request, String.class); - } - catch (Exception ex) { + } catch (Exception ex) { throw new IOException(ex); } } + private JWKSet parse(String cachedJwkSet) { + JWKSet jwkSet = null; + try { + jwkSet = JWKSet.parse(cachedJwkSet); + } catch (ParseException ignored) { + // Ignore invalid cache value + } + return jwkSet; + } + + private String getFirstSpecifiedKeyID(JWKMatcher jwkMatcher) { + Set keyIDs = jwkMatcher.getKeyIDs(); + return (keyIDs == null || keyIDs.isEmpty()) ? + null : keyIDs.stream().filter(id -> id != null).findFirst().orElse(null); + } } } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java index f343cd2b69..378a6dbd41 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -308,6 +308,7 @@ public class JwtDecodersTests { private void prepareConfigurationResponse(String body) { this.server.enqueue(response(body)); this.server.enqueue(response(JWK_SET)); + this.server.enqueue(response(JWK_SET)); // default NoOpCache } private void prepareConfigurationResponseOidc() {