ServerOAuth2AuthorizedClientExchangeFilterFunction uses ServerOAuth2AuthorizedClientRepository

Issue: gh-4921
This commit is contained in:
Rob Winch 2018-08-24 12:50:39 -05:00
parent 07b6699fd9
commit 5bcbb1c40f
2 changed files with 53 additions and 21 deletions

View File

@ -24,8 +24,8 @@ import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.util.Assert;
@ -34,6 +34,7 @@ import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import org.springframework.web.reactive.function.client.ExchangeFunction;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import java.net.URI;
@ -60,16 +61,22 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
*/
private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName();
/**
* The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}.
*/
private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange.class.getName();
private Clock clock = Clock.systemUTC();
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
private ReactiveOAuth2AuthorizedClientService authorizedClientService;
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
public ServerOAuth2AuthorizedClientExchangeFilterFunction() {}
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientService authorizedClientService) {
this.authorizedClientService = authorizedClientService;
public ServerOAuth2AuthorizedClientExchangeFilterFunction(
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
this.authorizedClientRepository = authorizedClientRepository;
}
/**
@ -78,7 +85,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
*
* <pre>
* WebClient webClient = WebClient.builder()
* .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientService))
* .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientRepository))
* .build();
* Mono<String> response = webClient
* .get()
@ -110,6 +117,30 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient);
}
/**
* Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for
* providing the Bearer Token. Example usage:
*
* <pre>
* WebClient webClient = WebClient.builder()
* .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientRepository))
* .build();
* Mono<String> response = webClient
* .get()
* .uri(uri)
* .attributes(serverWebExchange(serverWebExchange))
* // ...
* .retrieve()
* .bodyToMono(String.class);
* </pre>
* @param serverWebExchange the {@link ServerWebExchange} to use
* @return the {@link Consumer} to populate the client request attributes
*/
public static Consumer<Map<String, Object>> serverWebExchange(ServerWebExchange serverWebExchange) {
return attributes -> attributes.put(SERVER_WEB_EXCHANGE_ATTR_NAME, serverWebExchange);
}
/**
* An access token will be considered expired by comparing its expiration to now +
* this skewed Duration. The default is 1 minute.
@ -124,22 +155,23 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
.map(OAuth2AuthorizedClient.class::cast);
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
return Mono.justOrEmpty(attribute)
.flatMap(authorizedClient -> authorizedClient(next, authorizedClient))
.flatMap(authorizedClient -> authorizedClient(next, authorizedClient, exchange))
.map(authorizedClient -> bearer(request, authorizedClient))
.flatMap(next::exchange)
.switchIfEmpty(next.exchange(request));
}
private Mono<OAuth2AuthorizedClient> authorizedClient(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
private Mono<OAuth2AuthorizedClient> authorizedClient(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
if (shouldRefresh(authorizedClient)) {
return refreshAuthorizedClient(next, authorizedClient);
return refreshAuthorizedClient(next, authorizedClient, exchange);
}
return Mono.just(authorizedClient);
}
private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction next,
OAuth2AuthorizedClient authorizedClient) {
OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
ClientRegistration clientRegistration = authorizedClient
.getClientRegistration();
String tokenUri = clientRegistration
@ -155,12 +187,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
.flatMap(result -> ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.defaultIfEmpty(new PrincipalNameAuthentication(authorizedClient.getPrincipalName()))
.flatMap(principal -> this.authorizedClientService.saveAuthorizedClient(result, principal))
.flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(result, principal, exchange))
.thenReturn(result));
}
private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
if (this.authorizedClientService == null) {
if (this.authorizedClientRepository == null) {
return false;
}
OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();

View File

@ -36,9 +36,9 @@ import org.springframework.mock.http.client.reactive.MockClientHttpRequest;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@ -70,7 +70,7 @@ import static org.springframework.security.oauth2.client.web.reactive.function.c
@RunWith(MockitoJUnitRunner.class)
public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Mock
private ReactiveOAuth2AuthorizedClientService authorizedClientService;
private ServerOAuth2AuthorizedClientRepository auth2AuthorizedClientRepository;
private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
@ -124,7 +124,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void filterWhenRefreshRequiredThenRefresh() {
when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(Mono.empty());
when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(3600)
@ -139,7 +139,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
this.accessToken.getTokenValue(),
issuedAt,
accessTokenExpiresAt);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@ -153,7 +153,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.block();
verify(this.authorizedClientService).saveAuthorizedClient(any(), eq(authentication));
verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(2);
@ -173,7 +173,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(Mono.empty());
when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(3600)
@ -188,7 +188,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
this.accessToken.getTokenValue(),
issuedAt,
accessTokenExpiresAt);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@ -200,7 +200,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
this.function.filter(request, this.exchange)
.block();
verify(this.authorizedClientService).saveAuthorizedClient(any(), any());
verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), any(), any());
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(2);
@ -220,7 +220,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken);
@ -242,7 +242,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void filterWhenNotExpiredThenShouldRefreshFalse() {
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,