diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 19711e54af..20eb5ab45d 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -1084,7 +1084,9 @@ public class ServerHttpSecurity { private ServerAuthenticationConverter getAuthenticationConverter(ReactiveClientRegistrationRepository clientRegistrationRepository) { if (this.authenticationConverter == null) { - this.authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository); + ServerOAuth2AuthorizationCodeAuthenticationTokenConverter authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository); + authenticationConverter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); + this.authenticationConverter = authenticationConverter; } return this.authenticationConverter; } diff --git a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java index c95f8bd17d..c84e79e64a 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java @@ -39,6 +39,10 @@ import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor; import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter; import reactor.core.publisher.Mono; @@ -475,6 +479,28 @@ public class ServerHttpSecurityTests { verify(customServerCsrfTokenRepository).loadToken(any()); } + @Test + public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() { + ServerAuthorizationRequestRepository authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class); + ReactiveClientRegistrationRepository clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class); + + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request().build(); + + when(authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(authorizationRequest)); + + SecurityWebFilterChain securityFilterChain = this.http + .oauth2Login() + .clientRegistrationRepository(clientRegistrationRepository) + .authorizationRequestRepository(authorizationRequestRepository) + .and() + .build(); + + WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build(); + client.get().uri("/login/oauth2/code/registration-id").exchange(); + + verify(authorizationRequestRepository).removeAuthorizationRequest(any()); + } + private boolean isX509Filter(WebFilter filter) { try { Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");