Use the custom ServerRequestCache that the user configures
on for the default authentication entry point and authentication success handler Fixes gh-7721 https://github.com/spring-projects/spring-security/issues/7721 Set RequestCache on the Oauth2LoginSpec default authentication success handler import static ReflectionTestUtils.getField Feedback incorporated per https://github.com/spring-projects/spring-security/pull/7734#pullrequestreview-332150359
This commit is contained in:
parent
65f5c29316
commit
9aa333ca4d
|
@ -76,9 +76,11 @@ import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserSer
|
||||||
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
|
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
|
||||||
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter;
|
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter;
|
||||||
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
|
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
|
||||||
|
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
|
||||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter;
|
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter;
|
||||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
|
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
|
||||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
||||||
|
import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository;
|
||||||
import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
|
import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
|
||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
||||||
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
|
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
|
||||||
|
@ -984,7 +986,7 @@ public class ServerHttpSecurity {
|
||||||
|
|
||||||
private ServerWebExchangeMatcher authenticationMatcher;
|
private ServerWebExchangeMatcher authenticationMatcher;
|
||||||
|
|
||||||
private ServerAuthenticationSuccessHandler authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler();
|
private ServerAuthenticationSuccessHandler authenticationSuccessHandler;
|
||||||
|
|
||||||
private ServerAuthenticationFailureHandler authenticationFailureHandler;
|
private ServerAuthenticationFailureHandler authenticationFailureHandler;
|
||||||
|
|
||||||
|
@ -1175,7 +1177,7 @@ public class ServerHttpSecurity {
|
||||||
authenticationFilter.setRequiresAuthenticationMatcher(getAuthenticationMatcher());
|
authenticationFilter.setRequiresAuthenticationMatcher(getAuthenticationMatcher());
|
||||||
authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository));
|
authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository));
|
||||||
|
|
||||||
authenticationFilter.setAuthenticationSuccessHandler(this.authenticationSuccessHandler);
|
authenticationFilter.setAuthenticationSuccessHandler(getAuthenticationSuccessHandler(http));
|
||||||
authenticationFilter.setAuthenticationFailureHandler(getAuthenticationFailureHandler());
|
authenticationFilter.setAuthenticationFailureHandler(getAuthenticationFailureHandler());
|
||||||
authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
|
authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
|
||||||
|
|
||||||
|
@ -1183,16 +1185,29 @@ public class ServerHttpSecurity {
|
||||||
MediaType.TEXT_HTML);
|
MediaType.TEXT_HTML);
|
||||||
htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL));
|
htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL));
|
||||||
Map<String, String> urlToText = http.oauth2Login.getLinks();
|
Map<String, String> urlToText = http.oauth2Login.getLinks();
|
||||||
|
String authenticationEntryPointRedirectPath;
|
||||||
if (urlToText.size() == 1) {
|
if (urlToText.size() == 1) {
|
||||||
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint(urlToText.keySet().iterator().next())));
|
authenticationEntryPointRedirectPath = urlToText.keySet().iterator().next();
|
||||||
} else {
|
} else {
|
||||||
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint("/login")));
|
authenticationEntryPointRedirectPath = "/login";
|
||||||
}
|
}
|
||||||
|
RedirectServerAuthenticationEntryPoint entryPoint = new RedirectServerAuthenticationEntryPoint(authenticationEntryPointRedirectPath);
|
||||||
|
entryPoint.setRequestCache(http.requestCache.requestCache);
|
||||||
|
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, entryPoint));
|
||||||
|
|
||||||
http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC);
|
http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC);
|
||||||
http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION);
|
http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private ServerAuthenticationSuccessHandler getAuthenticationSuccessHandler(ServerHttpSecurity http) {
|
||||||
|
if (this.authenticationSuccessHandler == null) {
|
||||||
|
RedirectServerAuthenticationSuccessHandler handler = new RedirectServerAuthenticationSuccessHandler();
|
||||||
|
handler.setRequestCache(http.requestCache.requestCache);
|
||||||
|
this.authenticationSuccessHandler = handler;
|
||||||
|
}
|
||||||
|
return this.authenticationSuccessHandler;
|
||||||
|
}
|
||||||
|
|
||||||
private ServerAuthenticationFailureHandler getAuthenticationFailureHandler() {
|
private ServerAuthenticationFailureHandler getAuthenticationFailureHandler() {
|
||||||
if (this.authenticationFailureHandler == null) {
|
if (this.authenticationFailureHandler == null) {
|
||||||
this.authenticationFailureHandler = new RedirectServerAuthenticationFailureHandler("/login?error");
|
this.authenticationFailureHandler = new RedirectServerAuthenticationFailureHandler("/login?error");
|
||||||
|
|
|
@ -20,10 +20,12 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||||
import static org.mockito.BDDMockito.given;
|
import static org.mockito.BDDMockito.given;
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.spy;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
import static org.mockito.Mockito.verifyZeroInteractions;
|
import static org.mockito.Mockito.verifyZeroInteractions;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
import static org.springframework.security.config.Customizer.withDefaults;
|
import static org.springframework.security.config.Customizer.withDefaults;
|
||||||
|
import static org.springframework.test.util.ReflectionTestUtils.getField;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -35,16 +37,20 @@ import org.apache.http.HttpHeaders;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
import org.mockito.Mock;
|
import org.mockito.Mock;
|
||||||
import org.mockito.junit.MockitoJUnitRunner;
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
|
|
||||||
import org.springframework.security.core.Authentication;
|
import org.springframework.security.core.Authentication;
|
||||||
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
||||||
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
|
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
|
||||||
|
import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
|
||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
||||||
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
|
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
|
||||||
import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
|
import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
|
||||||
import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
|
import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
|
||||||
|
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
|
||||||
|
import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
import reactor.test.publisher.TestPublisher;
|
import reactor.test.publisher.TestPublisher;
|
||||||
|
|
||||||
|
@ -64,7 +70,6 @@ import org.springframework.security.web.server.context.WebSessionServerSecurityC
|
||||||
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
|
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
|
||||||
import org.springframework.security.web.server.csrf.CsrfWebFilter;
|
import org.springframework.security.web.server.csrf.CsrfWebFilter;
|
||||||
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
|
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
|
||||||
import org.springframework.test.util.ReflectionTestUtils;
|
|
||||||
import org.springframework.test.web.reactive.server.EntityExchangeResult;
|
import org.springframework.test.web.reactive.server.EntityExchangeResult;
|
||||||
import org.springframework.test.web.reactive.server.FluxExchangeResult;
|
import org.springframework.test.web.reactive.server.FluxExchangeResult;
|
||||||
import org.springframework.test.web.reactive.server.WebTestClient;
|
import org.springframework.test.web.reactive.server.WebTestClient;
|
||||||
|
@ -200,7 +205,7 @@ public class ServerHttpSecurityTests {
|
||||||
.isNotPresent();
|
.isNotPresent();
|
||||||
|
|
||||||
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
|
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
|
||||||
.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
|
.map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
|
||||||
|
|
||||||
assertThat(logoutHandler)
|
assertThat(logoutHandler)
|
||||||
.get()
|
.get()
|
||||||
|
@ -213,17 +218,17 @@ public class ServerHttpSecurityTests {
|
||||||
|
|
||||||
assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class))
|
assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class))
|
||||||
.get()
|
.get()
|
||||||
.extracting(csrfWebFilter -> ReflectionTestUtils.getField(csrfWebFilter, "csrfTokenRepository"))
|
.extracting(csrfWebFilter -> getField(csrfWebFilter, "csrfTokenRepository"))
|
||||||
.isEqualTo(this.csrfTokenRepository);
|
.isEqualTo(this.csrfTokenRepository);
|
||||||
|
|
||||||
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
|
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
|
||||||
.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
|
.map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
|
||||||
|
|
||||||
assertThat(logoutHandler)
|
assertThat(logoutHandler)
|
||||||
.get()
|
.get()
|
||||||
.isExactlyInstanceOf(DelegatingServerLogoutHandler.class)
|
.isExactlyInstanceOf(DelegatingServerLogoutHandler.class)
|
||||||
.extracting(delegatingLogoutHandler ->
|
.extracting(delegatingLogoutHandler ->
|
||||||
((List<ServerLogoutHandler>) ReflectionTestUtils.getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
|
((List<ServerLogoutHandler>) getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
|
||||||
.map(ServerLogoutHandler::getClass)
|
.map(ServerLogoutHandler::getClass)
|
||||||
.collect(Collectors.toList()))
|
.collect(Collectors.toList()))
|
||||||
.isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class));
|
.isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class));
|
||||||
|
@ -479,6 +484,33 @@ public class ServerHttpSecurityTests {
|
||||||
verify(customServerCsrfTokenRepository).loadToken(any());
|
verify(customServerCsrfTokenRepository).loadToken(any());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldConfigureRequestCacheForOAuth2LoginAuthenticationEntryPointAndSuccessHandler() {
|
||||||
|
ServerRequestCache requestCache = spy(new WebSessionServerRequestCache());
|
||||||
|
ReactiveClientRegistrationRepository clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class);
|
||||||
|
|
||||||
|
SecurityWebFilterChain securityFilterChain = this.http
|
||||||
|
.oauth2Login()
|
||||||
|
.clientRegistrationRepository(clientRegistrationRepository)
|
||||||
|
.and()
|
||||||
|
.authorizeExchange().anyExchange().authenticated()
|
||||||
|
.and()
|
||||||
|
.requestCache(c -> c.requestCache(requestCache))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build();
|
||||||
|
client.get().uri("/test").exchange();
|
||||||
|
ArgumentCaptor<ServerWebExchange> captor = ArgumentCaptor.forClass(ServerWebExchange.class);
|
||||||
|
verify(requestCache).saveRequest(captor.capture());
|
||||||
|
assertThat(captor.getValue().getRequest().getURI().toString()).isEqualTo("/test");
|
||||||
|
|
||||||
|
|
||||||
|
OAuth2LoginAuthenticationWebFilter authenticationWebFilter =
|
||||||
|
getWebFilter(securityFilterChain, OAuth2LoginAuthenticationWebFilter.class).get();
|
||||||
|
Object handler = getField(authenticationWebFilter, "authenticationSuccessHandler");
|
||||||
|
assertThat(getField(handler, "requestCache")).isSameAs(requestCache);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() {
|
public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() {
|
||||||
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class);
|
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class);
|
||||||
|
@ -503,7 +535,7 @@ public class ServerHttpSecurityTests {
|
||||||
|
|
||||||
private boolean isX509Filter(WebFilter filter) {
|
private boolean isX509Filter(WebFilter filter) {
|
||||||
try {
|
try {
|
||||||
Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");
|
Object converter = getField(filter, "authenticationConverter");
|
||||||
return converter.getClass().isAssignableFrom(ServerX509AuthenticationConverter.class);
|
return converter.getClass().isAssignableFrom(ServerX509AuthenticationConverter.class);
|
||||||
} catch (IllegalArgumentException e) {
|
} catch (IllegalArgumentException e) {
|
||||||
// field doesn't exist
|
// field doesn't exist
|
||||||
|
|
Loading…
Reference in New Issue