Polish Nimbus JWK Source Implementation

Issue gh-16251
This commit is contained in:
Josh Cummings 2025-02-04 17:32:12 -07:00
parent 7b7abb28bb
commit 11113adf62
2 changed files with 52 additions and 119 deletions

View File

@ -16,15 +16,7 @@
package org.springframework.security.oauth2.jwt; package org.springframework.security.oauth2.jwt;
import com.nimbusds.jose.KeySourceException; import java.net.URI;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.jwk.source.JWKSetParseException;
import com.nimbusds.jose.jwk.source.JWKSetRetrievalException;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.interfaces.RSAPublicKey; import java.security.interfaces.RSAPublicKey;
import java.text.ParseException; import java.text.ParseException;
import java.util.Arrays; import java.util.Arrays;
@ -32,7 +24,6 @@ import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantLock;
@ -43,8 +34,13 @@ import javax.crypto.SecretKey;
import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.KeySourceException;
import com.nimbusds.jose.RemoteKeySourceException;
import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.source.JWKSetCacheRefreshEvaluator;
import com.nimbusds.jose.jwk.source.JWKSetSource;
import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import com.nimbusds.jose.proc.JWSKeySelector; import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector; import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext; import com.nimbusds.jose.proc.SecurityContext;
@ -170,7 +166,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
.build(); .build();
// @formatter:on // @formatter:on
} }
catch (KeySourceException ex) { catch (RemoteKeySourceException ex) {
this.logger.trace("Failed to retrieve JWK set", ex); this.logger.trace("Failed to retrieve JWK set", ex);
if (ex.getCause() instanceof ParseException) { if (ex.getCause() instanceof ParseException) {
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex); throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex);
@ -383,7 +379,11 @@ public final class NimbusJwtDecoder implements JwtDecoder {
JWKSource<SecurityContext> jwkSource() { JWKSource<SecurityContext> jwkSource() {
String jwkSetUri = this.jwkSetUri.apply(this.restOperations); String jwkSetUri = this.jwkSetUri.apply(this.restOperations);
return new SpringJWKSource<>(this.restOperations, this.cache, toURL(jwkSetUri), jwkSetUri); return JWKSourceBuilder.create(new SpringJWKSource<>(this.restOperations, this.cache, jwkSetUri))
.refreshAheadCache(false)
.rateLimited(false)
.cache(this.cache instanceof NoOpCache)
.build();
} }
JWTProcessor<SecurityContext> processor() { JWTProcessor<SecurityContext> processor() {
@ -405,16 +405,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
return new NimbusJwtDecoder(processor()); return new NimbusJwtDecoder(processor());
} }
private static URL toURL(String url) { private static final class SpringJWKSource<C extends SecurityContext> implements JWKSetSource<C> {
try {
return new URL(url);
}
catch (MalformedURLException ex) {
throw new IllegalArgumentException("Invalid JWK Set URL \"" + url + "\" : " + ex.getMessage(), ex);
}
}
private static final class SpringJWKSource<C extends SecurityContext> implements JWKSource<C> {
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json"); private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
@ -424,120 +415,63 @@ public final class NimbusJwtDecoder implements JwtDecoder {
private final Cache cache; private final Cache cache;
private final URL url;
private final String jwkSetUri; private final String jwkSetUri;
private SpringJWKSource(RestOperations restOperations, Cache cache, URL url, String jwkSetUri) { private JWKSet jwkSet;
private SpringJWKSource(RestOperations restOperations, Cache cache, String jwkSetUri) {
Assert.notNull(restOperations, "restOperations cannot be null"); Assert.notNull(restOperations, "restOperations cannot be null");
this.restOperations = restOperations; this.restOperations = restOperations;
this.cache = cache; this.cache = cache;
this.url = url;
this.jwkSetUri = jwkSetUri; this.jwkSetUri = jwkSetUri;
} String jwks = this.cache.get(this.jwkSetUri, String.class);
if (jwks != null) {
@Override
public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException {
String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class);
JWKSet jwkSet = null;
if (cachedJwkSet != null) {
jwkSet = parse(cachedJwkSet);
}
if (jwkSet == null) {
if(reentrantLock.tryLock()) {
try {
String cachedJwkSetAfterLock = this.cache.get(this.jwkSetUri, String.class);
if (cachedJwkSetAfterLock != null) {
jwkSet = parse(cachedJwkSetAfterLock);
}
if(jwkSet == null) {
try {
jwkSet = fetchJWKSet();
} catch (IOException e) {
throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
}
}
} finally {
reentrantLock.unlock();
}
}
}
List<JWK> matches = jwkSelector.select(jwkSet);
if(!matches.isEmpty()) {
return matches;
}
String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher());
if (soughtKeyID == null) {
return Collections.emptyList();
}
if (jwkSet.getKeyByKeyId(soughtKeyID) != null) {
return Collections.emptyList();
}
if(reentrantLock.tryLock()) {
try { try {
String jwkSetUri = this.cache.get(this.jwkSetUri, String.class); this.jwkSet = JWKSet.parse(jwks);
JWKSet cacheJwkSet = parse(jwkSetUri); }
if(jwkSetUri != null && cacheJwkSet.toString().equals(jwkSet.toString())) { catch (ParseException ignored) {
try { // Ignore invalid cache value
jwkSet = fetchJWKSet();
} catch (IOException e) {
throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
}
} else if (jwkSetUri != null) {
jwkSet = parse(jwkSetUri);
}
} finally {
reentrantLock.unlock();
} }
} }
if(jwkSet == null) {
return Collections.emptyList();
}
return jwkSelector.select(jwkSet);
} }
private JWKSet fetchJWKSet() throws IOException, KeySourceException { private String fetchJwks() throws Exception {
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = new HttpHeaders();
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON)); headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
ResponseEntity<String> response = getResponse(headers); RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, URI.create(this.jwkSetUri));
if (response.getStatusCode().value() != 200) { ResponseEntity<String> response = this.restOperations.exchange(request, String.class);
throw new IOException(response.toString()); String jwks = response.getBody();
} this.jwkSet = JWKSet.parse(jwks);
return jwks;
}
@Override
public JWKSet getJWKSet(JWKSetCacheRefreshEvaluator refreshEvaluator, long currentTime, C context)
throws KeySourceException {
try { try {
String jwkSet = response.getBody(); this.reentrantLock.lock();
this.cache.put(this.jwkSetUri, jwkSet); if (refreshEvaluator.requiresRefresh(this.jwkSet)) {
return JWKSet.parse(jwkSet); this.cache.invalidate();
} catch (ParseException e) { }
throw new JWKSetParseException("Unable to parse JWK set", e); this.cache.get(this.jwkSetUri, this::fetchJwks);
return this.jwkSet;
}
catch (Cache.ValueRetrievalException ex) {
if (ex.getCause() instanceof RemoteKeySourceException keys) {
throw keys;
}
throw new RemoteKeySourceException(ex.getCause().getMessage(), ex.getCause());
}
finally {
this.reentrantLock.unlock();
} }
} }
private ResponseEntity<String> getResponse(HttpHeaders headers) throws IOException { @Override
try { public void close() {
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, this.url.toURI());
return this.restOperations.exchange(request, String.class);
} catch (Exception ex) {
throw new IOException(ex);
}
} }
private JWKSet parse(String cachedJwkSet) {
JWKSet jwkSet = null;
try {
jwkSet = JWKSet.parse(cachedJwkSet);
} catch (ParseException ignored) {
// Ignore invalid cache value
}
return jwkSet;
}
private String getFirstSpecifiedKeyID(JWKMatcher jwkMatcher) {
Set<String> keyIDs = jwkMatcher.getKeyIDs();
return (keyIDs == null || keyIDs.isEmpty()) ?
null : keyIDs.stream().filter(id -> id != null).findFirst().orElse(null);
}
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2025 the original author or authors. * Copyright 2002-2019 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -308,7 +308,6 @@ public class JwtDecodersTests {
private void prepareConfigurationResponse(String body) { private void prepareConfigurationResponse(String body) {
this.server.enqueue(response(body)); this.server.enqueue(response(body));
this.server.enqueue(response(JWK_SET)); this.server.enqueue(response(JWK_SET));
this.server.enqueue(response(JWK_SET)); // default NoOpCache
} }
private void prepareConfigurationResponseOidc() { private void prepareConfigurationResponseOidc() {