Use the custom ServerRequestCache that the user configures

on for the default authentication entry point and authentication
success handler

Fixes gh-7721

https://github.com/spring-projects/spring-security/issues/7721

Set RequestCache on the Oauth2LoginSpec default authentication success handler

import static ReflectionTestUtils.getField

Feedback incorporated per

https://github.com/spring-projects/spring-security/pull/7734#pullrequestreview-332150359
This commit is contained in:
Filip Hanik 2019-12-12 11:54:46 -08:00
parent 0d24e2b8cf
commit b754a3d635
2 changed files with 57 additions and 10 deletions

View File

@ -76,9 +76,11 @@ import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserSer
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter;
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
@ -984,7 +986,7 @@ public class ServerHttpSecurity {
private ServerWebExchangeMatcher authenticationMatcher;
private ServerAuthenticationSuccessHandler authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler();
private ServerAuthenticationSuccessHandler authenticationSuccessHandler;
private ServerAuthenticationFailureHandler authenticationFailureHandler;
@ -1175,7 +1177,7 @@ public class ServerHttpSecurity {
authenticationFilter.setRequiresAuthenticationMatcher(getAuthenticationMatcher());
authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository));
authenticationFilter.setAuthenticationSuccessHandler(this.authenticationSuccessHandler);
authenticationFilter.setAuthenticationSuccessHandler(getAuthenticationSuccessHandler(http));
authenticationFilter.setAuthenticationFailureHandler(getAuthenticationFailureHandler());
authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
@ -1183,16 +1185,29 @@ public class ServerHttpSecurity {
MediaType.TEXT_HTML);
htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL));
Map<String, String> urlToText = http.oauth2Login.getLinks();
String authenticationEntryPointRedirectPath;
if (urlToText.size() == 1) {
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint(urlToText.keySet().iterator().next())));
authenticationEntryPointRedirectPath = urlToText.keySet().iterator().next();
} else {
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint("/login")));
authenticationEntryPointRedirectPath = "/login";
}
RedirectServerAuthenticationEntryPoint entryPoint = new RedirectServerAuthenticationEntryPoint(authenticationEntryPointRedirectPath);
entryPoint.setRequestCache(http.requestCache.requestCache);
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, entryPoint));
http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC);
http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION);
}
private ServerAuthenticationSuccessHandler getAuthenticationSuccessHandler(ServerHttpSecurity http) {
if (this.authenticationSuccessHandler == null) {
RedirectServerAuthenticationSuccessHandler handler = new RedirectServerAuthenticationSuccessHandler();
handler.setRequestCache(http.requestCache.requestCache);
this.authenticationSuccessHandler = handler;
}
return this.authenticationSuccessHandler;
}
private ServerAuthenticationFailureHandler getAuthenticationFailureHandler() {
if (this.authenticationFailureHandler == null) {
this.authenticationFailureHandler = new RedirectServerAuthenticationFailureHandler("/login?error");

View File

@ -20,10 +20,12 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.test.util.ReflectionTestUtils.getField;
import java.util.Arrays;
import java.util.List;
@ -35,16 +37,20 @@ import org.apache.http.HttpHeaders;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
import reactor.core.publisher.Mono;
import reactor.test.publisher.TestPublisher;
@ -64,7 +70,6 @@ import org.springframework.security.web.server.context.WebSessionServerSecurityC
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
import org.springframework.security.web.server.csrf.CsrfWebFilter;
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.reactive.server.EntityExchangeResult;
import org.springframework.test.web.reactive.server.FluxExchangeResult;
import org.springframework.test.web.reactive.server.WebTestClient;
@ -200,7 +205,7 @@ public class ServerHttpSecurityTests {
.isNotPresent();
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
.map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
assertThat(logoutHandler)
.get()
@ -213,17 +218,17 @@ public class ServerHttpSecurityTests {
assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class))
.get()
.extracting(csrfWebFilter -> ReflectionTestUtils.getField(csrfWebFilter, "csrfTokenRepository"))
.extracting(csrfWebFilter -> getField(csrfWebFilter, "csrfTokenRepository"))
.isEqualTo(this.csrfTokenRepository);
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
.map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
assertThat(logoutHandler)
.get()
.isExactlyInstanceOf(DelegatingServerLogoutHandler.class)
.extracting(delegatingLogoutHandler ->
((List<ServerLogoutHandler>) ReflectionTestUtils.getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
((List<ServerLogoutHandler>) getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
.map(ServerLogoutHandler::getClass)
.collect(Collectors.toList()))
.isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class));
@ -479,6 +484,33 @@ public class ServerHttpSecurityTests {
verify(customServerCsrfTokenRepository).loadToken(any());
}
@Test
public void shouldConfigureRequestCacheForOAuth2LoginAuthenticationEntryPointAndSuccessHandler() {
ServerRequestCache requestCache = spy(new WebSessionServerRequestCache());
ReactiveClientRegistrationRepository clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class);
SecurityWebFilterChain securityFilterChain = this.http
.oauth2Login()
.clientRegistrationRepository(clientRegistrationRepository)
.and()
.authorizeExchange().anyExchange().authenticated()
.and()
.requestCache(c -> c.requestCache(requestCache))
.build();
WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build();
client.get().uri("/test").exchange();
ArgumentCaptor<ServerWebExchange> captor = ArgumentCaptor.forClass(ServerWebExchange.class);
verify(requestCache).saveRequest(captor.capture());
assertThat(captor.getValue().getRequest().getURI().toString()).isEqualTo("/test");
OAuth2LoginAuthenticationWebFilter authenticationWebFilter =
getWebFilter(securityFilterChain, OAuth2LoginAuthenticationWebFilter.class).get();
Object handler = getField(authenticationWebFilter, "authenticationSuccessHandler");
assertThat(getField(handler, "requestCache")).isSameAs(requestCache);
}
@Test
public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() {
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class);
@ -503,7 +535,7 @@ public class ServerHttpSecurityTests {
private boolean isX509Filter(WebFilter filter) {
try {
Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");
Object converter = getField(filter, "authenticationConverter");
return converter.getClass().isAssignableFrom(ServerX509AuthenticationConverter.class);
} catch (IllegalArgumentException e) {
// field doesn't exist