ServerOAuth2AuthorizedClientExchangeFilterFunction clientRegistrationId

Issue: gh-4921
This commit is contained in:
Rob Winch 2018-09-04 15:59:35 -05:00
parent 28537fa3b6
commit 158b8aa6d5
3 changed files with 79 additions and 18 deletions

View File

@ -27,12 +27,15 @@ import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2ClientException;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientRequest;
@ -75,18 +78,25 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
* The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}.
*/
private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange.class.getName();
public static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
AuthorityUtils.createAuthorityList("ROLE_USER"));
private Clock clock = Clock.systemUTC();
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
new WebClientReactiveClientCredentialsTokenResponseClient();
private ReactiveClientRegistrationRepository clientRegistrationRepository;
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
public ServerOAuth2AuthorizedClientExchangeFilterFunction() {}
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
this.clientRegistrationRepository = clientRegistrationRepository;
this.authorizedClientRepository = authorizedClientRepository;
}
@ -164,6 +174,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
}
/**
* Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
* client_credentials grant.
* @param clientCredentialsTokenResponseClient the client to use
*/
public void setClientCredentialsTokenResponseClient(
ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
}
/**
* An access token will be considered expired by comparing its expiration to now +
* this skewed Duration. The default is 1 minute.
@ -208,7 +229,39 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId,
ServerWebExchange exchange, Authentication principal) {
return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange)
.switchIfEmpty(Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId)));
.switchIfEmpty(authorizedClientNotFound(clientRegistrationId, exchange));
}
private Mono<OAuth2AuthorizedClient> authorizedClientNotFound(String clientRegistrationId, ServerWebExchange exchange) {
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
.flatMap(clientRegistration -> {
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
return clientCredentials(clientRegistration, exchange);
}
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
});
}
private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
ClientRegistration clientRegistration, ServerWebExchange exchange) {
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)
.flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, tokenResponse, exchange));
}
private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistration clientRegistration, OAuth2AccessTokenResponse tokenResponse, ServerWebExchange exchange) {
return currentAuthentication()
.flatMap(principal -> {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
clientRegistration, (principal != null ?
principal.getName() :
"anonymousUser"),
tokenResponse.getAccessToken());
return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null)
.thenReturn(authorizedClient);
});
}
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {

View File

@ -37,6 +37,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
@ -71,7 +72,10 @@ import static org.springframework.security.oauth2.client.web.reactive.function.c
@RunWith(MockitoJUnitRunner.class)
public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Mock
private ServerOAuth2AuthorizedClientRepository auth2AuthorizedClientRepository;
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
@Mock
private ReactiveClientRegistrationRepository clientRegistrationRepository;
private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
@ -125,7 +129,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void filterWhenRefreshRequiredThenRefresh() {
when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(3600)
@ -140,7 +144,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
this.accessToken.getTokenValue(),
issuedAt,
accessTokenExpiresAt);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@ -154,7 +158,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.block();
verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(2);
@ -174,7 +178,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(3600)
@ -189,7 +193,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
this.accessToken.getTokenValue(),
issuedAt,
accessTokenExpiresAt);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@ -201,7 +205,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
this.function.filter(request, this.exchange)
.block();
verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), any(), any());
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any());
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(2);
@ -221,7 +225,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken);
@ -243,7 +247,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void filterWhenNotExpiredThenShouldRefreshFalse() {
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@ -266,12 +270,13 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() {
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, refreshToken);
when(this.auth2AuthorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(clientRegistrationId(this.registration.getRegistrationId()))
.build();

View File

@ -18,7 +18,9 @@ package sample.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.web.reactive.function.client.WebClient;
/**
@ -29,9 +31,10 @@ import org.springframework.web.reactive.function.client.WebClient;
public class WebClientConfig {
@Bean
WebClient webClient() {
WebClient webClient(ReactiveClientRegistrationRepository clientRegistrationRepository,
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
return WebClient.builder()
.filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction())
.filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction(clientRegistrationRepository, authorizedClientRepository))
.build();
}
}