diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index a96da865cc..e395752cdd 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -76,9 +76,11 @@ import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserSer import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter; import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter; +import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.oidc.user.OidcUser; @@ -984,7 +986,7 @@ public class ServerHttpSecurity { private ServerWebExchangeMatcher authenticationMatcher; - private ServerAuthenticationSuccessHandler authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler(); + private ServerAuthenticationSuccessHandler authenticationSuccessHandler; private ServerAuthenticationFailureHandler authenticationFailureHandler; @@ -1175,7 +1177,7 @@ public class ServerHttpSecurity { authenticationFilter.setRequiresAuthenticationMatcher(getAuthenticationMatcher()); authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository)); - authenticationFilter.setAuthenticationSuccessHandler(this.authenticationSuccessHandler); + authenticationFilter.setAuthenticationSuccessHandler(getAuthenticationSuccessHandler(http)); authenticationFilter.setAuthenticationFailureHandler(getAuthenticationFailureHandler()); authenticationFilter.setSecurityContextRepository(this.securityContextRepository); @@ -1183,16 +1185,29 @@ public class ServerHttpSecurity { MediaType.TEXT_HTML); htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL)); Map urlToText = http.oauth2Login.getLinks(); + String authenticationEntryPointRedirectPath; if (urlToText.size() == 1) { - http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint(urlToText.keySet().iterator().next()))); + authenticationEntryPointRedirectPath = urlToText.keySet().iterator().next(); } else { - http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint("/login"))); + authenticationEntryPointRedirectPath = "/login"; } + RedirectServerAuthenticationEntryPoint entryPoint = new RedirectServerAuthenticationEntryPoint(authenticationEntryPointRedirectPath); + entryPoint.setRequestCache(http.requestCache.requestCache); + http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, entryPoint)); http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC); http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION); } + private ServerAuthenticationSuccessHandler getAuthenticationSuccessHandler(ServerHttpSecurity http) { + if (this.authenticationSuccessHandler == null) { + RedirectServerAuthenticationSuccessHandler handler = new RedirectServerAuthenticationSuccessHandler(); + handler.setRequestCache(http.requestCache.requestCache); + this.authenticationSuccessHandler = handler; + } + return this.authenticationSuccessHandler; + } + private ServerAuthenticationFailureHandler getAuthenticationFailureHandler() { if (this.authenticationFailureHandler == null) { this.authenticationFailureHandler = new RedirectServerAuthenticationFailureHandler("/login?error"); diff --git a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java index c84e79e64a..052b6629a4 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java @@ -20,10 +20,12 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.BDDMockito.given; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; import static org.springframework.security.config.Customizer.withDefaults; +import static org.springframework.test.util.ReflectionTestUtils.getField; import java.util.Arrays; import java.util.List; @@ -35,16 +37,20 @@ import org.apache.http.HttpHeaders; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; +import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor; import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter; +import org.springframework.security.web.server.savedrequest.ServerRequestCache; +import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache; import reactor.core.publisher.Mono; import reactor.test.publisher.TestPublisher; @@ -64,7 +70,6 @@ import org.springframework.security.web.server.context.WebSessionServerSecurityC import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler; import org.springframework.security.web.server.csrf.CsrfWebFilter; import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository; -import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.reactive.server.EntityExchangeResult; import org.springframework.test.web.reactive.server.FluxExchangeResult; import org.springframework.test.web.reactive.server.WebTestClient; @@ -200,7 +205,7 @@ public class ServerHttpSecurityTests { .isNotPresent(); Optional logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class) - .map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler")); + .map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler")); assertThat(logoutHandler) .get() @@ -213,17 +218,17 @@ public class ServerHttpSecurityTests { assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class)) .get() - .extracting(csrfWebFilter -> ReflectionTestUtils.getField(csrfWebFilter, "csrfTokenRepository")) + .extracting(csrfWebFilter -> getField(csrfWebFilter, "csrfTokenRepository")) .isEqualTo(this.csrfTokenRepository); Optional logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class) - .map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler")); + .map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler")); assertThat(logoutHandler) .get() .isExactlyInstanceOf(DelegatingServerLogoutHandler.class) .extracting(delegatingLogoutHandler -> - ((List) ReflectionTestUtils.getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream() + ((List) getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream() .map(ServerLogoutHandler::getClass) .collect(Collectors.toList())) .isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class)); @@ -479,6 +484,33 @@ public class ServerHttpSecurityTests { verify(customServerCsrfTokenRepository).loadToken(any()); } + @Test + public void shouldConfigureRequestCacheForOAuth2LoginAuthenticationEntryPointAndSuccessHandler() { + ServerRequestCache requestCache = spy(new WebSessionServerRequestCache()); + ReactiveClientRegistrationRepository clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class); + + SecurityWebFilterChain securityFilterChain = this.http + .oauth2Login() + .clientRegistrationRepository(clientRegistrationRepository) + .and() + .authorizeExchange().anyExchange().authenticated() + .and() + .requestCache(c -> c.requestCache(requestCache)) + .build(); + + WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build(); + client.get().uri("/test").exchange(); + ArgumentCaptor captor = ArgumentCaptor.forClass(ServerWebExchange.class); + verify(requestCache).saveRequest(captor.capture()); + assertThat(captor.getValue().getRequest().getURI().toString()).isEqualTo("/test"); + + + OAuth2LoginAuthenticationWebFilter authenticationWebFilter = + getWebFilter(securityFilterChain, OAuth2LoginAuthenticationWebFilter.class).get(); + Object handler = getField(authenticationWebFilter, "authenticationSuccessHandler"); + assertThat(getField(handler, "requestCache")).isSameAs(requestCache); + } + @Test public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() { ServerAuthorizationRequestRepository authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class); @@ -503,7 +535,7 @@ public class ServerHttpSecurityTests { private boolean isX509Filter(WebFilter filter) { try { - Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter"); + Object converter = getField(filter, "authenticationConverter"); return converter.getClass().isAssignableFrom(ServerX509AuthenticationConverter.class); } catch (IllegalArgumentException e) { // field doesn't exist