diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java index 1ccfdfd99c..f8f61f355d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java @@ -27,6 +27,7 @@ import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient; @@ -86,6 +87,8 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); + private boolean defaultOAuth2AuthorizedClient; + private ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = new WebClientReactiveClientCredentialsTokenResponseClient(); @@ -174,6 +177,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId); } + /** + * If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is + * recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be + * resolved from the current Authentication. + * @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false. + * Default is false. + */ + public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) { + this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient; + } + /** * Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for * client_credentials grant. @@ -216,14 +230,25 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements if (this.authorizedClientRepository == null) { return Mono.empty(); } - String clientRegistrationId = (String) request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME); - if (clientRegistrationId == null) { - return Mono.empty(); - } + ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME); return currentAuthentication() - .flatMap(principal -> loadAuthorizedClient(clientRegistrationId, exchange, principal) - ); + .flatMap(principal -> clientRegistrationId(request, principal) + .flatMap(clientRegistrationId -> loadAuthorizedClient(clientRegistrationId, exchange, principal)) + ); + } + + private Mono clientRegistrationId(ClientRequest request, Authentication authentication) { + return Mono.justOrEmpty(request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME)) + .cast(String.class) + .switchIfEmpty(clientRegistrationId(authentication)); + } + + private Mono clientRegistrationId(Authentication authentication) { + return Mono.justOrEmpty(authentication) + .filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken) + .cast(OAuth2AuthenticationToken.class) + .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId); } private Mono loadAuthorizedClient(String clientRegistrationId, diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 630ae529e9..98bf156d07 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -34,8 +34,10 @@ import org.springframework.http.codec.multipart.MultipartHttpMessageWriter; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.mock.http.client.reactive.MockClientHttpRequest; import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; @@ -43,6 +45,8 @@ import org.springframework.security.oauth2.client.web.server.ServerOAuth2Authori import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.user.DefaultOAuth2User; +import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.web.reactive.function.BodyInserter; import org.springframework.web.reactive.function.client.ClientRequest; import reactor.core.publisher.Mono; @@ -51,6 +55,7 @@ import java.net.URI; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -60,6 +65,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; import static org.springframework.http.HttpMethod.GET; import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId; @@ -293,6 +299,59 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { assertThat(getBody(request0)).isEmpty(); } + @Test + public void filterWhenClientRegistrationIdFromAuthenticationThenAuthorizedClientResolved() { + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); + this.function.setDefaultOAuth2AuthorizedClient(true); + + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt()); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, + "principalName", this.accessToken, refreshToken); + when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient)); + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration)); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .build(); + + OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections + .singletonMap("user", "rob"), "user"); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(), "client-id"); + this.function + .filter(request, this.exchange) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .block(); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + + ClientRequest request0 = requests.get(0); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request0.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request0)).isEmpty(); + } + + @Test + public void filterWhenDefaultOAuth2AuthorizedClientFalseThenEmpty() { + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); + + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .build(); + + OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections + .singletonMap("user", "rob"), "user"); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(), "client-id"); + + this.function + .filter(request, this.exchange) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .block(); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + + verifyZeroInteractions(this.clientRegistrationRepository, this.authorizedClientRepository); + } + private static String getBody(ClientRequest request) { final List> messageWriters = new ArrayList<>(); messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));