From cd7f6e09b08fb02ee0c6461b63daac9f9cca316f Mon Sep 17 00:00:00 2001 From: Steve Riesenberg <5248162+sjohnr@users.noreply.github.com> Date: Mon, 23 Sep 2024 11:06:12 -0500 Subject: [PATCH] Look up ReactiveOAuth2AccessTokenResponseClient as a bean Closes gh-11097 --- .../config/web/server/ServerHttpSecurity.java | 13 ++- .../web/server/OAuth2ClientSpecTests.java | 107 +++++++++++++++++- 2 files changed, 118 insertions(+), 2 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index cd52e80738..6465e3dc9c 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -4813,11 +4813,22 @@ public class ServerHttpSecurity { private ReactiveAuthenticationManager getAuthenticationManager() { if (this.authenticationManager == null) { this.authenticationManager = new OAuth2AuthorizationCodeReactiveAuthenticationManager( - new WebClientReactiveAuthorizationCodeTokenResponseClient()); + getAuthorizationCodeTokenResponseClient()); } return this.authenticationManager; } + private ReactiveOAuth2AccessTokenResponseClient getAuthorizationCodeTokenResponseClient() { + ResolvableType resolvableType = ResolvableType.forClassWithGenerics( + ReactiveOAuth2AccessTokenResponseClient.class, OAuth2AuthorizationCodeGrantRequest.class); + ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = getBeanOrNull( + resolvableType); + if (accessTokenResponseClient == null) { + accessTokenResponseClient = new WebClientReactiveAuthorizationCodeTokenResponseClient(); + } + return accessTokenResponseClient; + } + /** * Configures the {@link ReactiveClientRegistrationRepository}. Default is to look * the value up as a Bean. diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java index d348d95f8a..0bd8391d71 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,9 +17,11 @@ package org.springframework.security.config.web.server; import java.net.URI; +import java.util.Set; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; import reactor.core.publisher.Mono; import org.springframework.beans.factory.annotation.Autowired; @@ -31,9 +33,12 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity; import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContextExtension; +import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; @@ -41,8 +46,10 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistr import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; @@ -59,7 +66,9 @@ import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.reactive.config.EnableWebFlux; +import org.springframework.web.server.ServerWebExchange; +import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -215,6 +224,62 @@ public class OAuth2ClientSpecTests { verify(requestCache).getRedirectUri(any()); } + @Test + @SuppressWarnings("unchecked") + public void oauth2ClientWhenCustomAccessTokenResponseClientThenUsed() { + this.spring.register(OAuth2ClientBeanConfig.class, AuthorizedClientController.class).autowire(); + ReactiveClientRegistrationRepository clientRegistrationRepository = this.spring.getContext() + .getBean(ReactiveClientRegistrationRepository.class); + given(clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.just(this.registration)); + ServerOAuth2AuthorizedClientRepository authorizedClientRepository = this.spring.getContext() + .getBean(ServerOAuth2AuthorizedClientRepository.class); + given(authorizedClientRepository.saveAuthorizedClient(any(OAuth2AuthorizedClient.class), + any(Authentication.class), any(ServerWebExchange.class))) + .willReturn(Mono.empty()); + ServerAuthorizationRequestRepository authorizationRequestRepository = this.spring + .getContext() + .getBean(ServerAuthorizationRequestRepository.class); + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .redirectUri("/authorize/oauth2/code/registration-id") + .build(); + given(authorizationRequestRepository.loadAuthorizationRequest(any(ServerWebExchange.class))) + .willReturn(Mono.just(authorizationRequest)); + given(authorizationRequestRepository.removeAuthorizationRequest(any(ServerWebExchange.class))) + .willReturn(Mono.just(authorizationRequest)); + ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = this.spring + .getContext() + .getBean(ReactiveOAuth2AccessTokenResponseClient.class); + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("token") + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .scopes(Set.of()) + .expiresIn(300) + .build(); + given(accessTokenResponseClient.getTokenResponse(any(OAuth2AuthorizationCodeGrantRequest.class))) + .willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + this.client.get() + .uri((uriBuilder) -> uriBuilder + .path("/authorize/oauth2/code/registration-id") + .queryParam(OAuth2ParameterNames.CODE, "code") + .queryParam(OAuth2ParameterNames.STATE, "state") + .build() + ) + .exchange() + .expectStatus().is3xxRedirection(); + // @formatter:on + ArgumentCaptor grantRequestArgumentCaptor = ArgumentCaptor + .forClass(OAuth2AuthorizationCodeGrantRequest.class); + verify(accessTokenResponseClient).getTokenResponse(grantRequestArgumentCaptor.capture()); + OAuth2AuthorizationCodeGrantRequest grantRequest = grantRequestArgumentCaptor.getValue(); + assertThat(grantRequest.getClientRegistration()).isEqualTo(this.registration); + assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); + assertThat(grantRequest.getAuthorizationExchange().getAuthorizationRequest()).isEqualTo(authorizationRequest); + assertThat(grantRequest.getAuthorizationExchange().getAuthorizationResponse().getCode()).isEqualTo("code"); + assertThat(grantRequest.getAuthorizationExchange().getAuthorizationResponse().getState()).isEqualTo("state"); + assertThat(grantRequest.getAuthorizationExchange().getAuthorizationResponse().getRedirectUri()) + .startsWith("/authorize/oauth2/code/registration-id"); + } + @Configuration @EnableWebFlux @EnableWebFluxSecurity @@ -324,4 +389,44 @@ public class OAuth2ClientSpecTests { } + @Configuration + @EnableWebFlux + @EnableWebFluxSecurity + static class OAuth2ClientBeanConfig { + + @Bean + SecurityWebFilterChain securityWebFilterChain(ServerHttpSecurity http) { + // @formatter:off + http + .oauth2Client((oauth2Client) -> oauth2Client + .authorizationRequestRepository(authorizationRequestRepository()) + ); + // @formatter:on + return http.build(); + } + + @Bean + @SuppressWarnings("unchecked") + ServerAuthorizationRequestRepository authorizationRequestRepository() { + return mock(ServerAuthorizationRequestRepository.class); + } + + @Bean + @SuppressWarnings("unchecked") + ReactiveOAuth2AccessTokenResponseClient authorizationCodeAccessTokenResponseClient() { + return mock(ReactiveOAuth2AccessTokenResponseClient.class); + } + + @Bean + ReactiveClientRegistrationRepository clientRegistrationRepository() { + return mock(ReactiveClientRegistrationRepository.class); + } + + @Bean + ServerOAuth2AuthorizedClientRepository authorizedClientRepository() { + return mock(ServerOAuth2AuthorizedClientRepository.class); + } + + } + }