From 89f2874bff9e06a5cbd1f34e29992e92836bac30 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Fri, 24 Aug 2018 14:55:31 -0500 Subject: [PATCH] ServerOAuth2AuthorizedClientExchangeFilterFunction clientRegistrationId You can now provide the clientRegistrationId and ServerOAuth2AuthorizedClientExchangeFilterFunction will look up the authorized client automatically. Issue: gh-4921 --- ...uthorizedClientExchangeFilterFunction.java | 71 ++++++++++++++++--- ...izedClientExchangeFilterFunctionTests.java | 25 +++++++ 2 files changed, 87 insertions(+), 9 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java index 309042a772..11c560fccc 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java @@ -19,14 +19,19 @@ package org.springframework.security.oauth2.client.web.reactive.function.client; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; +import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2ClientException; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.util.Assert; import org.springframework.web.reactive.function.BodyInserters; @@ -61,10 +66,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements */ private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName(); + /** + * The client request attribute name used to locate the {@link ClientRegistration#getRegistrationId()} + */ + private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient.class.getName().concat(".CLIENT_REGISTRATION_ID"); + /** * The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}. */ private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange.class.getName(); + public static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser", + AuthorityUtils.createAuthorityList("ROLE_USER")); private Clock clock = Clock.systemUTC(); @@ -74,8 +86,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements public ServerOAuth2AuthorizedClientExchangeFilterFunction() {} - public ServerOAuth2AuthorizedClientExchangeFilterFunction( - ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + public ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { this.authorizedClientRepository = authorizedClientRepository; } @@ -141,6 +152,18 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements return attributes -> attributes.put(SERVER_WEB_EXCHANGE_ATTR_NAME, serverWebExchange); } + /** + * Modifies the {@link ClientRequest#attributes()} to include the {@link ClientRegistration#getRegistrationId()} to + * be used to look up the {@link OAuth2AuthorizedClient}. + * + * @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()} to + * be used to look up the {@link OAuth2AuthorizedClient}. + * @return the {@link Consumer} to populate the attributes + */ + public static Consumer> clientRegistrationId(String clientRegistrationId) { + return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId); + } + /** * An access token will be considered expired by comparing its expiration to now + * this skewed Duration. The default is 1 minute. @@ -153,17 +176,42 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements @Override public Mono filter(ClientRequest request, ExchangeFunction next) { - Optional attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME) - .map(OAuth2AuthorizedClient.class::cast); ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME); - return Mono.justOrEmpty(attribute) - .flatMap(authorizedClient -> authorizedClient(next, authorizedClient, exchange)) + return authorizedClient(request) + .flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, exchange)) .map(authorizedClient -> bearer(request, authorizedClient)) .flatMap(next::exchange) .switchIfEmpty(next.exchange(request)); } - private Mono authorizedClient(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) { + private Mono authorizedClient(ClientRequest request) { + Optional attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME) + .map(OAuth2AuthorizedClient.class::cast); + return Mono.justOrEmpty(attribute) + .switchIfEmpty(findAuthorizedClientByRegistrationId(request)); + } + + private Mono findAuthorizedClientByRegistrationId(ClientRequest request) { + if (this.authorizedClientRepository == null) { + return Mono.empty(); + } + String clientRegistrationId = (String) request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME); + if (clientRegistrationId == null) { + return Mono.empty(); + } + ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME); + return currentAuthentication() + .flatMap(principal -> loadAuthorizedClient(clientRegistrationId, exchange, principal) + ); + } + + private Mono loadAuthorizedClient(String clientRegistrationId, + ServerWebExchange exchange, Authentication principal) { + return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange) + .switchIfEmpty(Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId))); + } + + private Mono refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) { if (shouldRefresh(authorizedClient)) { return refreshAuthorizedClient(next, authorizedClient, exchange); } @@ -184,13 +232,18 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements return next.exchange(request) .flatMap(response -> response.body(oauth2AccessTokenResponse())) .map(accessTokenResponse -> new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken())) - .flatMap(result -> ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) + .flatMap(result -> currentAuthentication() .defaultIfEmpty(new PrincipalNameAuthentication(authorizedClient.getPrincipalName())) .flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(result, principal, exchange)) .thenReturn(result)); } + private Mono currentAuthentication() { + return ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication) + .defaultIfEmpty(ANONYMOUS_USER_TOKEN); + } + private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) { if (this.authorizedClientRepository == null) { return false; diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 595ea57f9f..579bda1505 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -61,6 +61,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.springframework.http.HttpMethod.GET; +import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId; import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient; /** @@ -263,6 +264,30 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { assertThat(getBody(request0)).isEmpty(); } + @Test + public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() { + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository); + + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt()); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, + "principalName", this.accessToken, refreshToken); + when(this.auth2AuthorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient)); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(clientRegistrationId(this.registration.getRegistrationId())) + .build(); + + this.function.filter(request, this.exchange).block(); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + + ClientRequest request0 = requests.get(0); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request0.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request0)).isEmpty(); + } + private static String getBody(ClientRequest request) { final List> messageWriters = new ArrayList<>(); messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));