diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java index 24bc732113..2ac4dc5bcc 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java @@ -43,18 +43,24 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource { */ private final AtomicReference> cachedJWKSet = new AtomicReference<>(Mono.empty()); + /** + * The cached JWK set URL. + */ + private final AtomicReference cachedJwkSetUrl = new AtomicReference<>(); + private WebClient webClient = WebClient.create(); - private final Mono jwkSetURL; + private final Mono jwkSetUrlProvider; ReactiveRemoteJWKSource(String jwkSetURL) { Assert.hasText(jwkSetURL, "jwkSetURL cannot be empty"); - this.jwkSetURL = Mono.just(jwkSetURL); + this.jwkSetUrlProvider = Mono.just(jwkSetURL); } - ReactiveRemoteJWKSource(Mono jwkSetURL) { - Assert.notNull(jwkSetURL, "jwkSetURL cannot be null"); - this.jwkSetURL = jwkSetURL.cache(); + ReactiveRemoteJWKSource(Mono jwkSetUrlProvider) { + Assert.notNull(jwkSetUrlProvider, "jwkSetUrlProvider cannot be null"); + this.jwkSetUrlProvider = Mono.fromCallable(this.cachedJwkSetUrl::get) + .switchIfEmpty(Mono.defer(() -> jwkSetUrlProvider.doOnNext(this.cachedJwkSetUrl::set))); } @Override @@ -100,13 +106,15 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource { */ private Mono getJWKSet() { // @formatter:off - return this.jwkSetURL.flatMap((jwkSetURL) -> this.webClient.get() - .uri(jwkSetURL) - .retrieve() - .bodyToMono(String.class)) + return this.jwkSetUrlProvider + .flatMap((jwkSetURL) -> this.webClient.get() + .uri(jwkSetURL) + .retrieve() + .bodyToMono(String.class) + ) .map(this::parse) .doOnNext((jwkSet) -> this.cachedJWKSet - .set(Mono.just(jwkSet)) + .set(Mono.just(jwkSet)) ) .cache(); // @formatter:on diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java index ddcc1c913f..8ecabf791c 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package org.springframework.security.oauth2.jwt; import java.util.Collections; import java.util.List; +import java.util.function.Supplier; import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWKMatcher; @@ -31,10 +32,16 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Mono; + +import org.springframework.web.reactive.function.client.WebClientResponseException; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willReturn; +import static org.mockito.BDDMockito.willThrow; /** * @author Rob Winch @@ -52,6 +59,9 @@ public class ReactiveRemoteJWKSourceTests { private MockWebServer server; + @Mock + private Supplier mockStringSupplier; + // @formatter:off private String keys = "{\n" + " \"keys\": [\n" @@ -156,4 +166,18 @@ public class ReactiveRemoteJWKSourceTests { assertThat(this.source.get(this.selector).block()).isEmpty(); } + @Test + public void getShouldRecoverAndReturnKeysAfterErrorCase() { + given(this.matcher.matches(any())).willReturn(true); + this.source = new ReactiveRemoteJWKSource(Mono.fromSupplier(this.mockStringSupplier)); + willThrow(WebClientResponseException.ServiceUnavailable.class).given(this.mockStringSupplier).get(); + // first case: id provider has error state + assertThatExceptionOfType(WebClientResponseException.ServiceUnavailable.class) + .isThrownBy(() -> this.source.get(this.selector).block()); + // second case: id provider is healthy again + willReturn(this.server.url("/").toString()).given(this.mockStringSupplier).get(); + List actual = this.source.get(this.selector).block(); + assertThat(actual).isNotEmpty(); + } + }