From 158b8aa6d5750e08be86ff6f184aa5a7d6c8da42 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Tue, 4 Sep 2018 15:59:35 -0500 Subject: [PATCH] ServerOAuth2AuthorizedClientExchangeFilterFunction clientRegistrationId Issue: gh-4921 --- ...uthorizedClientExchangeFilterFunction.java | 63 +++++++++++++++++-- ...izedClientExchangeFilterFunctionTests.java | 27 ++++---- .../java/sample/config/WebClientConfig.java | 7 ++- 3 files changed, 79 insertions(+), 18 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java index 11c560fccc..1ccfdfd99c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java @@ -27,12 +27,15 @@ import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.OAuth2ClientException; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.ClientRequest; @@ -75,18 +78,25 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements * The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}. */ private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange.class.getName(); - public static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser", + + private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_USER")); private Clock clock = Clock.systemUTC(); private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); + private ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = + new WebClientReactiveClientCredentialsTokenResponseClient(); + + private ReactiveClientRegistrationRepository clientRegistrationRepository; + private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; public ServerOAuth2AuthorizedClientExchangeFilterFunction() {} - public ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientRepository = authorizedClientRepository; } @@ -164,6 +174,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId); } + /** + * Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for + * client_credentials grant. + * @param clientCredentialsTokenResponseClient the client to use + */ + public void setClientCredentialsTokenResponseClient( + ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { + Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); + this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient; + } + /** * An access token will be considered expired by comparing its expiration to now + * this skewed Duration. The default is 1 minute. @@ -208,7 +229,39 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements private Mono loadAuthorizedClient(String clientRegistrationId, ServerWebExchange exchange, Authentication principal) { return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange) - .switchIfEmpty(Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId))); + .switchIfEmpty(authorizedClientNotFound(clientRegistrationId, exchange)); + } + + private Mono authorizedClientNotFound(String clientRegistrationId, ServerWebExchange exchange) { + return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) + .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found"))) + .flatMap(clientRegistration -> { + if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { + return clientCredentials(clientRegistration, exchange); + } + return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId)); + }); + } + + private Mono clientCredentials( + ClientRegistration clientRegistration, ServerWebExchange exchange) { + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest) + .flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, tokenResponse, exchange)); + } + + private Mono clientCredentialsResponse(ClientRegistration clientRegistration, OAuth2AccessTokenResponse tokenResponse, ServerWebExchange exchange) { + return currentAuthentication() + .flatMap(principal -> { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, (principal != null ? + principal.getName() : + "anonymousUser"), + tokenResponse.getAccessToken()); + + return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null) + .thenReturn(authorizedClient); + }); } private Mono refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 579bda1505..630ae529e9 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -37,6 +37,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; @@ -71,7 +72,10 @@ import static org.springframework.security.oauth2.client.web.reactive.function.c @RunWith(MockitoJUnitRunner.class) public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Mock - private ServerOAuth2AuthorizedClientRepository auth2AuthorizedClientRepository; + private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + + @Mock + private ReactiveClientRegistrationRepository clientRegistrationRepository; private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(); @@ -125,7 +129,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshRequiredThenRefresh() { - when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); + when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") .tokenType(OAuth2AccessToken.TokenType.BEARER) .expiresIn(3600) @@ -140,7 +144,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository); + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, @@ -154,7 +158,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) .block(); - verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any()); + verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any()); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(2); @@ -174,7 +178,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() { - when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); + when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") .tokenType(OAuth2AccessToken.TokenType.BEARER) .expiresIn(3600) @@ -189,7 +193,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository); + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, @@ -201,7 +205,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { this.function.filter(request, this.exchange) .block(); - verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), any(), any()); + verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any()); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(2); @@ -221,7 +225,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshTokenNullThenShouldRefreshFalse() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository); + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); @@ -243,7 +247,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenNotExpiredThenShouldRefreshFalse() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository); + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt()); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, @@ -266,12 +270,13 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository); + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt()); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); - when(this.auth2AuthorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient)); + when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient)); + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration)); ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(clientRegistrationId(this.registration.getRegistrationId())) .build(); diff --git a/samples/boot/authcodegrant-webflux/src/main/java/sample/config/WebClientConfig.java b/samples/boot/authcodegrant-webflux/src/main/java/sample/config/WebClientConfig.java index 17eb56f756..0574cf75e3 100644 --- a/samples/boot/authcodegrant-webflux/src/main/java/sample/config/WebClientConfig.java +++ b/samples/boot/authcodegrant-webflux/src/main/java/sample/config/WebClientConfig.java @@ -18,7 +18,9 @@ package sample.config; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.web.reactive.function.client.WebClient; /** @@ -29,9 +31,10 @@ import org.springframework.web.reactive.function.client.WebClient; public class WebClientConfig { @Bean - WebClient webClient() { + WebClient webClient(ReactiveClientRegistrationRepository clientRegistrationRepository, + ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { return WebClient.builder() - .filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction()) + .filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction(clientRegistrationRepository, authorizedClientRepository)) .build(); } }