Fix caching error state in ReactiveRemoteJWKSource

This commit is contained in:
Veli Döngelci 2023-10-09 16:23:30 -05:00 committed by Steve Riesenberg
parent 70ad3bf749
commit a6b872dcf3
No known key found for this signature in database
GPG Key ID: 5F311AB48A55D521
2 changed files with 48 additions and 10 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -43,23 +43,33 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource {
*/
private final AtomicReference<Mono<JWKSet>> cachedJWKSet = new AtomicReference<>(Mono.empty());
/**
* cached url for jwk set.
*/
private final AtomicReference<String> cachedJwkSetUrl = new AtomicReference<>();
private WebClient webClient = WebClient.create();
private final String jwkSetURL;
private Mono<String> jwkSetURLProvider;
ReactiveRemoteJWKSource(String jwkSetURL) {
Assert.hasText(jwkSetURL, "jwkSetURL cannot be empty");
this.jwkSetURL = jwkSetURL;
this.cachedJwkSetUrl.set(jwkSetURL);
}
ReactiveRemoteJWKSource(Mono<String> jwkSetURLProvider) {
Assert.notNull(jwkSetURLProvider, "jwkSetURLProvider cannot be null");
this.jwkSetURLProvider = jwkSetURLProvider;
}
@Override
public Mono<List<JWK>> get(JWKSelector jwkSelector) {
// @formatter:off
return this.cachedJWKSet.get()
.switchIfEmpty(Mono.defer(() -> getJWKSet()))
.switchIfEmpty(Mono.defer(this::getJWKSet))
.flatMap((jwkSet) -> get(jwkSelector, jwkSet))
.switchIfEmpty(Mono.defer(() -> getJWKSet()
.map((jwkSet) -> jwkSelector.select(jwkSet)))
.map(jwkSelector::select))
);
// @formatter:on
}
@ -95,13 +105,18 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource {
*/
private Mono<JWKSet> getJWKSet() {
// @formatter:off
return this.webClient.get()
.uri(this.jwkSetURL)
.retrieve()
.bodyToMono(String.class)
return Mono.justOrEmpty(this.cachedJwkSetUrl.get())
.switchIfEmpty(Mono.defer(() -> this.jwkSetURLProvider
.doOnNext(this.cachedJwkSetUrl::set))
)
.flatMap((jwkSetURL) -> this.webClient.get()
.uri(jwkSetURL)
.retrieve()
.bodyToMono(String.class)
)
.map(this::parse)
.doOnNext((jwkSet) -> this.cachedJWKSet
.set(Mono.just(jwkSet))
.set(Mono.just(jwkSet))
)
.cache();
// @formatter:on

View File

@ -18,6 +18,7 @@ package org.springframework.security.oauth2.jwt;
import java.util.Collections;
import java.util.List;
import java.util.function.Supplier;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
@ -31,10 +32,15 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.web.reactive.function.client.WebClientResponseException;
import reactor.core.publisher.Mono;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
/**
* @author Rob Winch
@ -52,6 +58,9 @@ public class ReactiveRemoteJWKSourceTests {
private MockWebServer server;
@Mock
private Supplier<String> mockStringSupplier;
// @formatter:off
private String keys = "{\n"
+ " \"keys\": [\n"
@ -156,4 +165,18 @@ public class ReactiveRemoteJWKSourceTests {
assertThat(this.source.get(this.selector).block()).isEmpty();
}
@Test
public void getShouldRecoverAndReturnKeysAfterErrorCase() {
given(this.matcher.matches(any())).willReturn(true);
this.source = new ReactiveRemoteJWKSource(Mono.fromSupplier(mockStringSupplier));
doThrow(WebClientResponseException.ServiceUnavailable.class).when(this.mockStringSupplier).get();
// first case: id provider has error state
assertThatThrownBy(() -> this.source.get(this.selector).block())
.isExactlyInstanceOf(WebClientResponseException.ServiceUnavailable.class);
// second case: id provider is healthy again
doReturn(this.server.url("/").toString()).when(this.mockStringSupplier).get();
var actual = this.source.get(this.selector).block();
assertThat(actual).isNotEmpty();
}
}