Default HttpSessionRequestCache.matchingRequestParameterName=continue

Closes gh-11757
This commit is contained in:
Rob Winch 2022-08-25 12:09:15 -05:00
parent b28efbc4b8
commit f84f08c4b9
12 changed files with 117 additions and 36 deletions

View File

@ -34,7 +34,6 @@ import org.springframework.security.web.DefaultSecurityFilterChain;
import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.csrf.LazyCsrfTokenRepository; import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
@ -81,13 +80,8 @@ public class DeferHttpSessionJavaConfigTests {
DefaultSecurityFilterChain springSecurity(HttpSecurity http) throws Exception { DefaultSecurityFilterChain springSecurity(HttpSecurity http) throws Exception {
LazyCsrfTokenRepository csrfRepository = new LazyCsrfTokenRepository(new HttpSessionCsrfTokenRepository()); LazyCsrfTokenRepository csrfRepository = new LazyCsrfTokenRepository(new HttpSessionCsrfTokenRepository());
csrfRepository.setDeferLoadToken(true); csrfRepository.setDeferLoadToken(true);
HttpSessionRequestCache requestCache = new HttpSessionRequestCache();
requestCache.setMatchingRequestParameterName("continue");
// @formatter:off // @formatter:off
http http
.requestCache((cache) -> cache
.requestCache(requestCache)
)
.securityContext((securityContext) -> securityContext .securityContext((securityContext) -> securityContext
.requireExplicitSave(true) .requireExplicitSave(true)
) )

View File

@ -23,6 +23,7 @@ import java.util.concurrent.Callable;
import com.google.common.net.HttpHeaders; import com.google.common.net.HttpHeaders;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock; import org.mockito.Mock;
@ -54,6 +55,7 @@ import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.provisioning.InMemoryUserDetailsManager;
import org.springframework.security.test.web.servlet.RequestCacheResultMatcher;
import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.header.writers.frameoptions.XFrameOptionsHeaderWriter; import org.springframework.security.web.header.writers.frameoptions.XFrameOptionsHeaderWriter;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
@ -192,8 +194,11 @@ public class HttpSecurityConfigurationTests {
public void authenticateWhenDefaultFilterChainBeanThenRedirectsToSavedRequest() throws Exception { public void authenticateWhenDefaultFilterChainBeanThenRedirectsToSavedRequest() throws Exception {
this.spring.register(SecurityEnabledConfig.class, UserDetailsConfig.class).autowire(); this.spring.register(SecurityEnabledConfig.class, UserDetailsConfig.class).autowire();
// @formatter:off // @formatter:off
MockHttpSession session = (MockHttpSession) this.mockMvc.perform(get("/messages")) MvcResult mvcResult = this.mockMvc.perform(get("/messages"))
.andReturn() .andReturn();
HttpServletRequest request = mvcResult.getRequest();
HttpServletResponse response = mvcResult.getResponse();
MockHttpSession session = (MockHttpSession) mvcResult
.getRequest() .getRequest()
.getSession(); .getSession();
// @formatter:on // @formatter:on
@ -203,10 +208,8 @@ public class HttpSecurityConfigurationTests {
.param("password", "password") .param("password", "password")
.session(session) .session(session)
.with(csrf()); .with(csrf());
// @formatter:on
// @formatter:off
this.mockMvc.perform(loginRequest) this.mockMvc.perform(loginRequest)
.andExpect(redirectedUrl("http://localhost/messages")); .andExpect(RequestCacheResultMatcher.redirectToCachedRequest());
// @formatter:on // @formatter:on
} }

View File

@ -42,6 +42,8 @@ import org.springframework.security.web.authentication.session.SessionAuthentica
import org.springframework.security.web.csrf.CsrfTokenRepository; import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.DefaultCsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.firewall.StrictHttpFirewall; import org.springframework.security.web.firewall.StrictHttpFirewall;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
@ -187,9 +189,11 @@ public class CsrfConfigurerTests {
public void loginWhenCsrfDisabledThenRedirectsToPreviousPostRequest() throws Exception { public void loginWhenCsrfDisabledThenRedirectsToPreviousPostRequest() throws Exception {
this.spring.register(DisableCsrfEnablesRequestCacheConfig.class).autowire(); this.spring.register(DisableCsrfEnablesRequestCacheConfig.class).autowire();
MvcResult mvcResult = this.mvc.perform(post("/to-save")).andReturn(); MvcResult mvcResult = this.mvc.perform(post("/to-save")).andReturn();
RequestCache requestCache = new HttpSessionRequestCache();
String redirectUrl = requestCache.getRequest(mvcResult.getRequest(), mvcResult.getResponse()).getRedirectUrl();
this.mvc.perform(post("/login").param("username", "user").param("password", "password") this.mvc.perform(post("/login").param("username", "user").param("password", "password")
.session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound()) .session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound())
.andExpect(redirectedUrl("http://localhost/to-save")); .andExpect(redirectedUrl(redirectUrl));
} }
@Test @Test
@ -215,9 +219,11 @@ public class CsrfConfigurerTests {
given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.generateToken(any())).willReturn(csrfToken); given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.generateToken(any())).willReturn(csrfToken);
this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire(); this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire();
MvcResult mvcResult = this.mvc.perform(get("/some-url")).andReturn(); MvcResult mvcResult = this.mvc.perform(get("/some-url")).andReturn();
RequestCache requestCache = new HttpSessionRequestCache();
String redirectUrl = requestCache.getRequest(mvcResult.getRequest(), mvcResult.getResponse()).getRedirectUrl();
this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf()) this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf())
.session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound()) .session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound())
.andExpect(redirectedUrl("http://localhost/some-url")); .andExpect(redirectedUrl(redirectUrl));
verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce()) verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce())
.loadToken(any(HttpServletRequest.class)); .loadToken(any(HttpServletRequest.class));
} }

View File

@ -36,6 +36,7 @@ import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension; import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.User;
import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.provisioning.InMemoryUserDetailsManager;
import org.springframework.security.test.web.servlet.RequestCacheResultMatcher;
import org.springframework.security.web.savedrequest.NullRequestCache; import org.springframework.security.web.savedrequest.NullRequestCache;
import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.savedrequest.RequestCacheAwareFilter; import org.springframework.security.web.savedrequest.RequestCacheAwareFilter;
@ -177,7 +178,7 @@ public class RequestCacheConfigurerTests {
.getRequest() .getRequest()
.getSession(); .getSession();
// @formatter:on // @formatter:on
this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("http://localhost/messages")); this.mvc.perform(formLogin(session)).andExpect(RequestCacheResultMatcher.redirectToCachedRequest());
} }
@Test @Test
@ -191,7 +192,7 @@ public class RequestCacheConfigurerTests {
.getRequest() .getRequest()
.getSession(); .getSession();
// @formatter:on // @formatter:on
this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("http://localhost/messages")); this.mvc.perform(formLogin(session)).andExpect(RequestCacheResultMatcher.redirectToCachedRequest());
} }
@Test @Test
@ -206,7 +207,7 @@ public class RequestCacheConfigurerTests {
.getRequest() .getRequest()
.getSession(); .getSession();
// @formatter:on // @formatter:on
this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("http://localhost/messages")); this.mvc.perform(formLogin(session)).andExpect(RequestCacheResultMatcher.redirectToCachedRequest());
} }
@Test @Test
@ -221,7 +222,7 @@ public class RequestCacheConfigurerTests {
.getRequest() .getRequest()
.getSession(); .getSession();
// @formatter:on // @formatter:on
this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("http://localhost/messages")); this.mvc.perform(formLogin(session)).andExpect(RequestCacheResultMatcher.redirectToCachedRequest());
} }
// gh-6102 // gh-6102
@ -275,7 +276,7 @@ public class RequestCacheConfigurerTests {
.getRequest() .getRequest()
.getSession(); .getSession();
// @formatter:on // @formatter:on
this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("http://localhost/bob")); this.mvc.perform(formLogin(session)).andExpect(RequestCacheResultMatcher.redirectToCachedRequest());
} }
@Test @Test

View File

@ -35,6 +35,7 @@ import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension; import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners; import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners;
import org.springframework.security.test.context.support.WithMockUser; import org.springframework.security.test.context.support.WithMockUser;
import org.springframework.security.test.web.servlet.RequestCacheResultMatcher;
import org.springframework.security.test.web.support.WebTestUtils; import org.springframework.security.test.web.support.WebTestUtils;
import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.access.AccessDeniedHandler;
@ -337,7 +338,7 @@ public class CsrfConfigTests {
.session(session) .session(session)
.with(csrf()); .with(csrf());
this.mvc.perform(login) this.mvc.perform(login)
.andExpect(redirectedUrl("http://localhost/authenticated")); .andExpect(RequestCacheResultMatcher.redirectToCachedRequest());
// @formatter:on // @formatter:on
} }

View File

@ -79,6 +79,7 @@ import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.test.web.servlet.RequestCacheResultMatcher;
import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.access.ExceptionTranslationFilter; import org.springframework.security.web.access.ExceptionTranslationFilter;
@ -807,7 +808,7 @@ public class MiscHttpConfigTests {
.session(session) .session(session)
.with(csrf()); .with(csrf());
session = (MockHttpSession) this.mvc.perform(loginRequest) session = (MockHttpSession) this.mvc.perform(loginRequest)
.andExpect(redirectedUrl("https://localhost:9443/protected")) .andExpect(RequestCacheResultMatcher.redirectToCachedRequest())
.andReturn() .andReturn()
.getRequest() .getRequest()
.getSession(false); .getSession(false);

View File

@ -0,0 +1,51 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.test.web.servlet;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.savedrequest.SavedRequest;
import org.springframework.test.web.servlet.ResultMatcher;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Ensures that the MockMvcResult redirects to the saved ReqeuestCache.getRedirectUrl().
*/
public final class RequestCacheResultMatcher {
/**
* Verifies that the MockMvcResult redirects to the saved
* ReqeustCache.getRedirectUrl().
* @return a ResultMatcher that performs the verification.
*/
public static ResultMatcher redirectToCachedRequest() {
return (mvcResult) -> {
RequestCache requestCache = new HttpSessionRequestCache();
MockHttpServletResponse response = mvcResult.getResponse();
SavedRequest savedRequest = requestCache.getRequest(mvcResult.getRequest(), response);
assertThat(savedRequest).describedAs("savedReqeust cannot be null").isNotNull();
String cachedRedirectUrl = savedRequest.getRedirectUrl();
assertThat(response.getRedirectedUrl()).isEqualTo(cachedRedirectUrl);
};
}
private RequestCacheResultMatcher() {
}
}

View File

@ -57,6 +57,7 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2Authorization
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache; import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.savedrequest.SavedRequest;
import org.springframework.security.web.util.UrlUtils; import org.springframework.security.web.util.UrlUtils;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@ -319,8 +320,9 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration1); this.setUpAuthorizationRequest(request, response, this.registration1);
this.setUpAuthenticationResult(this.registration1); this.setUpAuthenticationResult(this.registration1);
String redirectUrl = requestCache.getRequest(request, response).getRedirectUrl();
this.filter.doFilter(request, response, filterChain); this.filter.doFilter(request, response, filterChain);
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/saved-request"); assertThat(response.getRedirectedUrl()).isEqualTo(redirectUrl);
} }
@Test @Test
@ -331,13 +333,17 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
this.setUpAuthenticationResult(this.registration1); this.setUpAuthenticationResult(this.registration1);
RequestCache requestCache = spy(HttpSessionRequestCache.class); RequestCache requestCache = mock(RequestCache.class);
SavedRequest savedRequest = mock(SavedRequest.class);
String redirectUrl = "https://example.com/saved-request?success";
given(savedRequest.getRedirectUrl()).willReturn(redirectUrl);
given(requestCache.getRequest(any(), any())).willReturn(savedRequest);
this.filter.setRequestCache(requestCache); this.filter.setRequestCache(requestCache);
authorizationRequest.setRequestURI("/saved-request"); authorizationRequest.setRequestURI("/saved-request");
requestCache.saveRequest(authorizationRequest, response); requestCache.saveRequest(authorizationRequest, response);
this.filter.doFilter(authorizationResponse, response, filterChain); this.filter.doFilter(authorizationResponse, response, filterChain);
verify(requestCache).getRequest(any(HttpServletRequest.class), any(HttpServletResponse.class)); verify(requestCache).getRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/saved-request"); assertThat(response.getRedirectedUrl()).isEqualTo(redirectUrl);
} }
@Test @Test

View File

@ -52,7 +52,7 @@ public class HttpSessionRequestCache implements RequestCache {
private String sessionAttrName = SAVED_REQUEST; private String sessionAttrName = SAVED_REQUEST;
private String matchingRequestParameterName; private String matchingRequestParameterName = "continue";
/** /**
* Stores the current request, provided the configuration properties allow it. * Stores the current request, provided the configuration properties allow it.
@ -177,7 +177,7 @@ public class HttpSessionRequestCache implements RequestCache {
* {@link #getMatchingRequest(HttpServletRequest, HttpServletResponse)} * {@link #getMatchingRequest(HttpServletRequest, HttpServletResponse)}
* @param matchingRequestParameterName the parameter name that must be in the request * @param matchingRequestParameterName the parameter name that must be in the request
* for {@link #getMatchingRequest(HttpServletRequest, HttpServletResponse)} to check * for {@link #getMatchingRequest(HttpServletRequest, HttpServletResponse)} to check
* the session. * the session. Default is "continue".
*/ */
public void setMatchingRequestParameterName(String matchingRequestParameterName) { public void setMatchingRequestParameterName(String matchingRequestParameterName) {
this.matchingRequestParameterName = matchingRequestParameterName; this.matchingRequestParameterName = matchingRequestParameterName;

View File

@ -44,6 +44,7 @@ import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.WebAttributes; import org.springframework.security.web.WebAttributes;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache; import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.savedrequest.SavedRequest; import org.springframework.security.web.savedrequest.SavedRequest;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -104,7 +105,6 @@ public class ExceptionTranslationFilterTests {
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
filter.doFilter(request, response, fc); filter.doFilter(request, response, fc);
assertThat(response.getRedirectedUrl()).isEqualTo("/mycontext/login.jsp"); assertThat(response.getRedirectedUrl()).isEqualTo("/mycontext/login.jsp");
assertThat(getSavedRequestUrl(request)).isEqualTo("http://localhost/mycontext/secure/page.html");
} }
@Test @Test
@ -126,12 +126,13 @@ public class ExceptionTranslationFilterTests {
securityContext.setAuthentication( securityContext.setAuthentication(
new RememberMeAuthenticationToken("ignored", "ignored", AuthorityUtils.createAuthorityList("IGNORED"))); new RememberMeAuthenticationToken("ignored", "ignored", AuthorityUtils.createAuthorityList("IGNORED")));
SecurityContextHolder.setContext(securityContext); SecurityContextHolder.setContext(securityContext);
RequestCache requestCache = new HttpSessionRequestCache();
// Test // Test
ExceptionTranslationFilter filter = new ExceptionTranslationFilter(this.mockEntryPoint); ExceptionTranslationFilter filter = new ExceptionTranslationFilter(this.mockEntryPoint, requestCache);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
filter.doFilter(request, response, fc); filter.doFilter(request, response, fc);
assertThat(response.getRedirectedUrl()).isEqualTo("/mycontext/login.jsp"); assertThat(response.getRedirectedUrl()).isEqualTo("/mycontext/login.jsp");
assertThat(getSavedRequestUrl(request)).isEqualTo("http://localhost/mycontext/secure/page.html"); assertThat(getSavedRequestUrl(request)).isEqualTo(requestCache.getRequest(request, response).getRedirectUrl());
} }
@Test @Test
@ -199,12 +200,13 @@ public class ExceptionTranslationFilterTests {
willThrow(new BadCredentialsException("")).given(fc).doFilter(any(HttpServletRequest.class), willThrow(new BadCredentialsException("")).given(fc).doFilter(any(HttpServletRequest.class),
any(HttpServletResponse.class)); any(HttpServletResponse.class));
// Test // Test
ExceptionTranslationFilter filter = new ExceptionTranslationFilter(this.mockEntryPoint); RequestCache requestCache = new HttpSessionRequestCache();
ExceptionTranslationFilter filter = new ExceptionTranslationFilter(this.mockEntryPoint, requestCache);
filter.afterPropertiesSet(); filter.afterPropertiesSet();
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
filter.doFilter(request, response, fc); filter.doFilter(request, response, fc);
assertThat(response.getRedirectedUrl()).isEqualTo("/mycontext/login.jsp"); assertThat(response.getRedirectedUrl()).isEqualTo("/mycontext/login.jsp");
assertThat(getSavedRequestUrl(request)).isEqualTo("http://localhost/mycontext/secure/page.html"); assertThat(getSavedRequestUrl(request)).isEqualTo(requestCache.getRequest(request, response).getRedirectUrl());
} }
@Test @Test
@ -230,7 +232,6 @@ public class ExceptionTranslationFilterTests {
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
filter.doFilter(request, response, fc); filter.doFilter(request, response, fc);
assertThat(response.getRedirectedUrl()).isEqualTo("/mycontext/login.jsp"); assertThat(response.getRedirectedUrl()).isEqualTo("/mycontext/login.jsp");
assertThat(getSavedRequestUrl(request)).isEqualTo("http://localhost:8080/mycontext/secure/page.html");
} }
@Test @Test

View File

@ -120,6 +120,21 @@ public class HttpSessionRequestCacheTests {
assertThat(matchingRequest).isNotNull(); assertThat(matchingRequest).isNotNull();
} }
@Test
public void getMatchingRequestWhenMatchesThenRemoved() {
MockHttpServletRequest request = new MockHttpServletRequest();
HttpSessionRequestCache cache = new HttpSessionRequestCache();
cache.setMatchingRequestParameterName("success");
cache.saveRequest(request, new MockHttpServletResponse());
assertThat(request.getSession().getAttribute(HttpSessionRequestCache.SAVED_REQUEST)).isNotNull();
MockHttpServletRequest requestToMatch = new MockHttpServletRequest();
requestToMatch.setParameter("success", "");
requestToMatch.setSession(request.getSession());
HttpServletRequest matchingRequest = cache.getMatchingRequest(requestToMatch, new MockHttpServletResponse());
assertThat(matchingRequest).isNotNull();
assertThat(request.getSession().getAttribute(HttpSessionRequestCache.SAVED_REQUEST)).isNull();
}
private static final class CustomSavedRequest implements SavedRequest { private static final class CustomSavedRequest implements SavedRequest {
private final SavedRequest delegate; private final SavedRequest delegate;

View File

@ -26,19 +26,21 @@ import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
public class RequestCacheAwareFilterTests { public class RequestCacheAwareFilterTests {
@Test @Test
public void doFilterWhenHttpSessionRequestCacheConfiguredThenSavedRequestRemovedAfterMatch() throws Exception { public void doFilterWhenHttpSessionRequestCacheConfiguredThenSavedRequestRemovedAfterMatch() throws Exception {
RequestCacheAwareFilter filter = new RequestCacheAwareFilter();
HttpSessionRequestCache cache = new HttpSessionRequestCache();
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/destination"); MockHttpServletRequest request = new MockHttpServletRequest("POST", "/destination");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
cache.saveRequest(request, response); RequestCache requestCache = mock(RequestCache.class);
assertThat(request.getSession().getAttribute(HttpSessionRequestCache.SAVED_REQUEST)).isNotNull(); RequestCacheAwareFilter filter = new RequestCacheAwareFilter(requestCache);
given(requestCache.getMatchingRequest(request, response)).willReturn(request);
filter.doFilter(request, response, new MockFilterChain()); filter.doFilter(request, response, new MockFilterChain());
assertThat(request.getSession().getAttribute(HttpSessionRequestCache.SAVED_REQUEST)).isNull(); verify(requestCache).getMatchingRequest(request, response);
} }
@Test @Test