mirror of
https://github.com/spring-projects/spring-security.git
synced 2025-05-31 17:22:13 +00:00
ServerOAuth2AuthorizedClientExchangeFilterFunction uses ServerOAuth2AuthorizedClientRepository
Issue: gh-4921
This commit is contained in:
parent
07b6699fd9
commit
5bcbb1c40f
@ -24,8 +24,8 @@ import org.springframework.security.core.GrantedAuthority;
|
||||
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
||||
import org.springframework.security.core.context.SecurityContext;
|
||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
||||
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
|
||||
import org.springframework.util.Assert;
|
||||
@ -34,6 +34,7 @@ import org.springframework.web.reactive.function.client.ClientRequest;
|
||||
import org.springframework.web.reactive.function.client.ClientResponse;
|
||||
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
|
||||
import org.springframework.web.reactive.function.client.ExchangeFunction;
|
||||
import org.springframework.web.server.ServerWebExchange;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.net.URI;
|
||||
@ -60,16 +61,22 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
||||
*/
|
||||
private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName();
|
||||
|
||||
/**
|
||||
* The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}.
|
||||
*/
|
||||
private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange.class.getName();
|
||||
|
||||
private Clock clock = Clock.systemUTC();
|
||||
|
||||
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
|
||||
|
||||
private ReactiveOAuth2AuthorizedClientService authorizedClientService;
|
||||
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
|
||||
|
||||
public ServerOAuth2AuthorizedClientExchangeFilterFunction() {}
|
||||
|
||||
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientService authorizedClientService) {
|
||||
this.authorizedClientService = authorizedClientService;
|
||||
public ServerOAuth2AuthorizedClientExchangeFilterFunction(
|
||||
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
|
||||
this.authorizedClientRepository = authorizedClientRepository;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -78,7 +85,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
||||
*
|
||||
* <pre>
|
||||
* WebClient webClient = WebClient.builder()
|
||||
* .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientService))
|
||||
* .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientRepository))
|
||||
* .build();
|
||||
* Mono<String> response = webClient
|
||||
* .get()
|
||||
@ -110,6 +117,30 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
||||
return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for
|
||||
* providing the Bearer Token. Example usage:
|
||||
*
|
||||
* <pre>
|
||||
* WebClient webClient = WebClient.builder()
|
||||
* .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientRepository))
|
||||
* .build();
|
||||
* Mono<String> response = webClient
|
||||
* .get()
|
||||
* .uri(uri)
|
||||
* .attributes(serverWebExchange(serverWebExchange))
|
||||
* // ...
|
||||
* .retrieve()
|
||||
* .bodyToMono(String.class);
|
||||
* </pre>
|
||||
* @param serverWebExchange the {@link ServerWebExchange} to use
|
||||
* @return the {@link Consumer} to populate the client request attributes
|
||||
*/
|
||||
public static Consumer<Map<String, Object>> serverWebExchange(ServerWebExchange serverWebExchange) {
|
||||
return attributes -> attributes.put(SERVER_WEB_EXCHANGE_ATTR_NAME, serverWebExchange);
|
||||
}
|
||||
|
||||
/**
|
||||
* An access token will be considered expired by comparing its expiration to now +
|
||||
* this skewed Duration. The default is 1 minute.
|
||||
@ -124,22 +155,23 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
||||
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
|
||||
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
|
||||
.map(OAuth2AuthorizedClient.class::cast);
|
||||
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
|
||||
return Mono.justOrEmpty(attribute)
|
||||
.flatMap(authorizedClient -> authorizedClient(next, authorizedClient))
|
||||
.flatMap(authorizedClient -> authorizedClient(next, authorizedClient, exchange))
|
||||
.map(authorizedClient -> bearer(request, authorizedClient))
|
||||
.flatMap(next::exchange)
|
||||
.switchIfEmpty(next.exchange(request));
|
||||
}
|
||||
|
||||
private Mono<OAuth2AuthorizedClient> authorizedClient(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
|
||||
private Mono<OAuth2AuthorizedClient> authorizedClient(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
|
||||
if (shouldRefresh(authorizedClient)) {
|
||||
return refreshAuthorizedClient(next, authorizedClient);
|
||||
return refreshAuthorizedClient(next, authorizedClient, exchange);
|
||||
}
|
||||
return Mono.just(authorizedClient);
|
||||
}
|
||||
|
||||
private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction next,
|
||||
OAuth2AuthorizedClient authorizedClient) {
|
||||
OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
|
||||
ClientRegistration clientRegistration = authorizedClient
|
||||
.getClientRegistration();
|
||||
String tokenUri = clientRegistration
|
||||
@ -155,12 +187,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
||||
.flatMap(result -> ReactiveSecurityContextHolder.getContext()
|
||||
.map(SecurityContext::getAuthentication)
|
||||
.defaultIfEmpty(new PrincipalNameAuthentication(authorizedClient.getPrincipalName()))
|
||||
.flatMap(principal -> this.authorizedClientService.saveAuthorizedClient(result, principal))
|
||||
.flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(result, principal, exchange))
|
||||
.thenReturn(result));
|
||||
}
|
||||
|
||||
private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
|
||||
if (this.authorizedClientService == null) {
|
||||
if (this.authorizedClientRepository == null) {
|
||||
return false;
|
||||
}
|
||||
OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();
|
||||
|
@ -36,9 +36,9 @@ import org.springframework.mock.http.client.reactive.MockClientHttpRequest;
|
||||
import org.springframework.security.authentication.TestingAuthenticationToken;
|
||||
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
||||
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
|
||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
|
||||
@ -70,7 +70,7 @@ import static org.springframework.security.oauth2.client.web.reactive.function.c
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||
@Mock
|
||||
private ReactiveOAuth2AuthorizedClientService authorizedClientService;
|
||||
private ServerOAuth2AuthorizedClientRepository auth2AuthorizedClientRepository;
|
||||
|
||||
private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
|
||||
|
||||
@ -124,7 +124,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||
|
||||
@Test
|
||||
public void filterWhenRefreshRequiredThenRefresh() {
|
||||
when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(Mono.empty());
|
||||
when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
|
||||
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
|
||||
.tokenType(OAuth2AccessToken.TokenType.BEARER)
|
||||
.expiresIn(3600)
|
||||
@ -139,7 +139,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||
this.accessToken.getTokenValue(),
|
||||
issuedAt,
|
||||
accessTokenExpiresAt);
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService);
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
|
||||
|
||||
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||
@ -153,7 +153,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
|
||||
.block();
|
||||
|
||||
verify(this.authorizedClientService).saveAuthorizedClient(any(), eq(authentication));
|
||||
verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
|
||||
|
||||
List<ClientRequest> requests = this.exchange.getRequests();
|
||||
assertThat(requests).hasSize(2);
|
||||
@ -173,7 +173,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||
|
||||
@Test
|
||||
public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
|
||||
when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(Mono.empty());
|
||||
when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
|
||||
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
|
||||
.tokenType(OAuth2AccessToken.TokenType.BEARER)
|
||||
.expiresIn(3600)
|
||||
@ -188,7 +188,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||
this.accessToken.getTokenValue(),
|
||||
issuedAt,
|
||||
accessTokenExpiresAt);
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService);
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
|
||||
|
||||
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||
@ -200,7 +200,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||
this.function.filter(request, this.exchange)
|
||||
.block();
|
||||
|
||||
verify(this.authorizedClientService).saveAuthorizedClient(any(), any());
|
||||
verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), any(), any());
|
||||
|
||||
List<ClientRequest> requests = this.exchange.getRequests();
|
||||
assertThat(requests).hasSize(2);
|
||||
@ -220,7 +220,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||
|
||||
@Test
|
||||
public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService);
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
|
||||
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||
"principalName", this.accessToken);
|
||||
@ -242,7 +242,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||
|
||||
@Test
|
||||
public void filterWhenNotExpiredThenShouldRefreshFalse() {
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService);
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
|
||||
|
||||
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||
|
Loading…
x
Reference in New Issue
Block a user