Use the custom ServerRequestCache that the user configures
on for the default authentication entry point and authentication success handler Fixes gh-7721 https://github.com/spring-projects/spring-security/issues/7721 Set RequestCache on the Oauth2LoginSpec default authentication success handler import static ReflectionTestUtils.getField Feedback incorporated per https://github.com/spring-projects/spring-security/pull/7734#pullrequestreview-332150359
This commit is contained in:
parent
65f5c29316
commit
9aa333ca4d
|
@ -76,9 +76,11 @@ import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserSer
|
|||
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
|
||||
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter;
|
||||
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
|
||||
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
|
||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter;
|
||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
|
||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
||||
import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository;
|
||||
import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
||||
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
|
||||
|
@ -984,7 +986,7 @@ public class ServerHttpSecurity {
|
|||
|
||||
private ServerWebExchangeMatcher authenticationMatcher;
|
||||
|
||||
private ServerAuthenticationSuccessHandler authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler();
|
||||
private ServerAuthenticationSuccessHandler authenticationSuccessHandler;
|
||||
|
||||
private ServerAuthenticationFailureHandler authenticationFailureHandler;
|
||||
|
||||
|
@ -1175,7 +1177,7 @@ public class ServerHttpSecurity {
|
|||
authenticationFilter.setRequiresAuthenticationMatcher(getAuthenticationMatcher());
|
||||
authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository));
|
||||
|
||||
authenticationFilter.setAuthenticationSuccessHandler(this.authenticationSuccessHandler);
|
||||
authenticationFilter.setAuthenticationSuccessHandler(getAuthenticationSuccessHandler(http));
|
||||
authenticationFilter.setAuthenticationFailureHandler(getAuthenticationFailureHandler());
|
||||
authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
|
||||
|
||||
|
@ -1183,16 +1185,29 @@ public class ServerHttpSecurity {
|
|||
MediaType.TEXT_HTML);
|
||||
htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL));
|
||||
Map<String, String> urlToText = http.oauth2Login.getLinks();
|
||||
String authenticationEntryPointRedirectPath;
|
||||
if (urlToText.size() == 1) {
|
||||
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint(urlToText.keySet().iterator().next())));
|
||||
authenticationEntryPointRedirectPath = urlToText.keySet().iterator().next();
|
||||
} else {
|
||||
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint("/login")));
|
||||
authenticationEntryPointRedirectPath = "/login";
|
||||
}
|
||||
RedirectServerAuthenticationEntryPoint entryPoint = new RedirectServerAuthenticationEntryPoint(authenticationEntryPointRedirectPath);
|
||||
entryPoint.setRequestCache(http.requestCache.requestCache);
|
||||
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, entryPoint));
|
||||
|
||||
http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC);
|
||||
http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION);
|
||||
}
|
||||
|
||||
private ServerAuthenticationSuccessHandler getAuthenticationSuccessHandler(ServerHttpSecurity http) {
|
||||
if (this.authenticationSuccessHandler == null) {
|
||||
RedirectServerAuthenticationSuccessHandler handler = new RedirectServerAuthenticationSuccessHandler();
|
||||
handler.setRequestCache(http.requestCache.requestCache);
|
||||
this.authenticationSuccessHandler = handler;
|
||||
}
|
||||
return this.authenticationSuccessHandler;
|
||||
}
|
||||
|
||||
private ServerAuthenticationFailureHandler getAuthenticationFailureHandler() {
|
||||
if (this.authenticationFailureHandler == null) {
|
||||
this.authenticationFailureHandler = new RedirectServerAuthenticationFailureHandler("/login?error");
|
||||
|
|
|
@ -20,10 +20,12 @@ import static org.assertj.core.api.Assertions.assertThat;
|
|||
import static org.mockito.BDDMockito.given;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.spy;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyZeroInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.security.config.Customizer.withDefaults;
|
||||
import static org.springframework.test.util.ReflectionTestUtils.getField;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
@ -35,16 +37,20 @@ import org.apache.http.HttpHeaders;
|
|||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
||||
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
|
||||
import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
||||
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
|
||||
import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
|
||||
import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
|
||||
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
|
||||
import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.test.publisher.TestPublisher;
|
||||
|
||||
|
@ -64,7 +70,6 @@ import org.springframework.security.web.server.context.WebSessionServerSecurityC
|
|||
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
|
||||
import org.springframework.security.web.server.csrf.CsrfWebFilter;
|
||||
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
|
||||
import org.springframework.test.util.ReflectionTestUtils;
|
||||
import org.springframework.test.web.reactive.server.EntityExchangeResult;
|
||||
import org.springframework.test.web.reactive.server.FluxExchangeResult;
|
||||
import org.springframework.test.web.reactive.server.WebTestClient;
|
||||
|
@ -200,7 +205,7 @@ public class ServerHttpSecurityTests {
|
|||
.isNotPresent();
|
||||
|
||||
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
|
||||
.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
|
||||
.map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
|
||||
|
||||
assertThat(logoutHandler)
|
||||
.get()
|
||||
|
@ -213,17 +218,17 @@ public class ServerHttpSecurityTests {
|
|||
|
||||
assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class))
|
||||
.get()
|
||||
.extracting(csrfWebFilter -> ReflectionTestUtils.getField(csrfWebFilter, "csrfTokenRepository"))
|
||||
.extracting(csrfWebFilter -> getField(csrfWebFilter, "csrfTokenRepository"))
|
||||
.isEqualTo(this.csrfTokenRepository);
|
||||
|
||||
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
|
||||
.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
|
||||
.map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
|
||||
|
||||
assertThat(logoutHandler)
|
||||
.get()
|
||||
.isExactlyInstanceOf(DelegatingServerLogoutHandler.class)
|
||||
.extracting(delegatingLogoutHandler ->
|
||||
((List<ServerLogoutHandler>) ReflectionTestUtils.getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
|
||||
((List<ServerLogoutHandler>) getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
|
||||
.map(ServerLogoutHandler::getClass)
|
||||
.collect(Collectors.toList()))
|
||||
.isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class));
|
||||
|
@ -479,6 +484,33 @@ public class ServerHttpSecurityTests {
|
|||
verify(customServerCsrfTokenRepository).loadToken(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldConfigureRequestCacheForOAuth2LoginAuthenticationEntryPointAndSuccessHandler() {
|
||||
ServerRequestCache requestCache = spy(new WebSessionServerRequestCache());
|
||||
ReactiveClientRegistrationRepository clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class);
|
||||
|
||||
SecurityWebFilterChain securityFilterChain = this.http
|
||||
.oauth2Login()
|
||||
.clientRegistrationRepository(clientRegistrationRepository)
|
||||
.and()
|
||||
.authorizeExchange().anyExchange().authenticated()
|
||||
.and()
|
||||
.requestCache(c -> c.requestCache(requestCache))
|
||||
.build();
|
||||
|
||||
WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build();
|
||||
client.get().uri("/test").exchange();
|
||||
ArgumentCaptor<ServerWebExchange> captor = ArgumentCaptor.forClass(ServerWebExchange.class);
|
||||
verify(requestCache).saveRequest(captor.capture());
|
||||
assertThat(captor.getValue().getRequest().getURI().toString()).isEqualTo("/test");
|
||||
|
||||
|
||||
OAuth2LoginAuthenticationWebFilter authenticationWebFilter =
|
||||
getWebFilter(securityFilterChain, OAuth2LoginAuthenticationWebFilter.class).get();
|
||||
Object handler = getField(authenticationWebFilter, "authenticationSuccessHandler");
|
||||
assertThat(getField(handler, "requestCache")).isSameAs(requestCache);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() {
|
||||
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class);
|
||||
|
@ -503,7 +535,7 @@ public class ServerHttpSecurityTests {
|
|||
|
||||
private boolean isX509Filter(WebFilter filter) {
|
||||
try {
|
||||
Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");
|
||||
Object converter = getField(filter, "authenticationConverter");
|
||||
return converter.getClass().isAssignableFrom(ServerX509AuthenticationConverter.class);
|
||||
} catch (IllegalArgumentException e) {
|
||||
// field doesn't exist
|
||||
|
|
Loading…
Reference in New Issue