Share JWKSource Instances

Closes gh-10312
This commit is contained in:
Josh Cummings 2021-09-22 11:36:54 -06:00
parent 4e7c9bee46
commit 7b599d4770
10 changed files with 127 additions and 59 deletions

View File

@ -1250,7 +1250,6 @@ public class OAuth2ResourceServerConfigurerTests {
String jwtThree = jwtFromIssuer(issuerThree);
mockWebServer(String.format(metadata, issuerOne, issuerOne));
mockWebServer(jwkSet);
mockWebServer(jwkSet);
// @formatter:off
this.mvc.perform(get("/authenticated").with(bearerToken(jwtOne)))
.andExpect(status().isOk())
@ -1258,7 +1257,6 @@ public class OAuth2ResourceServerConfigurerTests {
// @formatter:on
mockWebServer(String.format(metadata, issuerTwo, issuerTwo));
mockWebServer(jwkSet);
mockWebServer(jwkSet);
// @formatter:off
this.mvc.perform(get("/authenticated").with(bearerToken(jwtTwo)))
.andExpect(status().isOk())
@ -1266,7 +1264,6 @@ public class OAuth2ResourceServerConfigurerTests {
// @formatter:on
mockWebServer(String.format(metadata, issuerThree, issuerThree));
mockWebServer(jwkSet);
mockWebServer(jwkSet);
// @formatter:off
this.mvc.perform(get("/authenticated").with(bearerToken(jwtThree)))
.andExpect(status().isUnauthorized())

View File

@ -707,21 +707,18 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
String jwtThree = jwtFromIssuer(issuerThree);
mockWebServer(String.format(metadata, issuerOne, issuerOne));
mockWebServer(jwkSet);
mockWebServer(jwkSet);
// @formatter:off
this.mvc.perform(get("/authenticated").header("Authorization", "Bearer " + jwtOne))
.andExpect(status().isNotFound());
// @formatter:on
mockWebServer(String.format(metadata, issuerTwo, issuerTwo));
mockWebServer(jwkSet);
mockWebServer(jwkSet);
// @formatter:off
this.mvc.perform(get("/authenticated").header("Authorization", "Bearer " + jwtTwo))
.andExpect(status().isNotFound());
// @formatter:on
mockWebServer(String.format(metadata, issuerThree, issuerThree));
mockWebServer(jwkSet);
mockWebServer(jwkSet);
// @formatter:off
this.mvc.perform(get("/authenticated").header("Authorization", "Bearer " + jwtThree))
.andExpect(status().isUnauthorized())

View File

@ -31,7 +31,10 @@ import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.jwk.KeyType;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.RequestEntity;
@ -82,7 +85,17 @@ final class JwtDecoderProviderConfigurationUtils {
+ "\" provided in the configuration did not " + "match the requested issuer \"" + issuer + "\"");
}
static Set<SignatureAlgorithm> getSignatureAlgorithms(JWKSource<SecurityContext> jwkSource) {
static <C extends SecurityContext> void addJWSAlgorithms(ConfigurableJWTProcessor<C> jwtProcessor) {
JWSKeySelector<C> selector = jwtProcessor.getJWSKeySelector();
if (selector instanceof JWSVerificationKeySelector) {
JWKSource<C> jwkSource = ((JWSVerificationKeySelector<C>) selector).getJWKSource();
Set<JWSAlgorithm> algorithms = getJWSAlgorithms(jwkSource);
selector = new JWSVerificationKeySelector<>(algorithms, jwkSource);
jwtProcessor.setJWSKeySelector(selector);
}
}
static <C extends SecurityContext> Set<JWSAlgorithm> getJWSAlgorithms(JWKSource<C> jwkSource) {
JWKMatcher jwkMatcher = new JWKMatcher.Builder().publicOnly(true).keyUses(KeyUse.SIGNATURE, null)
.keyTypes(KeyType.RSA, KeyType.EC).build();
Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>();
@ -106,6 +119,12 @@ final class JwtDecoderProviderConfigurationUtils {
catch (KeySourceException ex) {
throw new IllegalStateException(ex);
}
Assert.notEmpty(jwsAlgorithms, "Failed to find any algorithms from the JWK set");
return jwsAlgorithms;
}
static Set<SignatureAlgorithm> getSignatureAlgorithms(JWKSource<SecurityContext> jwkSource) {
Set<JWSAlgorithm> jwsAlgorithms = getJWSAlgorithms(jwkSource);
Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
for (JWSAlgorithm jwsAlgorithm : jwsAlgorithms) {
SignatureAlgorithm signatureAlgorithm = SignatureAlgorithm.from(jwsAlgorithm.getName());
@ -113,7 +132,6 @@ final class JwtDecoderProviderConfigurationUtils {
signatureAlgorithms.add(signatureAlgorithm);
}
}
Assert.notEmpty(signatureAlgorithms, "Failed to find any algorithms from the JWK set");
return signatureAlgorithms;
}

View File

@ -16,17 +16,9 @@
package org.springframework.security.oauth2.jwt;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URL;
import java.util.Map;
import java.util.Set;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.SecurityContext;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.util.Assert;
/**
@ -117,22 +109,10 @@ public final class JwtDecoders {
JwtDecoderProviderConfigurationUtils.validateIssuer(configuration, issuer);
OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefaultWithIssuer(issuer);
String jwkSetUri = configuration.get("jwks_uri").toString();
RemoteJWKSet<SecurityContext> jwkSource = new RemoteJWKSet<>(url(jwkSetUri));
Set<SignatureAlgorithm> signatureAlgorithms = JwtDecoderProviderConfigurationUtils
.getSignatureAlgorithms(jwkSource);
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(jwkSetUri)
.jwsAlgorithms((algs) -> algs.addAll(signatureAlgorithms)).build();
.jwtProcessorCustomizer(JwtDecoderProviderConfigurationUtils::addJWSAlgorithms).build();
jwtDecoder.setJwtValidator(jwtValidator);
return jwtDecoder;
}
private static URL url(String url) {
try {
return new URL(url);
}
catch (IOException ex) {
throw new UncheckedIOException(ex);
}
}
}

View File

@ -17,12 +17,14 @@
package org.springframework.security.oauth2.jwt;
import java.security.interfaces.RSAPublicKey;
import java.time.Duration;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
@ -274,19 +276,20 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
*/
public static final class JwkSetUriReactiveJwtDecoderBuilder {
private static final Duration FOREVER = Duration.ofMillis(Long.MAX_VALUE);
private final String jwkSetUri;
private Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
private WebClient webClient = WebClient.create();
private Consumer<ConfigurableJWTProcessor<JWKSecurityContext>> jwtProcessorCustomizer;
private BiFunction<ReactiveRemoteJWKSource, ConfigurableJWTProcessor<JWKSecurityContext>, Mono<ConfigurableJWTProcessor<JWKSecurityContext>>> jwtProcessorCustomizer;
private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) {
Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty");
this.jwkSetUri = jwkSetUri;
this.jwtProcessorCustomizer = (processor) -> {
};
this.jwtProcessorCustomizer = (source, processor) -> Mono.just(processor);
}
/**
@ -342,6 +345,16 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
public JwkSetUriReactiveJwtDecoderBuilder jwtProcessorCustomizer(
Consumer<ConfigurableJWTProcessor<JWKSecurityContext>> jwtProcessorCustomizer) {
Assert.notNull(jwtProcessorCustomizer, "jwtProcessorCustomizer cannot be null");
this.jwtProcessorCustomizer = (source, processor) -> {
jwtProcessorCustomizer.accept(processor);
return Mono.just(processor);
};
return this;
}
JwkSetUriReactiveJwtDecoderBuilder jwtProcessorCustomizer(
BiFunction<ReactiveRemoteJWKSource, ConfigurableJWTProcessor<JWKSecurityContext>, Mono<ConfigurableJWTProcessor<JWKSecurityContext>>> jwtProcessorCustomizer) {
Assert.notNull(jwtProcessorCustomizer, "jwtProcessorCustomizer cannot be null");
this.jwtProcessorCustomizer = jwtProcessorCustomizer;
return this;
}
@ -373,15 +386,17 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
jwtProcessor.setJWSKeySelector(jwsKeySelector);
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {
});
this.jwtProcessorCustomizer.accept(jwtProcessor);
ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri);
source.setWebClient(this.webClient);
Function<JWSAlgorithm, Boolean> expectedJwsAlgorithms = getExpectedJwsAlgorithms(jwsKeySelector);
Mono<ConfigurableJWTProcessor<JWKSecurityContext>> jwtProcessorMono = this.jwtProcessorCustomizer
.apply(source, jwtProcessor)
.cache((processor) -> FOREVER, (ex) -> Duration.ZERO, () -> Duration.ZERO);
return (jwt) -> {
JWKSelector selector = createSelector(expectedJwsAlgorithms, jwt.getHeader());
return source.get(selector)
return jwtProcessorMono.flatMap((processor) -> source.get(selector)
.onErrorMap((ex) -> new IllegalStateException("Could not obtain the keys", ex))
.map((jwkList) -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwkList)));
.map((jwkList) -> createClaimsSet(processor, jwt, new JWKSecurityContext(jwkList))));
};
}

View File

@ -0,0 +1,81 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.jwt;
import java.util.HashSet;
import java.util.Set;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.KeySourceException;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.jwk.KeyType;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
import reactor.core.publisher.Mono;
import org.springframework.util.Assert;
final class ReactiveJwtDecoderProviderConfigurationUtils {
static <C extends SecurityContext> Mono<ConfigurableJWTProcessor<C>> addJWSAlgorithms(
ReactiveRemoteJWKSource jwkSource, ConfigurableJWTProcessor<C> jwtProcessor) {
JWSKeySelector<C> selector = jwtProcessor.getJWSKeySelector();
if (!(selector instanceof JWSVerificationKeySelector)) {
return Mono.just(jwtProcessor);
}
JWKSource<C> delegate = ((JWSVerificationKeySelector<C>) selector).getJWKSource();
return getJWSAlgorithms(jwkSource).map((algorithms) -> new JWSVerificationKeySelector<>(algorithms, delegate))
.map((replacement) -> {
jwtProcessor.setJWSKeySelector(replacement);
return jwtProcessor;
});
}
static Mono<Set<JWSAlgorithm>> getJWSAlgorithms(ReactiveRemoteJWKSource jwkSource) {
JWKMatcher jwkMatcher = new JWKMatcher.Builder().publicOnly(true).keyUses(KeyUse.SIGNATURE, null)
.keyTypes(KeyType.RSA, KeyType.EC).build();
return jwkSource.get(new JWKSelector(jwkMatcher)).map((jwks) -> {
Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>();
for (JWK jwk : jwks) {
if (jwk.getAlgorithm() != null) {
JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(jwk.getAlgorithm().getName());
jwsAlgorithms.add(jwsAlgorithm);
}
else {
if (jwk.getKeyType() == KeyType.RSA) {
jwsAlgorithms.addAll(JWSAlgorithm.Family.RSA);
}
else if (jwk.getKeyType() == KeyType.EC) {
jwsAlgorithms.addAll(JWSAlgorithm.Family.EC);
}
}
}
Assert.notEmpty(jwsAlgorithms, "Failed to find any algorithms from the JWK set");
return jwsAlgorithms;
}).onErrorMap(KeySourceException.class, (ex) -> new IllegalStateException(ex));
}
private ReactiveJwtDecoderProviderConfigurationUtils() {
}
}

View File

@ -16,17 +16,9 @@
package org.springframework.security.oauth2.jwt;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URL;
import java.util.Map;
import java.util.Set;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.SecurityContext;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.util.Assert;
/**
@ -115,22 +107,10 @@ public final class ReactiveJwtDecoders {
JwtDecoderProviderConfigurationUtils.validateIssuer(configuration, issuer);
OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefaultWithIssuer(issuer);
String jwkSetUri = configuration.get("jwks_uri").toString();
RemoteJWKSet<SecurityContext> jwkSource = new RemoteJWKSet<>(url(jwkSetUri));
Set<SignatureAlgorithm> signatureAlgorithms = JwtDecoderProviderConfigurationUtils
.getSignatureAlgorithms(jwkSource);
NimbusReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder.withJwkSetUri(jwkSetUri)
.jwsAlgorithms((algs) -> algs.addAll(signatureAlgorithms)).build();
.jwtProcessorCustomizer(ReactiveJwtDecoderProviderConfigurationUtils::addJWSAlgorithms).build();
jwtDecoder.setJwtValidator(jwtValidator);
return jwtDecoder;
}
private static URL url(String url) {
try {
return new URL(url);
}
catch (IOException ex) {
throw new UncheckedIOException(ex);
}
}
}

View File

@ -308,7 +308,6 @@ public class JwtDecodersTests {
private void prepareConfigurationResponse(String body) {
this.server.enqueue(response(body));
this.server.enqueue(response(JWK_SET));
this.server.enqueue(response(JWK_SET));
}
private void prepareConfigurationResponseOidc() {

View File

@ -29,6 +29,7 @@ import java.util.Base64;
import java.util.Collections;
import java.util.Date;
import java.util.Map;
import java.util.function.Consumer;
import javax.crypto.SecretKey;
@ -45,6 +46,7 @@ import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.jupiter.api.AfterEach;
@ -314,7 +316,7 @@ public class NimbusReactiveJwtDecoderTests {
assertThatIllegalArgumentException()
.isThrownBy(() -> NimbusReactiveJwtDecoder
.withJwkSetUri(this.jwkSetUri)
.jwtProcessorCustomizer(null)
.jwtProcessorCustomizer((Consumer<ConfigurableJWTProcessor<JWKSecurityContext>>) null)
.build()
)
.withMessage("jwtProcessorCustomizer cannot be null");

View File

@ -282,7 +282,6 @@ public class ReactiveJwtDecodersTests {
private void prepareConfigurationResponse(String body) {
this.server.enqueue(response(body));
this.server.enqueue(response(JWK_SET));
this.server.enqueue(response(JWK_SET));
}
private void prepareConfigurationResponseOidc() {