Use the custom ServerRequestCache that the user configures

on for the default authentication entry point and authentication
success handler

Fixes gh-7721

https://github.com/spring-projects/spring-security/issues/7721

Set RequestCache on the Oauth2LoginSpec default authentication success handler

import static ReflectionTestUtils.getField

Feedback incorporated per

https://github.com/spring-projects/spring-security/pull/7734#pullrequestreview-332150359
This commit is contained in:
Filip Hanik 2019-12-12 11:54:46 -08:00
parent 65f5c29316
commit 9aa333ca4d
2 changed files with 57 additions and 10 deletions

View File

@ -76,9 +76,11 @@ import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserSer
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter; import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter;
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter; import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter; import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUser;
@ -984,7 +986,7 @@ public class ServerHttpSecurity {
private ServerWebExchangeMatcher authenticationMatcher; private ServerWebExchangeMatcher authenticationMatcher;
private ServerAuthenticationSuccessHandler authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler(); private ServerAuthenticationSuccessHandler authenticationSuccessHandler;
private ServerAuthenticationFailureHandler authenticationFailureHandler; private ServerAuthenticationFailureHandler authenticationFailureHandler;
@ -1175,7 +1177,7 @@ public class ServerHttpSecurity {
authenticationFilter.setRequiresAuthenticationMatcher(getAuthenticationMatcher()); authenticationFilter.setRequiresAuthenticationMatcher(getAuthenticationMatcher());
authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository)); authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository));
authenticationFilter.setAuthenticationSuccessHandler(this.authenticationSuccessHandler); authenticationFilter.setAuthenticationSuccessHandler(getAuthenticationSuccessHandler(http));
authenticationFilter.setAuthenticationFailureHandler(getAuthenticationFailureHandler()); authenticationFilter.setAuthenticationFailureHandler(getAuthenticationFailureHandler());
authenticationFilter.setSecurityContextRepository(this.securityContextRepository); authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
@ -1183,16 +1185,29 @@ public class ServerHttpSecurity {
MediaType.TEXT_HTML); MediaType.TEXT_HTML);
htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL)); htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL));
Map<String, String> urlToText = http.oauth2Login.getLinks(); Map<String, String> urlToText = http.oauth2Login.getLinks();
String authenticationEntryPointRedirectPath;
if (urlToText.size() == 1) { if (urlToText.size() == 1) {
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint(urlToText.keySet().iterator().next()))); authenticationEntryPointRedirectPath = urlToText.keySet().iterator().next();
} else { } else {
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint("/login"))); authenticationEntryPointRedirectPath = "/login";
} }
RedirectServerAuthenticationEntryPoint entryPoint = new RedirectServerAuthenticationEntryPoint(authenticationEntryPointRedirectPath);
entryPoint.setRequestCache(http.requestCache.requestCache);
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, entryPoint));
http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC); http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC);
http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION); http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION);
} }
private ServerAuthenticationSuccessHandler getAuthenticationSuccessHandler(ServerHttpSecurity http) {
if (this.authenticationSuccessHandler == null) {
RedirectServerAuthenticationSuccessHandler handler = new RedirectServerAuthenticationSuccessHandler();
handler.setRequestCache(http.requestCache.requestCache);
this.authenticationSuccessHandler = handler;
}
return this.authenticationSuccessHandler;
}
private ServerAuthenticationFailureHandler getAuthenticationFailureHandler() { private ServerAuthenticationFailureHandler getAuthenticationFailureHandler() {
if (this.authenticationFailureHandler == null) { if (this.authenticationFailureHandler == null) {
this.authenticationFailureHandler = new RedirectServerAuthenticationFailureHandler("/login?error"); this.authenticationFailureHandler = new RedirectServerAuthenticationFailureHandler("/login?error");

View File

@ -20,10 +20,12 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.test.util.ReflectionTestUtils.getField;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
@ -35,16 +37,20 @@ import org.apache.http.HttpHeaders;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor; import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter; import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.test.publisher.TestPublisher; import reactor.test.publisher.TestPublisher;
@ -64,7 +70,6 @@ import org.springframework.security.web.server.context.WebSessionServerSecurityC
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler; import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
import org.springframework.security.web.server.csrf.CsrfWebFilter; import org.springframework.security.web.server.csrf.CsrfWebFilter;
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository; import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.reactive.server.EntityExchangeResult; import org.springframework.test.web.reactive.server.EntityExchangeResult;
import org.springframework.test.web.reactive.server.FluxExchangeResult; import org.springframework.test.web.reactive.server.FluxExchangeResult;
import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.test.web.reactive.server.WebTestClient;
@ -200,7 +205,7 @@ public class ServerHttpSecurityTests {
.isNotPresent(); .isNotPresent();
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class) Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler")); .map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
assertThat(logoutHandler) assertThat(logoutHandler)
.get() .get()
@ -213,17 +218,17 @@ public class ServerHttpSecurityTests {
assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class)) assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class))
.get() .get()
.extracting(csrfWebFilter -> ReflectionTestUtils.getField(csrfWebFilter, "csrfTokenRepository")) .extracting(csrfWebFilter -> getField(csrfWebFilter, "csrfTokenRepository"))
.isEqualTo(this.csrfTokenRepository); .isEqualTo(this.csrfTokenRepository);
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class) Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler")); .map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
assertThat(logoutHandler) assertThat(logoutHandler)
.get() .get()
.isExactlyInstanceOf(DelegatingServerLogoutHandler.class) .isExactlyInstanceOf(DelegatingServerLogoutHandler.class)
.extracting(delegatingLogoutHandler -> .extracting(delegatingLogoutHandler ->
((List<ServerLogoutHandler>) ReflectionTestUtils.getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream() ((List<ServerLogoutHandler>) getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
.map(ServerLogoutHandler::getClass) .map(ServerLogoutHandler::getClass)
.collect(Collectors.toList())) .collect(Collectors.toList()))
.isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class)); .isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class));
@ -479,6 +484,33 @@ public class ServerHttpSecurityTests {
verify(customServerCsrfTokenRepository).loadToken(any()); verify(customServerCsrfTokenRepository).loadToken(any());
} }
@Test
public void shouldConfigureRequestCacheForOAuth2LoginAuthenticationEntryPointAndSuccessHandler() {
ServerRequestCache requestCache = spy(new WebSessionServerRequestCache());
ReactiveClientRegistrationRepository clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class);
SecurityWebFilterChain securityFilterChain = this.http
.oauth2Login()
.clientRegistrationRepository(clientRegistrationRepository)
.and()
.authorizeExchange().anyExchange().authenticated()
.and()
.requestCache(c -> c.requestCache(requestCache))
.build();
WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build();
client.get().uri("/test").exchange();
ArgumentCaptor<ServerWebExchange> captor = ArgumentCaptor.forClass(ServerWebExchange.class);
verify(requestCache).saveRequest(captor.capture());
assertThat(captor.getValue().getRequest().getURI().toString()).isEqualTo("/test");
OAuth2LoginAuthenticationWebFilter authenticationWebFilter =
getWebFilter(securityFilterChain, OAuth2LoginAuthenticationWebFilter.class).get();
Object handler = getField(authenticationWebFilter, "authenticationSuccessHandler");
assertThat(getField(handler, "requestCache")).isSameAs(requestCache);
}
@Test @Test
public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() { public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() {
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class); ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class);
@ -503,7 +535,7 @@ public class ServerHttpSecurityTests {
private boolean isX509Filter(WebFilter filter) { private boolean isX509Filter(WebFilter filter) {
try { try {
Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter"); Object converter = getField(filter, "authenticationConverter");
return converter.getClass().isAssignableFrom(ServerX509AuthenticationConverter.class); return converter.getClass().isAssignableFrom(ServerX509AuthenticationConverter.class);
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
// field doesn't exist // field doesn't exist