Standardize Mock Request Paths

Closes gh-17449
This commit is contained in:
Josh Cummings 2025-07-02 18:16:41 -06:00
parent d869686d09
commit 98686a5139
No known key found for this signature in database
GPG Key ID: 869B37A20E876129
64 changed files with 399 additions and 721 deletions

View File

@ -14,6 +14,7 @@ dependencies {
provided 'jakarta.servlet:jakarta.servlet-api' provided 'jakarta.servlet:jakarta.servlet-api'
testImplementation project(path : ':spring-security-web', configuration : 'tests')
testImplementation "org.assertj:assertj-core" testImplementation "org.assertj:assertj-core"
testImplementation "org.junit.jupiter:junit-jupiter-api" testImplementation "org.junit.jupiter:junit-jupiter-api"
testImplementation "org.junit.jupiter:junit-jupiter-params" testImplementation "org.junit.jupiter:junit-jupiter-params"

View File

@ -55,6 +55,8 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post;
/** /**
* Tests {@link CasAuthenticationFilter}. * Tests {@link CasAuthenticationFilter}.
@ -79,9 +81,7 @@ public class CasAuthenticationFilterTests {
@Test @Test
public void testNormalOperation() throws Exception { public void testNormalOperation() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/login/cas"); MockHttpServletRequest request = post("/login/cas").param("ticket", "ST-0-ER94xMJmn6pha35CQRoZ").build();
request.setServletPath("/login/cas");
request.addParameter("ticket", "ST-0-ER94xMJmn6pha35CQRoZ");
CasAuthenticationFilter filter = new CasAuthenticationFilter(); CasAuthenticationFilter filter = new CasAuthenticationFilter();
filter.setAuthenticationManager((a) -> a); filter.setAuthenticationManager((a) -> a);
assertThat(filter.requiresAuthentication(request, new MockHttpServletResponse())).isTrue(); assertThat(filter.requiresAuthentication(request, new MockHttpServletResponse())).isTrue();
@ -104,24 +104,22 @@ public class CasAuthenticationFilterTests {
String url = "/login/cas"; String url = "/login/cas";
CasAuthenticationFilter filter = new CasAuthenticationFilter(); CasAuthenticationFilter filter = new CasAuthenticationFilter();
filter.setFilterProcessesUrl(url); filter.setFilterProcessesUrl(url);
MockHttpServletRequest request = new MockHttpServletRequest("POST", url); MockHttpServletRequest request = post(url).build();
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
request.setServletPath(url);
assertThat(filter.requiresAuthentication(request, response)).isTrue(); assertThat(filter.requiresAuthentication(request, response)).isTrue();
} }
@Test @Test
public void testRequiresAuthenticationProxyRequest() { public void testRequiresAuthenticationProxyRequest() {
CasAuthenticationFilter filter = new CasAuthenticationFilter(); CasAuthenticationFilter filter = new CasAuthenticationFilter();
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/pgtCallback").build();
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
request.setServletPath("/pgtCallback");
assertThat(filter.requiresAuthentication(request, response)).isFalse(); assertThat(filter.requiresAuthentication(request, response)).isFalse();
filter.setProxyReceptorUrl(request.getServletPath()); filter.setProxyReceptorUrl(request.getServletPath());
assertThat(filter.requiresAuthentication(request, response)).isFalse(); assertThat(filter.requiresAuthentication(request, response)).isFalse();
filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class)); filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class));
assertThat(filter.requiresAuthentication(request, response)).isTrue(); assertThat(filter.requiresAuthentication(request, response)).isTrue();
request.setServletPath("/other"); request = get("/other").build();
assertThat(filter.requiresAuthentication(request, response)).isFalse(); assertThat(filter.requiresAuthentication(request, response)).isFalse();
} }
@ -133,12 +131,10 @@ public class CasAuthenticationFilterTests {
CasAuthenticationFilter filter = new CasAuthenticationFilter(); CasAuthenticationFilter filter = new CasAuthenticationFilter();
filter.setFilterProcessesUrl(url); filter.setFilterProcessesUrl(url);
filter.setServiceProperties(properties); filter.setServiceProperties(properties);
MockHttpServletRequest request = new MockHttpServletRequest("POST", url); MockHttpServletRequest request = post(url).build();
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
request.setServletPath(url);
assertThat(filter.requiresAuthentication(request, response)).isTrue(); assertThat(filter.requiresAuthentication(request, response)).isTrue();
request = new MockHttpServletRequest("POST", "/other"); request = post("/other").build();
request.setServletPath("/other");
assertThat(filter.requiresAuthentication(request, response)).isFalse(); assertThat(filter.requiresAuthentication(request, response)).isFalse();
request.setParameter(properties.getArtifactParameter(), "value"); request.setParameter(properties.getArtifactParameter(), "value");
assertThat(filter.requiresAuthentication(request, response)).isTrue(); assertThat(filter.requiresAuthentication(request, response)).isTrue();
@ -156,9 +152,8 @@ public class CasAuthenticationFilterTests {
@Test @Test
public void testAuthenticateProxyUrl() throws Exception { public void testAuthenticateProxyUrl() throws Exception {
CasAuthenticationFilter filter = new CasAuthenticationFilter(); CasAuthenticationFilter filter = new CasAuthenticationFilter();
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/pgtCallback").build();
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
request.setServletPath("/pgtCallback");
filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class)); filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class));
filter.setProxyReceptorUrl(request.getServletPath()); filter.setProxyReceptorUrl(request.getServletPath());
assertThat(filter.attemptAuthentication(request, response)).isNull(); assertThat(filter.attemptAuthentication(request, response)).isNull();
@ -172,9 +167,7 @@ public class CasAuthenticationFilterTests {
given(manager.authenticate(any(Authentication.class))).willReturn(authentication); given(manager.authenticate(any(Authentication.class))).willReturn(authentication);
ServiceProperties serviceProperties = new ServiceProperties(); ServiceProperties serviceProperties = new ServiceProperties();
serviceProperties.setAuthenticateAllArtifacts(true); serviceProperties.setAuthenticateAllArtifacts(true);
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/authenticate"); MockHttpServletRequest request = post("/authenticate").param("ticket", "ST-1-123").build();
request.setParameter("ticket", "ST-1-123");
request.setServletPath("/authenticate");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain chain = mock(FilterChain.class); FilterChain chain = mock(FilterChain.class);
CasAuthenticationFilter filter = new CasAuthenticationFilter(); CasAuthenticationFilter filter = new CasAuthenticationFilter();
@ -200,10 +193,9 @@ public class CasAuthenticationFilterTests {
@Test @Test
public void testChainNotInvokedForProxyReceptor() throws Exception { public void testChainNotInvokedForProxyReceptor() throws Exception {
CasAuthenticationFilter filter = new CasAuthenticationFilter(); CasAuthenticationFilter filter = new CasAuthenticationFilter();
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/pgtCallback").build();
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain chain = mock(FilterChain.class); FilterChain chain = mock(FilterChain.class);
request.setServletPath("/pgtCallback");
filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class)); filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class));
filter.setProxyReceptorUrl(request.getServletPath()); filter.setProxyReceptorUrl(request.getServletPath());
filter.doFilter(request, response, chain); filter.doFilter(request, response, chain);
@ -271,16 +263,14 @@ public class CasAuthenticationFilterTests {
@Test @Test
public void requiresAuthenticationWhenProxyRequestMatcherThenMatches() { public void requiresAuthenticationWhenProxyRequestMatcherThenMatches() {
CasAuthenticationFilter filter = new CasAuthenticationFilter(); CasAuthenticationFilter filter = new CasAuthenticationFilter();
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/pgtCallback"); MockHttpServletRequest request = get("/pgtCallback").build();
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
request.setServletPath("/pgtCallback");
assertThat(filter.requiresAuthentication(request, response)).isFalse(); assertThat(filter.requiresAuthentication(request, response)).isFalse();
filter.setProxyReceptorMatcher(PathPatternRequestMatcher.withDefaults().matcher(request.getServletPath())); filter.setProxyReceptorMatcher(PathPatternRequestMatcher.withDefaults().matcher(request.getServletPath()));
assertThat(filter.requiresAuthentication(request, response)).isFalse(); assertThat(filter.requiresAuthentication(request, response)).isFalse();
filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class)); filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class));
assertThat(filter.requiresAuthentication(request, response)).isTrue(); assertThat(filter.requiresAuthentication(request, response)).isTrue();
request.setRequestURI("/other"); request = get("/other").build();
request.setServletPath("/other");
assertThat(filter.requiresAuthentication(request, response)).isFalse(); assertThat(filter.requiresAuthentication(request, response)).isFalse();
} }

View File

@ -44,6 +44,7 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests {@link FilterChainProxy}. * Tests {@link FilterChainProxy}.
@ -143,13 +144,12 @@ public class FilterChainProxyConfigTests {
} }
private void doNormalOperation(FilterChainProxy filterChainProxy) throws Exception { private void doNormalOperation(FilterChainProxy filterChainProxy) throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); MockHttpServletRequest request = get("/foo/secure/super/somefile.html").build();
request.setServletPath("/foo/secure/super/somefile.html");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain chain = mock(FilterChain.class); FilterChain chain = mock(FilterChain.class);
filterChainProxy.doFilter(request, response, chain); filterChainProxy.doFilter(request, response, chain);
verify(chain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); verify(chain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
request.setServletPath("/a/path/which/doesnt/match/any/filter.html"); request = get("/a/path/which/doesnt/match/any/filter.html").build();
chain = mock(FilterChain.class); chain = mock(FilterChain.class);
filterChainProxy.doFilter(request, response, chain); filterChainProxy.doFilter(request, response, chain);
verify(chain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); verify(chain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));

View File

@ -77,7 +77,6 @@ public class AuthorizeRequestsTests {
public void setup() { public void setup() {
this.servletContext = spy(MockServletContext.mvc()); this.servletContext = spy(MockServletContext.mvc());
this.request = new MockHttpServletRequest(this.servletContext, "GET", ""); this.request = new MockHttpServletRequest(this.servletContext, "GET", "");
this.request.setMethod("GET");
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
this.chain = new MockFilterChain(); this.chain = new MockFilterChain();
} }
@ -111,10 +110,12 @@ public class AuthorizeRequestsTests {
public void antMatchersPathVariables() throws Exception { public void antMatchersPathVariables() throws Exception {
loadConfig(AntPatchersPathVariables.class); loadConfig(AntPatchersPathVariables.class);
this.request.setServletPath("/user/user"); this.request.setServletPath("/user/user");
this.request.setRequestURI("/user/user");
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
this.setup(); this.setup();
this.request.setServletPath("/user/deny"); this.request.setServletPath("/user/deny");
this.request.setRequestURI("/user/deny");
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
} }
@ -123,10 +124,12 @@ public class AuthorizeRequestsTests {
@Test @Test
public void antMatchersPathVariablesCaseInsensitive() throws Exception { public void antMatchersPathVariablesCaseInsensitive() throws Exception {
loadConfig(AntPatchersPathVariables.class); loadConfig(AntPatchersPathVariables.class);
this.request.setRequestURI("/USER/user");
this.request.setServletPath("/USER/user"); this.request.setServletPath("/USER/user");
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
this.setup(); this.setup();
this.request.setRequestURI("/USER/deny");
this.request.setServletPath("/USER/deny"); this.request.setServletPath("/USER/deny");
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
@ -137,10 +140,12 @@ public class AuthorizeRequestsTests {
public void antMatchersPathVariablesCaseInsensitiveCamelCaseVariables() throws Exception { public void antMatchersPathVariablesCaseInsensitiveCamelCaseVariables() throws Exception {
loadConfig(AntMatchersPathVariablesCamelCaseVariables.class); loadConfig(AntMatchersPathVariablesCamelCaseVariables.class);
this.request.setServletPath("/USER/user"); this.request.setServletPath("/USER/user");
this.request.setRequestURI("/USER/user");
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
this.setup(); this.setup();
this.request.setServletPath("/USER/deny"); this.request.setServletPath("/USER/deny");
this.request.setRequestURI("/USER/deny");
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
} }

View File

@ -39,6 +39,7 @@ import org.springframework.security.web.context.HttpSessionSecurityContextReposi
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post;
/** /**
* @author Rob Winch * @author Rob Winch
@ -48,8 +49,6 @@ public class HttpSecurityLogoutTests {
AnnotationConfigWebApplicationContext context; AnnotationConfigWebApplicationContext context;
MockHttpServletRequest request;
MockHttpServletResponse response; MockHttpServletResponse response;
MockFilterChain chain; MockFilterChain chain;
@ -59,7 +58,6 @@ public class HttpSecurityLogoutTests {
@BeforeEach @BeforeEach
public void setup() { public void setup() {
this.request = new MockHttpServletRequest("GET", "");
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
this.chain = new MockFilterChain(); this.chain = new MockFilterChain();
} }
@ -77,11 +75,10 @@ public class HttpSecurityLogoutTests {
loadConfig(ClearAuthenticationFalseConfig.class); loadConfig(ClearAuthenticationFalseConfig.class);
SecurityContext currentContext = SecurityContextHolder.createEmptyContext(); SecurityContext currentContext = SecurityContextHolder.createEmptyContext();
currentContext.setAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER")); currentContext.setAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"));
this.request.getSession() MockHttpServletRequest request = post("/logout").build();
request.getSession()
.setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, currentContext); .setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, currentContext);
this.request.setMethod("POST"); this.springSecurityFilterChain.doFilter(request, this.response, this.chain);
this.request.setServletPath("/logout");
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(currentContext.getAuthentication()).isNotNull(); assertThat(currentContext.getAuthentication()).isNotNull();
} }

View File

@ -45,6 +45,7 @@ import org.springframework.web.servlet.handler.HandlerMappingIntrospector;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* @author Rob Winch * @author Rob Winch
@ -54,8 +55,6 @@ public class HttpSecurityRequestMatchersTests {
AnnotationConfigWebApplicationContext context; AnnotationConfigWebApplicationContext context;
MockHttpServletRequest request;
MockHttpServletResponse response; MockHttpServletResponse response;
MockFilterChain chain; MockFilterChain chain;
@ -65,8 +64,6 @@ public class HttpSecurityRequestMatchersTests {
@BeforeEach @BeforeEach
public void setup() { public void setup() {
this.request = new MockHttpServletRequest("GET", "");
this.request.setMethod("GET");
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
this.chain = new MockFilterChain(); this.chain = new MockFilterChain();
} }
@ -87,70 +84,64 @@ public class HttpSecurityRequestMatchersTests {
@Test @Test
public void requestMatchersMvcMatcherServletPath() throws Exception { public void requestMatchersMvcMatcherServletPath() throws Exception {
loadConfig(RequestMatchersMvcMatcherServeltPathConfig.class); loadConfig(RequestMatchersMvcMatcherServeltPathConfig.class);
this.request.setServletPath("/spring"); MockHttpServletRequest request = get().requestUri(null, "/spring", "/path").build();
this.request.setRequestURI("/spring/path"); this.springSecurityFilterChain.doFilter(request, this.response, this.chain);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED);
setup(); setup();
this.request.setServletPath(""); request = get().requestUri(null, "", "/path").build();
this.request.setRequestURI("/path"); this.springSecurityFilterChain.doFilter(request, this.response, this.chain);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
setup(); setup();
this.request.setServletPath("/other"); request = get().requestUri(null, "/other", "/path").build();
this.request.setRequestURI("/other/path"); this.springSecurityFilterChain.doFilter(request, this.response, this.chain);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
} }
@Test @Test
public void requestMatcherWhensMvcMatcherServletPathInLambdaThenPathIsSecured() throws Exception { public void requestMatcherWhensMvcMatcherServletPathInLambdaThenPathIsSecured() throws Exception {
loadConfig(RequestMatchersMvcMatcherServletPathInLambdaConfig.class); loadConfig(RequestMatchersMvcMatcherServletPathInLambdaConfig.class);
this.request.setServletPath("/spring"); MockHttpServletRequest request = get().requestUri(null, "/spring", "/path").build();
this.request.setRequestURI("/spring/path"); this.springSecurityFilterChain.doFilter(request, this.response, this.chain);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED);
setup(); setup();
this.request.setServletPath(""); request = get().requestUri(null, "", "/path").build();
this.request.setRequestURI("/path"); this.springSecurityFilterChain.doFilter(request, this.response, this.chain);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
setup(); setup();
this.request.setServletPath("/other"); request = get().requestUri(null, "/other", "/path").build();
this.request.setRequestURI("/other/path"); this.springSecurityFilterChain.doFilter(request, this.response, this.chain);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
} }
@Test @Test
public void requestMatcherWhenMultiMvcMatcherInLambdaThenAllPathsAreDenied() throws Exception { public void requestMatcherWhenMultiMvcMatcherInLambdaThenAllPathsAreDenied() throws Exception {
loadConfig(MultiMvcMatcherInLambdaConfig.class); loadConfig(MultiMvcMatcherInLambdaConfig.class);
this.request.setRequestURI("/test-1"); MockHttpServletRequest request = get("/test-1").build();
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED);
setup(); setup();
this.request.setRequestURI("/test-2"); request = get("/test-2").build();
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED);
setup(); setup();
this.request.setRequestURI("/test-3"); request = get("/test-3").build();
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED);
} }
@Test @Test
public void requestMatcherWhenMultiMvcMatcherThenAllPathsAreDenied() throws Exception { public void requestMatcherWhenMultiMvcMatcherThenAllPathsAreDenied() throws Exception {
loadConfig(MultiMvcMatcherConfig.class); loadConfig(MultiMvcMatcherConfig.class);
this.request.setRequestURI("/test-1"); MockHttpServletRequest request = get("/test-1").build();
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED);
setup(); setup();
this.request.setRequestURI("/test-2"); request = get("/test-2").build();
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED);
setup(); setup();
this.request.setRequestURI("/test-3"); request = get("/test-3").build();
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED);
} }

View File

@ -67,7 +67,7 @@ public class HttpSecuritySecurityMatchersNoMvcTests {
@BeforeEach @BeforeEach
public void setup() throws Exception { public void setup() throws Exception {
this.request = new MockHttpServletRequest("GET", ""); this.request = new MockHttpServletRequest();
this.request.setMethod("GET"); this.request.setMethod("GET");
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
this.chain = new MockFilterChain(); this.chain = new MockFilterChain();
@ -83,15 +83,15 @@ public class HttpSecuritySecurityMatchersNoMvcTests {
@Test @Test
public void securityMatcherWhenNoMvcThenAntMatcher() throws Exception { public void securityMatcherWhenNoMvcThenAntMatcher() throws Exception {
loadConfig(SecurityMatcherNoMvcConfig.class); loadConfig(SecurityMatcherNoMvcConfig.class);
this.request.setServletPath("/path"); this.request.setRequestURI("/path");
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED);
setup(); setup();
this.request.setServletPath("/path.html"); this.request.setRequestURI("/path.html");
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
setup(); setup();
this.request.setServletPath("/path/"); this.request.setRequestURI("/path/");
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
List<RequestMatcher> requestMatchers = this.springSecurityFilterChain.getFilterChains() List<RequestMatcher> requestMatchers = this.springSecurityFilterChain.getFilterChains()
.stream() .stream()

View File

@ -30,14 +30,10 @@ import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.core.userdetails.PasswordEncodedUser; import org.springframework.security.core.userdetails.PasswordEncodedUser;
import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.provisioning.InMemoryUserDetailsManager;
import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.context.HttpRequestResponseHolder;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.CsrfTokenRequestHandler; import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.csrf.DeferredCsrfToken;
@ -46,14 +42,13 @@ import org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post;
/** /**
* @author Rob Winch * @author Rob Winch
*/ */
public class SessionManagementConfigurerServlet31Tests { public class SessionManagementConfigurerServlet31Tests {
MockHttpServletRequest request;
MockHttpServletResponse response; MockHttpServletResponse response;
MockFilterChain chain; MockFilterChain chain;
@ -64,7 +59,6 @@ public class SessionManagementConfigurerServlet31Tests {
@BeforeEach @BeforeEach
public void setup() { public void setup() {
this.request = new MockHttpServletRequest("GET", "");
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
this.chain = new MockFilterChain(); this.chain = new MockFilterChain();
} }
@ -78,13 +72,9 @@ public class SessionManagementConfigurerServlet31Tests {
@Test @Test
public void changeSessionIdThenPreserveParameters() throws Exception { public void changeSessionIdThenPreserveParameters() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); MockHttpServletRequest request = post("/login").param("username", "user").param("password", "password").build();
String id = request.getSession().getId(); String id = request.getSession().getId();
request.getSession(); request.getSession();
request.setServletPath("/login");
request.setMethod("POST");
request.setParameter("username", "user");
request.setParameter("password", "password");
HttpSessionCsrfTokenRepository repository = new HttpSessionCsrfTokenRepository(); HttpSessionCsrfTokenRepository repository = new HttpSessionCsrfTokenRepository();
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler(); CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler();
DeferredCsrfToken deferredCsrfToken = repository.loadDeferredToken(request, this.response); DeferredCsrfToken deferredCsrfToken = repository.loadDeferredToken(request, this.response);
@ -106,15 +96,6 @@ public class SessionManagementConfigurerServlet31Tests {
this.springSecurityFilterChain = this.context.getBean("springSecurityFilterChain", Filter.class); this.springSecurityFilterChain = this.context.getBean("springSecurityFilterChain", Filter.class);
} }
private void login(Authentication auth) {
HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository();
HttpRequestResponseHolder requestResponseHolder = new HttpRequestResponseHolder(this.request, this.response);
repo.loadContext(requestResponseHolder);
SecurityContextImpl securityContextImpl = new SecurityContextImpl();
securityContextImpl.setAuthentication(auth);
repo.saveContext(securityContextImpl, requestResponseHolder.getRequest(), requestResponseHolder.getResponse());
}
@Configuration @Configuration
@EnableWebSecurity @EnableWebSecurity
static class SessionManagementDefaultSessionFixationServlet31Config { static class SessionManagementDefaultSessionFixationServlet31Config {

View File

@ -107,6 +107,7 @@ import org.springframework.security.web.authentication.HttpStatusEntryPoint;
import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpRequestResponseHolder;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.servlet.TestMockHttpServletRequests;
import org.springframework.security.web.session.HttpSessionDestroyedEvent; import org.springframework.security.web.session.HttpSessionDestroyedEvent;
import org.springframework.security.web.util.matcher.RequestHeaderRequestMatcher; import org.springframework.security.web.util.matcher.RequestHeaderRequestMatcher;
import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.util.ReflectionTestUtils;
@ -127,6 +128,7 @@ import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.security.config.annotation.SecurityContextChangedListenerArgumentMatchers.setAuthentication; import static org.springframework.security.config.annotation.SecurityContextChangedListenerArgumentMatchers.setAuthentication;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
@ -185,8 +187,7 @@ public class OAuth2LoginConfigurerTests {
@BeforeEach @BeforeEach
public void setup() { public void setup() {
this.request = new MockHttpServletRequest("GET", ""); this.request = TestMockHttpServletRequests.get("/login/oauth2/code/google").build();
this.request.setServletPath("/login/oauth2/code/google");
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
this.filterChain = new MockFilterChain(); this.filterChain = new MockFilterChain();
} }
@ -347,7 +348,7 @@ public class OAuth2LoginConfigurerTests {
loadConfig(OAuth2LoginConfigLoginProcessingUrl.class); loadConfig(OAuth2LoginConfigLoginProcessingUrl.class);
// setup authorization request // setup authorization request
OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest(); OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest();
this.request.setServletPath("/login/oauth2/google"); this.request.setRequestURI("/login/oauth2/google");
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response);
// setup authentication parameters // setup authentication parameters
this.request.setParameter("code", "code123"); this.request.setParameter("code", "code123");
@ -381,8 +382,7 @@ public class OAuth2LoginConfigurerTests {
// @formatter:on // @formatter:on
given(resolver.resolve(any())).willReturn(result); given(resolver.resolve(any())).willReturn(result);
String requestUri = "/oauth2/authorization/google"; String requestUri = "/oauth2/authorization/google";
this.request = new MockHttpServletRequest("GET", requestUri); this.request = TestMockHttpServletRequests.get(requestUri).build();
this.request.setServletPath(requestUri);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
assertThat(this.response.getRedirectedUrl()).isEqualTo( assertThat(this.response.getRedirectedUrl()).isEqualTo(
"https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1"); "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1");
@ -394,8 +394,7 @@ public class OAuth2LoginConfigurerTests {
// @formatter:off // @formatter:off
// @formatter:on // @formatter:on
String requestUri = "/oauth2/authorization/google"; String requestUri = "/oauth2/authorization/google";
this.request = new MockHttpServletRequest("GET", requestUri); this.request = TestMockHttpServletRequests.get(requestUri).build();
this.request.setServletPath(requestUri);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
assertThat(this.response.getRedirectedUrl()).isEqualTo( assertThat(this.response.getRedirectedUrl()).isEqualTo(
"https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1"); "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1");
@ -418,8 +417,7 @@ public class OAuth2LoginConfigurerTests {
// @formatter:on // @formatter:on
given(resolver.resolve(any())).willReturn(result); given(resolver.resolve(any())).willReturn(result);
String requestUri = "/oauth2/authorization/google"; String requestUri = "/oauth2/authorization/google";
this.request = new MockHttpServletRequest("GET", requestUri); this.request = TestMockHttpServletRequests.get(requestUri).build();
this.request.setServletPath(requestUri);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
assertThat(this.response.getRedirectedUrl()).isEqualTo( assertThat(this.response.getRedirectedUrl()).isEqualTo(
"https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1"); "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1");
@ -432,8 +430,7 @@ public class OAuth2LoginConfigurerTests {
RedirectStrategy redirectStrategy = this.context RedirectStrategy redirectStrategy = this.context
.getBean(OAuth2LoginConfigCustomAuthorizationRedirectStrategy.class).redirectStrategy; .getBean(OAuth2LoginConfigCustomAuthorizationRedirectStrategy.class).redirectStrategy;
String requestUri = "/oauth2/authorization/google"; String requestUri = "/oauth2/authorization/google";
this.request = new MockHttpServletRequest("GET", requestUri); this.request = get(requestUri).build();
this.request.setServletPath(requestUri);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
then(redirectStrategy).should().sendRedirect(any(), any(), anyString()); then(redirectStrategy).should().sendRedirect(any(), any(), anyString());
} }
@ -445,8 +442,7 @@ public class OAuth2LoginConfigurerTests {
RedirectStrategy redirectStrategy = this.context RedirectStrategy redirectStrategy = this.context
.getBean(OAuth2LoginConfigCustomAuthorizationRedirectStrategyInLambda.class).redirectStrategy; .getBean(OAuth2LoginConfigCustomAuthorizationRedirectStrategyInLambda.class).redirectStrategy;
String requestUri = "/oauth2/authorization/google"; String requestUri = "/oauth2/authorization/google";
this.request = new MockHttpServletRequest("GET", requestUri); this.request = get(requestUri).build();
this.request.setServletPath(requestUri);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
then(redirectStrategy).should().sendRedirect(any(), any(), anyString()); then(redirectStrategy).should().sendRedirect(any(), any(), anyString());
} }
@ -456,8 +452,7 @@ public class OAuth2LoginConfigurerTests {
public void oauth2LoginWithOneClientConfiguredThenRedirectForAuthorization() throws Exception { public void oauth2LoginWithOneClientConfiguredThenRedirectForAuthorization() throws Exception {
loadConfig(OAuth2LoginConfig.class); loadConfig(OAuth2LoginConfig.class);
String requestUri = "/"; String requestUri = "/";
this.request = new MockHttpServletRequest("GET", requestUri); this.request = get(requestUri).build();
this.request.setServletPath(requestUri);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
assertThat(this.response.getRedirectedUrl()).matches("http://localhost/oauth2/authorization/google"); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/oauth2/authorization/google");
} }
@ -467,8 +462,7 @@ public class OAuth2LoginConfigurerTests {
public void oauth2LoginWithOneClientConfiguredAndFormLoginThenRedirectDefaultLoginPage() throws Exception { public void oauth2LoginWithOneClientConfiguredAndFormLoginThenRedirectDefaultLoginPage() throws Exception {
loadConfig(OAuth2LoginConfigFormLogin.class); loadConfig(OAuth2LoginConfigFormLogin.class);
String requestUri = "/"; String requestUri = "/";
this.request = new MockHttpServletRequest("GET", requestUri); this.request = get(requestUri).build();
this.request.setServletPath(requestUri);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
assertThat(this.response.getRedirectedUrl()).matches("http://localhost/login"); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/login");
} }
@ -479,8 +473,7 @@ public class OAuth2LoginConfigurerTests {
throws Exception { throws Exception {
loadConfig(OAuth2LoginConfig.class); loadConfig(OAuth2LoginConfig.class);
String requestUri = "/favicon.ico"; String requestUri = "/favicon.ico";
this.request = new MockHttpServletRequest("GET", requestUri); this.request = get(requestUri).build();
this.request.setServletPath(requestUri);
this.request.addHeader(HttpHeaders.ACCEPT, new MediaType("image", "*").toString()); this.request.addHeader(HttpHeaders.ACCEPT, new MediaType("image", "*").toString());
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
assertThat(this.response.getRedirectedUrl()).matches("http://localhost/login"); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/login");
@ -491,8 +484,7 @@ public class OAuth2LoginConfigurerTests {
public void oauth2LoginWithMultipleClientsConfiguredThenRedirectDefaultLoginPage() throws Exception { public void oauth2LoginWithMultipleClientsConfiguredThenRedirectDefaultLoginPage() throws Exception {
loadConfig(OAuth2LoginConfigMultipleClients.class); loadConfig(OAuth2LoginConfigMultipleClients.class);
String requestUri = "/"; String requestUri = "/";
this.request = new MockHttpServletRequest("GET", requestUri); this.request = get(requestUri).build();
this.request.setServletPath(requestUri);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
assertThat(this.response.getRedirectedUrl()).matches("http://localhost/login"); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/login");
} }
@ -503,8 +495,7 @@ public class OAuth2LoginConfigurerTests {
throws Exception { throws Exception {
loadConfig(OAuth2LoginConfig.class); loadConfig(OAuth2LoginConfig.class);
String requestUri = "/"; String requestUri = "/";
this.request = new MockHttpServletRequest("GET", requestUri); this.request = get(requestUri).build();
this.request.setServletPath(requestUri);
this.request.addHeader("X-Requested-With", "XMLHttpRequest"); this.request.addHeader("X-Requested-With", "XMLHttpRequest");
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
assertThat(this.response.getRedirectedUrl()).doesNotMatch("http://localhost/oauth2/authorization/google"); assertThat(this.response.getRedirectedUrl()).doesNotMatch("http://localhost/oauth2/authorization/google");
@ -515,8 +506,7 @@ public class OAuth2LoginConfigurerTests {
throws Exception { throws Exception {
loadConfig(OAuth2LoginWithHttpBasicConfig.class); loadConfig(OAuth2LoginWithHttpBasicConfig.class);
String requestUri = "/"; String requestUri = "/";
this.request = new MockHttpServletRequest("GET", requestUri); this.request = get(requestUri).build();
this.request.setServletPath(requestUri);
this.request.addHeader("X-Requested-With", "XMLHttpRequest"); this.request.addHeader("X-Requested-With", "XMLHttpRequest");
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
assertThat(this.response.getStatus()).isEqualTo(401); assertThat(this.response.getStatus()).isEqualTo(401);
@ -527,8 +517,7 @@ public class OAuth2LoginConfigurerTests {
throws Exception { throws Exception {
loadConfig(OAuth2LoginWithXHREntryPointConfig.class); loadConfig(OAuth2LoginWithXHREntryPointConfig.class);
String requestUri = "/"; String requestUri = "/";
this.request = new MockHttpServletRequest("GET", requestUri); this.request = get(requestUri).build();
this.request.setServletPath(requestUri);
this.request.addHeader("X-Requested-With", "XMLHttpRequest"); this.request.addHeader("X-Requested-With", "XMLHttpRequest");
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
assertThat(this.response.getStatus()).isEqualTo(401); assertThat(this.response.getStatus()).isEqualTo(401);
@ -540,8 +529,7 @@ public class OAuth2LoginConfigurerTests {
throws Exception { throws Exception {
loadConfig(OAuth2LoginConfigAuthorizationCodeClientAndOtherClients.class); loadConfig(OAuth2LoginConfigAuthorizationCodeClientAndOtherClients.class);
String requestUri = "/"; String requestUri = "/";
this.request = new MockHttpServletRequest("GET", requestUri); this.request = get(requestUri).build();
this.request.setServletPath(requestUri);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
assertThat(this.response.getRedirectedUrl()).matches("http://localhost/oauth2/authorization/google"); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/oauth2/authorization/google");
} }
@ -550,8 +538,7 @@ public class OAuth2LoginConfigurerTests {
public void oauth2LoginWithCustomLoginPageThenRedirectCustomLoginPage() throws Exception { public void oauth2LoginWithCustomLoginPageThenRedirectCustomLoginPage() throws Exception {
loadConfig(OAuth2LoginConfigCustomLoginPage.class); loadConfig(OAuth2LoginConfigCustomLoginPage.class);
String requestUri = "/"; String requestUri = "/";
this.request = new MockHttpServletRequest("GET", requestUri); this.request = get(requestUri).build();
this.request.setServletPath(requestUri);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
assertThat(this.response.getRedirectedUrl()).matches("http://localhost/custom-login"); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/custom-login");
} }
@ -560,8 +547,7 @@ public class OAuth2LoginConfigurerTests {
public void requestWhenOauth2LoginWithCustomLoginPageInLambdaThenRedirectCustomLoginPage() throws Exception { public void requestWhenOauth2LoginWithCustomLoginPageInLambdaThenRedirectCustomLoginPage() throws Exception {
loadConfig(OAuth2LoginConfigCustomLoginPageInLambda.class); loadConfig(OAuth2LoginConfigCustomLoginPageInLambda.class);
String requestUri = "/"; String requestUri = "/";
this.request = new MockHttpServletRequest("GET", requestUri); this.request = get(requestUri).build();
this.request.setServletPath(requestUri);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
assertThat(this.response.getRedirectedUrl()).matches("http://localhost/custom-login"); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/custom-login");
} }

View File

@ -89,6 +89,7 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests for {@link OidcUserRefreshedEventListener} with {@link OAuth2LoginConfigurer}. * Tests for {@link OidcUserRefreshedEventListener} with {@link OAuth2LoginConfigurer}.
@ -147,8 +148,7 @@ public class OidcUserRefreshedEventListenerConfigurationTests {
@BeforeEach @BeforeEach
public void setUp() { public void setUp() {
this.request = new MockHttpServletRequest("GET", ""); this.request = get("/").build();
this.request.setServletPath("/");
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(this.request, this.response)); RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(this.request, this.response));
} }

View File

@ -42,6 +42,7 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests for {@link OidcUserRefreshedEventListener}. * Tests for {@link OidcUserRefreshedEventListener}.
@ -64,8 +65,7 @@ public class OidcUserRefreshedEventListenerTests {
this.eventListener = new OidcUserRefreshedEventListener(); this.eventListener = new OidcUserRefreshedEventListener();
this.eventListener.setSecurityContextRepository(this.securityContextRepository); this.eventListener.setSecurityContextRepository(this.securityContextRepository);
this.request = new MockHttpServletRequest("GET", ""); this.request = get("/").build();
this.request.setServletPath("/");
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
} }

View File

@ -94,6 +94,7 @@ import org.springframework.security.web.authentication.AuthenticationFailureHand
import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpRequestResponseHolder;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.servlet.TestMockHttpServletRequests;
import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.MvcResult;
@ -190,8 +191,7 @@ public class Saml2LoginConfigurerTests {
@BeforeEach @BeforeEach
public void setup() { public void setup() {
this.request = new MockHttpServletRequest("POST", ""); this.request = TestMockHttpServletRequests.post("/login/saml2/sso/test-rp").build();
this.request.setServletPath("/login/saml2/sso/test-rp");
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
this.filterChain = new MockFilterChain(); this.filterChain = new MockFilterChain();
} }
@ -430,7 +430,6 @@ public class Saml2LoginConfigurerTests {
private void performSaml2Login(String expected) throws IOException, ServletException { private void performSaml2Login(String expected) throws IOException, ServletException {
// setup authentication parameters // setup authentication parameters
this.request.setRequestURI("/login/saml2/sso/registration-id"); this.request.setRequestURI("/login/saml2/sso/registration-id");
this.request.setServletPath("/login/saml2/sso/registration-id");
this.request.setParameter("SAMLResponse", this.request.setParameter("SAMLResponse",
Base64.getEncoder().encodeToString("saml2-xml-response-object".getBytes())); Base64.getEncoder().encodeToString("saml2-xml-response-object".getBytes()));
// perform test // perform test

View File

@ -76,6 +76,7 @@ import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.logout.LogoutFilter; import org.springframework.security.web.authentication.logout.LogoutFilter;
import org.springframework.security.web.authentication.logout.LogoutHandler; import org.springframework.security.web.authentication.logout.LogoutHandler;
import org.springframework.security.web.authentication.logout.LogoutSuccessHandler; import org.springframework.security.web.authentication.logout.LogoutSuccessHandler;
import org.springframework.security.web.servlet.TestMockHttpServletRequests;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.MvcResult;
@ -158,8 +159,7 @@ public class Saml2LogoutConfigurerTests {
Collections.emptyMap()); Collections.emptyMap());
principal.setRelyingPartyRegistrationId("registration-id"); principal.setRelyingPartyRegistrationId("registration-id");
this.user = new Saml2Authentication(principal, "response", AuthorityUtils.createAuthorityList("ROLE_USER")); this.user = new Saml2Authentication(principal, "response", AuthorityUtils.createAuthorityList("ROLE_USER"));
this.request = new MockHttpServletRequest("POST", ""); this.request = TestMockHttpServletRequests.post("/login/saml2/sso/test-rp").build();
this.request.setServletPath("/login/saml2/sso/test-rp");
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
} }

View File

@ -132,9 +132,7 @@ public class FilterSecurityMetadataSourceBeanDefinitionParserTests {
} }
private FilterInvocation createFilterInvocation(String path, String method) { private FilterInvocation createFilterInvocation(String path, String method) {
MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); MockHttpServletRequest request = new MockHttpServletRequest(method, path);
request.setRequestURI(path);
request.setMethod(method);
return new FilterInvocation(request, new MockHttpServletResponse(), new MockFilterChain()); return new FilterInvocation(request, new MockHttpServletResponse(), new MockFilterChain());
} }

View File

@ -134,8 +134,7 @@ public class Saml2LogoutBeanDefinitionParserTests {
principal.setRelyingPartyRegistrationId("registration-id"); principal.setRelyingPartyRegistrationId("registration-id");
this.saml2User = new Saml2Authentication(principal, "response", this.saml2User = new Saml2Authentication(principal, "response",
AuthorityUtils.createAuthorityList("ROLE_USER")); AuthorityUtils.createAuthorityList("ROLE_USER"));
this.request = new MockHttpServletRequest("POST", ""); this.request = new MockHttpServletRequest("POST", "/login/saml2/sso/test-rp");
this.request.setServletPath("/login/saml2/sso/test-rp");
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
} }

View File

@ -26,10 +26,7 @@ import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.config.util.InMemoryXmlApplicationContext; import org.springframework.security.config.util.InMemoryXmlApplicationContext;
import org.springframework.security.core.Authentication; import org.springframework.security.web.servlet.TestMockHttpServletRequests;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.web.context.HttpRequestResponseHolder;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -61,7 +58,7 @@ public class SessionManagementConfigServlet31Tests {
@BeforeEach @BeforeEach
public void setup() { public void setup() {
this.request = new MockHttpServletRequest("GET", ""); this.request = new MockHttpServletRequest();
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
this.chain = new MockFilterChain(); this.chain = new MockFilterChain();
} }
@ -75,12 +72,11 @@ public class SessionManagementConfigServlet31Tests {
@Test @Test
public void changeSessionIdThenPreserveParameters() throws Exception { public void changeSessionIdThenPreserveParameters() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); MockHttpServletRequest request = TestMockHttpServletRequests.post("/login")
.param("username", "user")
.param("password", "password")
.build();
request.getSession(); request.getSession();
request.setServletPath("/login");
request.setMethod("POST");
request.setParameter("username", "user");
request.setParameter("password", "password");
request.getSession().setAttribute("attribute1", "value1"); request.getSession().setAttribute("attribute1", "value1");
String id = request.getSession().getId(); String id = request.getSession().getId();
// @formatter:off // @formatter:off
@ -99,12 +95,11 @@ public class SessionManagementConfigServlet31Tests {
@Test @Test
public void changeSessionId() throws Exception { public void changeSessionId() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); MockHttpServletRequest request = TestMockHttpServletRequests.post("/login")
.param("username", "user")
.param("password", "password")
.build();
request.getSession(); request.getSession();
request.setServletPath("/login");
request.setMethod("POST");
request.setParameter("username", "user");
request.setParameter("password", "password");
String id = request.getSession().getId(); String id = request.getSession().getId();
// @formatter:off // @formatter:off
loadContext("<http>\n" loadContext("<http>\n"
@ -124,13 +119,4 @@ public class SessionManagementConfigServlet31Tests {
this.springSecurityFilterChain = this.context.getBean("springSecurityFilterChain", Filter.class); this.springSecurityFilterChain = this.context.getBean("springSecurityFilterChain", Filter.class);
} }
private void login(Authentication auth) {
HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository();
HttpRequestResponseHolder requestResponseHolder = new HttpRequestResponseHolder(this.request, this.response);
repo.loadContext(requestResponseHolder);
SecurityContextImpl securityContextImpl = new SecurityContextImpl();
securityContextImpl.setAuthentication(auth);
repo.saveContext(securityContextImpl, requestResponseHolder.getRequest(), requestResponseHolder.getResponse());
}
} }

View File

@ -60,7 +60,7 @@ public class CustomHttpSecurityConfigurerTests {
@BeforeEach @BeforeEach
public void setup() { public void setup() {
this.request = new MockHttpServletRequest("GET", ""); this.request = new MockHttpServletRequest();
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
this.chain = new MockFilterChain(); this.chain = new MockFilterChain();
this.request.setMethod("GET"); this.request.setMethod("GET");
@ -76,7 +76,7 @@ public class CustomHttpSecurityConfigurerTests {
@Test @Test
public void customConfiguerPermitAll() throws Exception { public void customConfiguerPermitAll() throws Exception {
loadContext(Config.class); loadContext(Config.class);
this.request.setPathInfo("/public/something"); this.request.setRequestURI("/public/something");
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
} }
@ -84,7 +84,7 @@ public class CustomHttpSecurityConfigurerTests {
@Test @Test
public void customConfiguerFormLogin() throws Exception { public void customConfiguerFormLogin() throws Exception {
loadContext(Config.class); loadContext(Config.class);
this.request.setPathInfo("/requires-authentication"); this.request.setRequestURI("/requires-authentication");
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getRedirectedUrl()).endsWith("/custom"); assertThat(this.response.getRedirectedUrl()).endsWith("/custom");
} }
@ -92,7 +92,7 @@ public class CustomHttpSecurityConfigurerTests {
@Test @Test
public void customConfiguerCustomizeDisablesCsrf() throws Exception { public void customConfiguerCustomizeDisablesCsrf() throws Exception {
loadContext(ConfigCustomize.class); loadContext(ConfigCustomize.class);
this.request.setPathInfo("/public/something"); this.request.setRequestURI("/public/something");
this.request.setMethod("POST"); this.request.setMethod("POST");
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
@ -101,7 +101,7 @@ public class CustomHttpSecurityConfigurerTests {
@Test @Test
public void customConfiguerCustomizeFormLogin() throws Exception { public void customConfiguerCustomizeFormLogin() throws Exception {
loadContext(ConfigCustomize.class); loadContext(ConfigCustomize.class);
this.request.setPathInfo("/requires-authentication"); this.request.setRequestURI("/requires-authentication");
this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain);
assertThat(this.response.getRedirectedUrl()).endsWith("/other"); assertThat(this.response.getRedirectedUrl()).endsWith("/other");
} }

View File

@ -41,6 +41,7 @@ import org.springframework.security.saml2.provider.service.web.authentication.Op
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests for {@link RelyingPartyRegistrationsBeanDefinitionParser}. * Tests for {@link RelyingPartyRegistrationsBeanDefinitionParser}.
@ -280,9 +281,7 @@ public class RelyingPartyRegistrationsBeanDefinitionParserTests {
Converter<HttpServletRequest, String> relayStateResolver = this.spring.getContext().getBean(Converter.class); Converter<HttpServletRequest, String> relayStateResolver = this.spring.getContext().getBean(Converter.class);
OpenSaml4AuthenticationRequestResolver authenticationRequestResolver = this.spring.getContext() OpenSaml4AuthenticationRequestResolver authenticationRequestResolver = this.spring.getContext()
.getBean(OpenSaml4AuthenticationRequestResolver.class); .getBean(OpenSaml4AuthenticationRequestResolver.class);
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/saml2/authenticate/one").build();
request.setRequestURI("/saml2/authenticate/one");
request.setServletPath("/saml2/authenticate/one");
authenticationRequestResolver.resolve(request); authenticationRequestResolver.resolve(request);
verify(relayStateResolver).convert(request); verify(relayStateResolver).convert(request);
} }

View File

@ -44,8 +44,6 @@ import org.springframework.web.bind.annotation.PathVariable
import org.springframework.web.bind.annotation.RequestMapping import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RestController import org.springframework.web.bind.annotation.RestController
import org.springframework.web.servlet.config.annotation.EnableWebMvc import org.springframework.web.servlet.config.annotation.EnableWebMvc
import org.springframework.web.servlet.config.annotation.PathMatchConfigurer
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer
/** /**
* Tests for [AuthorizeRequestsDsl] * Tests for [AuthorizeRequestsDsl]
@ -405,17 +403,11 @@ class AuthorizeRequestsDslTests {
this.spring.register(MvcMatcherServletPathConfig::class.java).autowire() this.spring.register(MvcMatcherServletPathConfig::class.java).autowire()
this.mockMvc.perform(MockMvcRequestBuilders.get("/spring/path") this.mockMvc.perform(MockMvcRequestBuilders.get("/spring/path")
.with { request -> .servletPath("/spring"))
request.servletPath = "/spring"
request
})
.andExpect(status().isForbidden) .andExpect(status().isForbidden)
this.mockMvc.perform(MockMvcRequestBuilders.get("/other/path") this.mockMvc.perform(MockMvcRequestBuilders.get("/other/path")
.with { request -> .servletPath("/other"))
request.servletPath = "/other"
request
})
.andExpect(status().isOk) .andExpect(status().isOk)
} }
@ -514,28 +506,15 @@ class AuthorizeRequestsDslTests {
this.spring.register(MvcMatcherServletPathConfig::class.java).autowire() this.spring.register(MvcMatcherServletPathConfig::class.java).autowire()
this.mockMvc.perform(MockMvcRequestBuilders.get("/spring/path") this.mockMvc.perform(MockMvcRequestBuilders.get("/spring/path")
.with { request -> .servletPath("/spring"))
request.apply {
servletPath = "/spring"
}
})
.andExpect(status().isForbidden) .andExpect(status().isForbidden)
this.mockMvc.perform(MockMvcRequestBuilders.put("/spring/path") this.mockMvc.perform(MockMvcRequestBuilders.put("/spring/path")
.with { request -> .servletPath("/spring"))
request.apply {
servletPath = "/spring"
csrf()
}
})
.andExpect(status().isForbidden) .andExpect(status().isForbidden)
this.mockMvc.perform(MockMvcRequestBuilders.get("/other/path") this.mockMvc.perform(MockMvcRequestBuilders.get("/other/path")
.with { request -> .servletPath("/other"))
request.apply {
servletPath = "/other"
}
})
.andExpect(status().isOk) .andExpect(status().isOk)
} }
} }

View File

@ -83,18 +83,12 @@ class RequiresChannelDslTests {
this.spring.register(MvcMatcherServletPathConfig::class.java).autowire() this.spring.register(MvcMatcherServletPathConfig::class.java).autowire()
this.mockMvc.perform(MockMvcRequestBuilders.get("/spring/path") this.mockMvc.perform(MockMvcRequestBuilders.get("/spring/path")
.with { request -> .servletPath("/spring"))
request.servletPath = "/spring"
request
})
.andExpect(status().isFound) .andExpect(status().isFound)
.andExpect(redirectedUrl("https://localhost/spring/path")) .andExpect(redirectedUrl("https://localhost/spring/path"))
this.mockMvc.perform(MockMvcRequestBuilders.get("/other/path") this.mockMvc.perform(MockMvcRequestBuilders.get("/other/path")
.with { request -> .servletPath("/other"))
request.servletPath = "/other"
request
})
.andExpect(MockMvcResultMatchers.status().isOk) .andExpect(MockMvcResultMatchers.status().isOk)
} }

View File

@ -18,6 +18,7 @@
<property name="avoidStaticImportExcludes" value="org.springframework.security.test.web.servlet.response.SecurityMockMvcResultHandlers.*" /> <property name="avoidStaticImportExcludes" value="org.springframework.security.test.web.servlet.response.SecurityMockMvcResultHandlers.*" />
<property name="avoidStaticImportExcludes" value="org.springframework.security.config.annotation.SecurityContextChangedListenerArgumentMatchers.*" /> <property name="avoidStaticImportExcludes" value="org.springframework.security.config.annotation.SecurityContextChangedListenerArgumentMatchers.*" />
<property name="avoidStaticImportExcludes" value="org.springframework.security.web.csrf.CsrfTokenAssert.*" /> <property name="avoidStaticImportExcludes" value="org.springframework.security.web.csrf.CsrfTokenAssert.*" />
<property name="avoidStaticImportExcludes" value="org.springframework.security.web.servlet.TestMockHttpServletRequests.*" />
<property name="avoidStaticImportExcludes" value="org.springframework.security.web.util.matcher.AntPathRequestMatcher.*" /> <property name="avoidStaticImportExcludes" value="org.springframework.security.web.util.matcher.AntPathRequestMatcher.*" />
<property name="avoidStaticImportExcludes" value="org.springframework.security.web.util.matcher.RegexRequestMatcher.*" /> <property name="avoidStaticImportExcludes" value="org.springframework.security.web.util.matcher.RegexRequestMatcher.*" />
<property name="avoidStaticImportExcludes" value="org.springframework.core.annotation.MergedAnnotations.SearchStrategy.*" /> <property name="avoidStaticImportExcludes" value="org.springframework.core.annotation.MergedAnnotations.SearchStrategy.*" />

View File

@ -9,7 +9,8 @@ dependencies {
implementation 'org.springframework:spring-context' implementation 'org.springframework:spring-context'
implementation 'org.springframework:spring-tx' implementation 'org.springframework:spring-tx'
testImplementation project(':spring-security-web') testImplementation project(path: ':spring-security-web')
testImplementation project(path: ':spring-security-web', configuration: 'tests')
testImplementation 'jakarta.servlet:jakarta.servlet-api' testImplementation 'jakarta.servlet:jakarta.servlet-api'
testImplementation 'org.springframework:spring-web' testImplementation 'org.springframework:spring-web'
testImplementation "org.assertj:assertj-core" testImplementation "org.assertj:assertj-core"

View File

@ -29,6 +29,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.servlet.TestMockHttpServletRequests;
import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.test.context.junit.jupiter.SpringExtension;
@ -43,9 +44,7 @@ public class HttpNamespaceWithMultipleInterceptorsTests {
@Test @Test
public void requestThatIsMatchedByDefaultInterceptorIsAllowed() throws Exception { public void requestThatIsMatchedByDefaultInterceptorIsAllowed() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = TestMockHttpServletRequests.get("/somefile.html").build();
request.setMethod("GET");
request.setServletPath("/somefile.html");
request.setSession(createAuthenticatedSession("ROLE_0", "ROLE_1", "ROLE_2")); request.setSession(createAuthenticatedSession("ROLE_0", "ROLE_1", "ROLE_2"));
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
this.fcp.doFilter(request, response, new MockFilterChain()); this.fcp.doFilter(request, response, new MockFilterChain());
@ -54,10 +53,7 @@ public class HttpNamespaceWithMultipleInterceptorsTests {
@Test @Test
public void securedUrlAccessIsRejectedWithoutRequiredRole() throws Exception { public void securedUrlAccessIsRejectedWithoutRequiredRole() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = TestMockHttpServletRequests.get("/secure/somefile.html").build();
request.setMethod("GET");
request.setServletPath("/secure/somefile.html");
request.setSession(createAuthenticatedSession("ROLE_0")); request.setSession(createAuthenticatedSession("ROLE_0"));
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
this.fcp.doFilter(request, response, new MockFilterChain()); this.fcp.doFilter(request, response, new MockFilterChain());

View File

@ -18,6 +18,7 @@ dependencies {
testImplementation project(path: ':spring-security-oauth2-core', configuration: 'tests') testImplementation project(path: ':spring-security-oauth2-core', configuration: 'tests')
testImplementation project(path: ':spring-security-oauth2-jose', configuration: 'tests') testImplementation project(path: ':spring-security-oauth2-jose', configuration: 'tests')
testImplementation project(path: ':spring-security-web', configuration: 'tests')
testImplementation 'com.squareup.okhttp3:mockwebserver' testImplementation 'com.squareup.okhttp3:mockwebserver'
testImplementation 'io.micrometer:context-propagation' testImplementation 'io.micrometer:context-propagation'
testImplementation 'io.projectreactor.netty:reactor-netty' testImplementation 'io.projectreactor.netty:reactor-netty'

View File

@ -44,6 +44,8 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.entry; import static org.assertj.core.api.Assertions.entry;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post;
/** /**
* Tests for {@link DefaultOAuth2AuthorizationRequestResolver}. * Tests for {@link DefaultOAuth2AuthorizationRequestResolver}.
@ -123,8 +125,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
@Test @Test
public void resolveWhenNotAuthorizationRequestThenDoesNotResolve() { public void resolveWhenNotAuthorizationRequestThenDoesNotResolve() {
String requestUri = "/path"; String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest).isNull(); assertThat(authorizationRequest).isNull();
} }
@ -133,7 +134,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
@Test @Test
public void resolveWhenNotAuthorizationRequestThenRequestBodyNotConsumed() throws IOException { public void resolveWhenNotAuthorizationRequestThenRequestBodyNotConsumed() throws IOException {
String requestUri = "/path"; String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); MockHttpServletRequest request = post(requestUri).build();
request.setContent("foo".getBytes(StandardCharsets.UTF_8)); request.setContent("foo".getBytes(StandardCharsets.UTF_8));
request.setCharacterEncoding(StandardCharsets.UTF_8.name()); request.setCharacterEncoding(StandardCharsets.UTF_8.name());
HttpServletRequest spyRequest = Mockito.spy(request); HttpServletRequest spyRequest = Mockito.spy(request);
@ -151,8 +152,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
ClientRegistration clientRegistration = this.registration1; ClientRegistration clientRegistration = this.registration1;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId() String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId()
+ "-invalid"; + "-invalid";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
// @formatter:off // @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.resolver.resolve(request)) .isThrownBy(() -> this.resolver.resolve(request))
@ -164,8 +164,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestWithValidClientThenResolves() { public void resolveWhenAuthorizationRequestWithValidClientThenResolves() {
ClientRegistration clientRegistration = this.registration1; ClientRegistration clientRegistration = this.registration1;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest).isNotNull(); assertThat(authorizationRequest).isNotNull();
assertThat(authorizationRequest.getAuthorizationUri()) assertThat(authorizationRequest.getAuthorizationUri())
@ -191,8 +190,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenClientAuthorizationRequiredExceptionAvailableThenResolves() { public void resolveWhenClientAuthorizationRequiredExceptionAvailableThenResolves() {
ClientRegistration clientRegistration = this.registration2; ClientRegistration clientRegistration = this.registration2;
String requestUri = "/path"; String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request,
clientRegistration.getRegistrationId()); clientRegistration.getRegistrationId());
assertThat(authorizationRequest).isNotNull(); assertThat(authorizationRequest).isNotNull();
@ -204,8 +202,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriExpanded() { public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriExpanded() {
ClientRegistration clientRegistration = this.registration2; ClientRegistration clientRegistration = this.registration2;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri());
assertThat(authorizationRequest.getRedirectUri()) assertThat(authorizationRequest.getRedirectUri())
@ -216,9 +213,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenHttpRedirectUriWithExtraVarsExpanded() { public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenHttpRedirectUriWithExtraVarsExpanded() {
ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get("localhost:8080" + requestUri).build();
request.setServerPort(8080);
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri());
assertThat(authorizationRequest.getRedirectUri()) assertThat(authorizationRequest.getRedirectUri())
@ -229,10 +224,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenHttpsRedirectUriWithExtraVarsExpanded() { public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenHttpsRedirectUriWithExtraVarsExpanded() {
ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get("https://localhost:8081" + requestUri).build();
request.setScheme("https");
request.setServerPort(8081);
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri());
assertThat(authorizationRequest.getRedirectUri()) assertThat(authorizationRequest.getRedirectUri())
@ -243,10 +235,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestIncludesPort80ThenExpandedRedirectUriWithExtraVarsExcludesPort() { public void resolveWhenAuthorizationRequestIncludesPort80ThenExpandedRedirectUriWithExtraVarsExcludesPort() {
ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get("http://localhost" + requestUri).build();
request.setScheme("http");
request.setServerPort(80);
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri());
assertThat(authorizationRequest.getRedirectUri()) assertThat(authorizationRequest.getRedirectUri())
@ -257,10 +246,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestIncludesPort443ThenExpandedRedirectUriWithExtraVarsExcludesPort() { public void resolveWhenAuthorizationRequestIncludesPort443ThenExpandedRedirectUriWithExtraVarsExcludesPort() {
ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get("https://localhost:443" + requestUri).build();
request.setScheme("https");
request.setServerPort(443);
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri());
assertThat(authorizationRequest.getRedirectUri()) assertThat(authorizationRequest.getRedirectUri())
@ -271,10 +257,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestHasNoPortThenInvalidUrlException() { public void resolveWhenAuthorizationRequestHasNoPortThenInvalidUrlException() {
ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).port(-1).build();
request.setScheme("https");
request.setServerPort(-1);
request.setServletPath(requestUri);
assertThatExceptionOfType(InvalidUrlException.class).isThrownBy(() -> this.resolver.resolve(request)); assertThatExceptionOfType(InvalidUrlException.class).isThrownBy(() -> this.resolver.resolve(request));
} }
@ -283,9 +266,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriExpandedExcludesQueryString() { public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriExpandedExcludesQueryString() {
ClientRegistration clientRegistration = this.registration2; ClientRegistration clientRegistration = this.registration2;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri + "?foo=bar").build();
request.setServletPath(requestUri);
request.setQueryString("foo=bar");
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri());
assertThat(authorizationRequest.getRedirectUri()) assertThat(authorizationRequest.getRedirectUri())
@ -296,11 +277,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestIncludesPort80ThenExpandedRedirectUriExcludesPort() { public void resolveWhenAuthorizationRequestIncludesPort80ThenExpandedRedirectUriExcludesPort() {
ClientRegistration clientRegistration = this.registration1; ClientRegistration clientRegistration = this.registration1;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setScheme("http");
request.setServerName("localhost");
request.setServerPort(80);
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAuthorizationRequestUri()) assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&"
@ -312,11 +289,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestIncludesPort443ThenExpandedRedirectUriExcludesPort() { public void resolveWhenAuthorizationRequestIncludesPort443ThenExpandedRedirectUriExcludesPort() {
ClientRegistration clientRegistration = this.registration1; ClientRegistration clientRegistration = this.registration1;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get("https://example.com:443" + requestUri).build();
request.setScheme("https");
request.setServerName("example.com");
request.setServerPort(443);
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAuthorizationRequestUri()) assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&"
@ -328,8 +301,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenClientAuthorizationRequiredExceptionAvailableThenRedirectUriIsAuthorize() { public void resolveWhenClientAuthorizationRequiredExceptionAvailableThenRedirectUriIsAuthorize() {
ClientRegistration clientRegistration = this.registration1; ClientRegistration clientRegistration = this.registration1;
String requestUri = "/path"; String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request,
clientRegistration.getRegistrationId()); clientRegistration.getRegistrationId());
assertThat(authorizationRequest.getAuthorizationRequestUri()) assertThat(authorizationRequest.getAuthorizationRequestUri())
@ -342,8 +314,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestOAuth2LoginThenRedirectUriIsLogin() { public void resolveWhenAuthorizationRequestOAuth2LoginThenRedirectUriIsLogin() {
ClientRegistration clientRegistration = this.registration2; ClientRegistration clientRegistration = this.registration2;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAuthorizationRequestUri()) assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id-2&" .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id-2&"
@ -355,9 +326,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestHasActionParameterAuthorizeThenRedirectUriIsAuthorize() { public void resolveWhenAuthorizationRequestHasActionParameterAuthorizeThenRedirectUriIsAuthorize() {
ClientRegistration clientRegistration = this.registration1; ClientRegistration clientRegistration = this.registration1;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).param("action", "authorize").build();
request.addParameter("action", "authorize");
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAuthorizationRequestUri()) assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&"
@ -369,9 +338,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestHasActionParameterLoginThenRedirectUriIsLogin() { public void resolveWhenAuthorizationRequestHasActionParameterLoginThenRedirectUriIsLogin() {
ClientRegistration clientRegistration = this.registration2; ClientRegistration clientRegistration = this.registration2;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).param("action", "login").build();
request.addParameter("action", "login");
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAuthorizationRequestUri()) assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id-2&" .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id-2&"
@ -383,8 +350,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestWithValidPublicClientThenResolves() { public void resolveWhenAuthorizationRequestWithValidPublicClientThenResolves() {
ClientRegistration clientRegistration = this.publicClientRegistration; ClientRegistration clientRegistration = this.publicClientRegistration;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest).isNotNull(); assertThat(authorizationRequest).isNotNull();
assertThat(authorizationRequest.getAuthorizationUri()) assertThat(authorizationRequest.getAuthorizationUri())
@ -420,15 +386,13 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
ClientRegistration clientRegistration = this.registration1; ClientRegistration clientRegistration = this.registration1;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertPkceApplied(authorizationRequest, clientRegistration); assertPkceApplied(authorizationRequest, clientRegistration);
clientRegistration = this.registration2; clientRegistration = this.registration2;
requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
request = new MockHttpServletRequest("GET", requestUri); request = get(requestUri).build();
request.setServletPath(requestUri);
authorizationRequest = this.resolver.resolve(request); authorizationRequest = this.resolver.resolve(request);
assertPkceApplied(authorizationRequest, clientRegistration); assertPkceApplied(authorizationRequest, clientRegistration);
} }
@ -447,15 +411,13 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
ClientRegistration clientRegistration = this.registration1; ClientRegistration clientRegistration = this.registration1;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertPkceApplied(authorizationRequest, clientRegistration); assertPkceApplied(authorizationRequest, clientRegistration);
clientRegistration = this.registration2; clientRegistration = this.registration2;
requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
request = new MockHttpServletRequest("GET", requestUri); request = get(requestUri).build();
request.setServletPath(requestUri);
authorizationRequest = this.resolver.resolve(request); authorizationRequest = this.resolver.resolve(request);
assertPkceNotApplied(authorizationRequest, clientRegistration); assertPkceNotApplied(authorizationRequest, clientRegistration);
} }
@ -491,8 +453,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthenticationRequestWithValidOidcClientThenResolves() { public void resolveWhenAuthenticationRequestWithValidOidcClientThenResolves() {
ClientRegistration clientRegistration = this.oidcRegistration; ClientRegistration clientRegistration = this.oidcRegistration;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest).isNotNull(); assertThat(authorizationRequest).isNotNull();
assertThat(authorizationRequest.getAuthorizationUri()) assertThat(authorizationRequest.getAuthorizationUri())
@ -524,8 +485,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestCustomizerRemovesNonceThenQueryExcludesNonce() { public void resolveWhenAuthorizationRequestCustomizerRemovesNonceThenQueryExcludesNonce() {
ClientRegistration clientRegistration = this.oidcRegistration; ClientRegistration clientRegistration = this.oidcRegistration;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
this.resolver.setAuthorizationRequestCustomizer( this.resolver.setAuthorizationRequestCustomizer(
(builder) -> builder.additionalParameters((params) -> params.remove(OidcParameterNames.NONCE)) (builder) -> builder.additionalParameters((params) -> params.remove(OidcParameterNames.NONCE))
.attributes((attrs) -> attrs.remove(OidcParameterNames.NONCE))); .attributes((attrs) -> attrs.remove(OidcParameterNames.NONCE)));
@ -543,8 +503,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestCustomizerAddsParameterThenQueryIncludesParameter() { public void resolveWhenAuthorizationRequestCustomizerAddsParameterThenQueryIncludesParameter() {
ClientRegistration clientRegistration = this.oidcRegistration; ClientRegistration clientRegistration = this.oidcRegistration;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
this.resolver.setAuthorizationRequestCustomizer((builder) -> builder.authorizationRequestUri((uriBuilder) -> { this.resolver.setAuthorizationRequestCustomizer((builder) -> builder.authorizationRequestUri((uriBuilder) -> {
uriBuilder.queryParam("param1", "value1"); uriBuilder.queryParam("param1", "value1");
return uriBuilder.build(); return uriBuilder.build();
@ -561,8 +520,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQueryIncludesParameter() { public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQueryIncludesParameter() {
ClientRegistration clientRegistration = this.oidcRegistration; ClientRegistration clientRegistration = this.oidcRegistration;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
this.resolver.setAuthorizationRequestCustomizer((builder) -> builder.parameters((params) -> { this.resolver.setAuthorizationRequestCustomizer((builder) -> builder.parameters((params) -> {
params.put("appid", params.get("client_id")); params.put("appid", params.get("client_id"));
params.remove("client_id"); params.remove("client_id");
@ -579,8 +537,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
OAuth2AuthorizationRequestResolver resolver = new DefaultOAuth2AuthorizationRequestResolver( OAuth2AuthorizationRequestResolver resolver = new DefaultOAuth2AuthorizationRequestResolver(
this.clientRegistrationRepository); this.clientRegistrationRepository);
String requestUri = this.authorizationRequestBaseUri + "/" + this.registration2.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + this.registration2.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = resolver.resolve(request);
assertThat(authorizationRequest.getRedirectUri()) assertThat(authorizationRequest.getRedirectUri())
.isEqualTo("http://localhost/login/oauth2/code/" + this.registration2.getRegistrationId()); .isEqualTo("http://localhost/login/oauth2/code/" + this.registration2.getRegistrationId());
@ -590,8 +547,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
public void resolveWhenAuthorizationRequestProvideCodeChallengeMethod() { public void resolveWhenAuthorizationRequestProvideCodeChallengeMethod() {
ClientRegistration clientRegistration = this.pkceClientRegistration; ClientRegistration clientRegistration = this.pkceClientRegistration;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAdditionalParameters().containsKey(PkceParameterNames.CODE_CHALLENGE_METHOD)) assertThat(authorizationRequest.getAdditionalParameters().containsKey(PkceParameterNames.CODE_CHALLENGE_METHOD))
.isTrue(); .isTrue();

View File

@ -72,6 +72,7 @@ import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests for {@link OAuth2AuthorizationCodeGrantFilter}. * Tests for {@link OAuth2AuthorizationCodeGrantFilter}.
@ -154,8 +155,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
@Test @Test
public void doFilterWhenNotAuthorizationResponseThenNotProcessed() throws Exception { public void doFilterWhenNotAuthorizationResponseThenNotProcessed() throws Exception {
String requestUri = "/path"; String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
// NOTE: A valid Authorization Response contains either a 'code' or 'error' // NOTE: A valid Authorization Response contains either a 'code' or 'error'
// parameter. // parameter.
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
@ -328,8 +328,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
@Test @Test
public void doFilterWhenAuthorizationSucceedsAndHasSavedRequestThenRedirectToSavedRequest() throws Exception { public void doFilterWhenAuthorizationSucceedsAndHasSavedRequestThenRedirectToSavedRequest() throws Exception {
String requestUri = "/saved-request"; String requestUri = "/saved-request";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
RequestCache requestCache = new HttpSessionRequestCache(); RequestCache requestCache = new HttpSessionRequestCache();
requestCache.saveRequest(request, response); requestCache.saveRequest(request, response);
@ -430,8 +429,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
private static MockHttpServletRequest createAuthorizationRequest(String requestUri, private static MockHttpServletRequest createAuthorizationRequest(String requestUri,
Map<String, String> parameters) { Map<String, String> parameters) {
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
if (!CollectionUtils.isEmpty(parameters)) { if (!CollectionUtils.isEmpty(parameters)) {
parameters.forEach(request::addParameter); parameters.forEach(request::addParameter);
request.setQueryString(parameters.entrySet() request.setQueryString(parameters.entrySet()

View File

@ -55,6 +55,7 @@ import static org.mockito.BDDMockito.willThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests for {@link OAuth2AuthorizationRequestRedirectFilter}. * Tests for {@link OAuth2AuthorizationRequestRedirectFilter}.
@ -127,8 +128,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
@Test @Test
public void doFilterWhenNotAuthorizationRequestThenNextFilter() throws Exception { public void doFilterWhenNotAuthorizationRequestThenNextFilter() throws Exception {
String requestUri = "/path"; String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain); this.filter.doFilter(request, response, filterChain);
@ -139,8 +139,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
public void doFilterWhenAuthorizationRequestWithInvalidClientThenStatusInternalServerError() throws Exception { public void doFilterWhenAuthorizationRequestWithInvalidClientThenStatusInternalServerError() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"
+ this.registration1.getRegistrationId() + "-invalid"; + this.registration1.getRegistrationId() + "-invalid";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain); this.filter.doFilter(request, response, filterChain);
@ -154,8 +153,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
throws Exception { throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"
+ this.registration1.getRegistrationId() + "-invalid"; + this.registration1.getRegistrationId() + "-invalid";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
this.filter.setAuthenticationFailureHandler((request1, response1, ex) -> { this.filter.setAuthenticationFailureHandler((request1, response1, ex) -> {
@ -178,8 +176,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
public void doFilterWhenAuthorizationRequestOAuth2LoginThenRedirectForAuthorization() throws Exception { public void doFilterWhenAuthorizationRequestOAuth2LoginThenRedirectForAuthorization() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"
+ this.registration1.getRegistrationId(); + this.registration1.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain); this.filter.doFilter(request, response, filterChain);
@ -193,8 +190,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
public void doFilterWhenAuthorizationRequestOAuth2LoginThenAuthorizationRequestSaved() throws Exception { public void doFilterWhenAuthorizationRequestOAuth2LoginThenAuthorizationRequestSaved() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"
+ this.registration2.getRegistrationId(); + this.registration2.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock( AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(
@ -212,8 +208,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
this.filter = new OAuth2AuthorizationRequestRedirectFilter(this.clientRegistrationRepository, this.filter = new OAuth2AuthorizationRequestRedirectFilter(this.clientRegistrationRepository,
authorizationRequestBaseUri); authorizationRequestBaseUri);
String requestUri = authorizationRequestBaseUri + "/" + this.registration1.getRegistrationId(); String requestUri = authorizationRequestBaseUri + "/" + this.registration1.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain); this.filter.doFilter(request, response, filterChain);
@ -227,8 +222,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownThenRedirectForAuthorization() public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownThenRedirectForAuthorization()
throws Exception { throws Exception {
String requestUri = "/path"; String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
willThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId())).given(filterChain) willThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId())).given(filterChain)
@ -245,8 +239,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownButAuthorizationRequestNotResolvedThenStatusInternalServerError() public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownButAuthorizationRequestNotResolvedThenStatusInternalServerError()
throws Exception { throws Exception {
String requestUri = "/path"; String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
willThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId())).given(filterChain) willThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId())).given(filterChain)
@ -266,8 +259,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
throws Exception { throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"
+ this.registration1.getRegistrationId(); + this.registration1.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
request.addParameter("idp", "https://other.provider.com"); request.addParameter("idp", "https://other.provider.com");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
@ -295,8 +287,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
throws Exception { throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"
+ this.registration1.getRegistrationId(); + this.registration1.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
String loginHintParamName = "login_hint"; String loginHintParamName = "login_hint";
request.addParameter(loginHintParamName, "user@provider.com"); request.addParameter(loginHintParamName, "user@provider.com");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
@ -335,8 +326,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
throws Exception { throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"
+ this.registration1.getRegistrationId(); + this.registration1.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
RedirectStrategy customRedirectStrategy = (httpRequest, httpResponse, url) -> { RedirectStrategy customRedirectStrategy = (httpRequest, httpResponse, url) -> {
@ -363,8 +353,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownThenSaveRequestBeforeCommitted() public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownThenSaveRequestBeforeCommitted()
throws Exception { throws Exception {
String requestUri = "/path"; String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
willAnswer((invocation) -> assertThat((invocation.<HttpServletResponse>getArgument(1)).isCommitted()).isFalse()) willAnswer((invocation) -> assertThat((invocation.<HttpServletResponse>getArgument(1)).isCommitted()).isFalse())

View File

@ -69,6 +69,7 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests for {@link OAuth2LoginAuthenticationFilter}. * Tests for {@link OAuth2LoginAuthenticationFilter}.
@ -163,8 +164,7 @@ public class OAuth2LoginAuthenticationFilterTests {
@Test @Test
public void doFilterWhenNotAuthorizationResponseThenNextFilter() throws Exception { public void doFilterWhenNotAuthorizationResponseThenNextFilter() throws Exception {
String requestUri = "/path"; String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain); this.filter.doFilter(request, response, filterChain);
@ -176,8 +176,7 @@ public class OAuth2LoginAuthenticationFilterTests {
@Test @Test
public void doFilterWhenAuthorizationResponseInvalidThenInvalidRequestError() throws Exception { public void doFilterWhenAuthorizationResponseInvalidThenInvalidRequestError() throws Exception {
String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).build();
request.setServletPath(requestUri);
// NOTE: // NOTE:
// A valid Authorization Response contains either a 'code' or 'error' parameter. // A valid Authorization Response contains either a 'code' or 'error' parameter.
// Don't set it to force an invalid Authorization Response. // Don't set it to force an invalid Authorization Response.
@ -198,10 +197,9 @@ public class OAuth2LoginAuthenticationFilterTests {
public void doFilterWhenAuthorizationResponseAuthorizationRequestNotFoundThenAuthorizationRequestNotFoundError() public void doFilterWhenAuthorizationResponseAuthorizationRequestNotFoundThenAuthorizationRequestNotFoundError()
throws Exception { throws Exception {
String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code")
request.setServletPath(requestUri); .param(OAuth2ParameterNames.STATE, "state")
request.addParameter(OAuth2ParameterNames.CODE, "code"); .build();
request.addParameter(OAuth2ParameterNames.STATE, "state");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain); this.filter.doFilter(request, response, filterChain);
@ -221,10 +219,9 @@ public class OAuth2LoginAuthenticationFilterTests {
throws Exception { throws Exception {
String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
String state = "state"; String state = "state";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code")
request.setServletPath(requestUri); .param(OAuth2ParameterNames.STATE, state)
request.addParameter(OAuth2ParameterNames.CODE, "code"); .build();
request.addParameter(OAuth2ParameterNames.STATE, "state");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
// @formatter:off // @formatter:off
@ -258,10 +255,9 @@ public class OAuth2LoginAuthenticationFilterTests {
public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception { public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception {
String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
String state = "state"; String state = "state";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code")
request.setServletPath(requestUri); .param(OAuth2ParameterNames.STATE, state)
request.addParameter(OAuth2ParameterNames.CODE, "code"); .build();
request.addParameter(OAuth2ParameterNames.STATE, state);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration2, state); this.setUpAuthorizationRequest(request, response, this.registration2, state);
@ -274,10 +270,9 @@ public class OAuth2LoginAuthenticationFilterTests {
public void doFilterWhenAuthorizationResponseValidThenAuthorizedClientSaved() throws Exception { public void doFilterWhenAuthorizationResponseValidThenAuthorizedClientSaved() throws Exception {
String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
String state = "state"; String state = "state";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code")
request.setServletPath(requestUri); .param(OAuth2ParameterNames.STATE, state)
request.addParameter(OAuth2ParameterNames.CODE, "code"); .build();
request.addParameter(OAuth2ParameterNames.STATE, state);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration1, state); this.setUpAuthorizationRequest(request, response, this.registration1, state);
@ -300,10 +295,9 @@ public class OAuth2LoginAuthenticationFilterTests {
this.filter.setAuthenticationManager(this.authenticationManager); this.filter.setAuthenticationManager(this.authenticationManager);
String requestUri = "/login/oauth2/custom/" + this.registration2.getRegistrationId(); String requestUri = "/login/oauth2/custom/" + this.registration2.getRegistrationId();
String state = "state"; String state = "state";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code")
request.setServletPath(requestUri); .param(OAuth2ParameterNames.STATE, state)
request.addParameter(OAuth2ParameterNames.CODE, "code"); .build();
request.addParameter(OAuth2ParameterNames.STATE, state);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration2, state); this.setUpAuthorizationRequest(request, response, this.registration2, state);
@ -319,13 +313,9 @@ public class OAuth2LoginAuthenticationFilterTests {
throws Exception { throws Exception {
String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
String state = "state"; String state = "state";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code")
request.setScheme("http"); .param(OAuth2ParameterNames.STATE, state)
request.setServerName("localhost"); .build();
request.setServerPort(80);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration2, state); this.setUpAuthorizationRequest(request, response, this.registration2, state);
@ -350,13 +340,10 @@ public class OAuth2LoginAuthenticationFilterTests {
throws Exception { throws Exception {
String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
String state = "state"; String state = "state";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get("https://example.com:443" + requestUri)
request.setScheme("https"); .param(OAuth2ParameterNames.CODE, "code")
request.setServerName("example.com"); .param(OAuth2ParameterNames.STATE, state)
request.setServerPort(443); .build();
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration2, state); this.setUpAuthorizationRequest(request, response, this.registration2, state);
@ -381,13 +368,10 @@ public class OAuth2LoginAuthenticationFilterTests {
throws Exception { throws Exception {
String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
String state = "state"; String state = "state";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get("https://example.com:9090" + requestUri)
request.setScheme("https"); .param(OAuth2ParameterNames.CODE, "code")
request.setServerName("example.com"); .param(OAuth2ParameterNames.STATE, state)
request.setServerPort(9090); .build();
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration2, state); this.setUpAuthorizationRequest(request, response, this.registration2, state);
@ -411,10 +395,9 @@ public class OAuth2LoginAuthenticationFilterTests {
public void attemptAuthenticationShouldSetAuthenticationDetailsOnAuthenticationResult() throws Exception { public void attemptAuthenticationShouldSetAuthenticationDetailsOnAuthenticationResult() throws Exception {
String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
String state = "state"; String state = "state";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code")
request.setServletPath(requestUri); .param(OAuth2ParameterNames.STATE, state)
request.addParameter(OAuth2ParameterNames.CODE, "code"); .build();
request.addParameter(OAuth2ParameterNames.STATE, state);
WebAuthenticationDetails webAuthenticationDetails = mock(WebAuthenticationDetails.class); WebAuthenticationDetails webAuthenticationDetails = mock(WebAuthenticationDetails.class);
given(this.authenticationDetailsSource.buildDetails(any())).willReturn(webAuthenticationDetails); given(this.authenticationDetailsSource.buildDetails(any())).willReturn(webAuthenticationDetails);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
@ -430,10 +413,9 @@ public class OAuth2LoginAuthenticationFilterTests {
this.filter.setAuthenticationResultConverter((authentication) -> null); this.filter.setAuthenticationResultConverter((authentication) -> null);
String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
String state = "state"; String state = "state";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code")
request.setServletPath(requestUri); .param(OAuth2ParameterNames.STATE, state)
request.addParameter(OAuth2ParameterNames.CODE, "code"); .build();
request.addParameter(OAuth2ParameterNames.STATE, state);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
this.setUpAuthorizationRequest(request, response, this.registration1, state); this.setUpAuthorizationRequest(request, response, this.registration1, state);
this.setUpAuthenticationResult(this.registration1); this.setUpAuthenticationResult(this.registration1);
@ -448,10 +430,9 @@ public class OAuth2LoginAuthenticationFilterTests {
authentication.getAuthorities(), authentication.getClientRegistration().getRegistrationId())); authentication.getAuthorities(), authentication.getClientRegistration().getRegistrationId()));
String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
String state = "state"; String state = "state";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code")
request.setServletPath(requestUri); .param(OAuth2ParameterNames.STATE, state)
request.addParameter(OAuth2ParameterNames.CODE, "code"); .build();
request.addParameter(OAuth2ParameterNames.STATE, state);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
this.setUpAuthorizationRequest(request, response, this.registration1, state); this.setUpAuthorizationRequest(request, response, this.registration1, state);
this.setUpAuthenticationResult(this.registration1); this.setUpAuthenticationResult(this.registration1);

View File

@ -108,6 +108,7 @@ dependencies {
optional 'com.fasterxml.jackson.core:jackson-databind' optional 'com.fasterxml.jackson.core:jackson-databind'
optional 'org.springframework:spring-jdbc' optional 'org.springframework:spring-jdbc'
testImplementation project(path: ':spring-security-web', configuration: 'tests')
testImplementation 'com.squareup.okhttp3:mockwebserver' testImplementation 'com.squareup.okhttp3:mockwebserver'
testImplementation "org.assertj:assertj-core" testImplementation "org.assertj:assertj-core"
testImplementation "org.skyscreamer:jsonassert" testImplementation "org.skyscreamer:jsonassert"

View File

@ -42,6 +42,7 @@ import org.springframework.security.saml2.provider.service.authentication.TestOp
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.web.servlet.TestMockHttpServletRequests;
import org.springframework.util.StreamUtils; import org.springframework.util.StreamUtils;
import org.springframework.web.util.UriUtils; import org.springframework.web.util.UriUtils;
@ -216,15 +217,11 @@ public final class OpenSaml4AuthenticationTokenConverterTests {
} }
private MockHttpServletRequest post(String uri) { private MockHttpServletRequest post(String uri) {
MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); return TestMockHttpServletRequests.post(uri).build();
request.setServletPath(uri);
return request;
} }
private MockHttpServletRequest get(String uri) { private MockHttpServletRequest get(String uri) {
MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); return TestMockHttpServletRequests.get(uri).build();
request.setServletPath(uri);
return request;
} }
private <T extends SignableSAMLObject> T signed(T toSign) { private <T extends SignableSAMLObject> T signed(T toSign) {

View File

@ -42,6 +42,7 @@ import org.springframework.security.saml2.provider.service.authentication.TestOp
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.web.servlet.TestMockHttpServletRequests;
import org.springframework.util.StreamUtils; import org.springframework.util.StreamUtils;
import org.springframework.web.util.UriUtils; import org.springframework.web.util.UriUtils;
@ -216,15 +217,11 @@ public final class OpenSamlAuthenticationTokenConverterTests {
} }
private MockHttpServletRequest post(String uri) { private MockHttpServletRequest post(String uri) {
MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); return TestMockHttpServletRequests.post(uri).build();
request.setServletPath(uri);
return request;
} }
private MockHttpServletRequest get(String uri) { private MockHttpServletRequest get(String uri) {
MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); return TestMockHttpServletRequests.get(uri).build();
request.setServletPath(uri);
return request;
} }
private <T extends SignableSAMLObject> T signed(T toSign) { private <T extends SignableSAMLObject> T signed(T toSign) {

View File

@ -28,6 +28,7 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.security.web.servlet.TestMockHttpServletRequests;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -102,9 +103,7 @@ public class OpenSaml4AuthenticationRequestResolverTests {
} }
private MockHttpServletRequest givenRequest(String path) { private MockHttpServletRequest givenRequest(String path) {
MockHttpServletRequest request = new MockHttpServletRequest("GET", path); return TestMockHttpServletRequests.get(path).build();
request.setServletPath(path);
return request;
} }
} }

View File

@ -36,6 +36,7 @@ import org.springframework.security.saml2.provider.service.authentication.logout
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.web.servlet.TestMockHttpServletRequests;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@ -135,15 +136,11 @@ public final class OpenSaml4LogoutRequestValidatorParametersResolverTests {
} }
private MockHttpServletRequest post(String uri) { private MockHttpServletRequest post(String uri) {
MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); return TestMockHttpServletRequests.post(uri).build();
request.setServletPath(uri);
return request;
} }
private MockHttpServletRequest get(String uri) { private MockHttpServletRequest get(String uri) {
MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); return TestMockHttpServletRequests.get(uri).build();
request.setServletPath(uri);
return request;
} }
private String serialize(XMLObject object) { private String serialize(XMLObject object) {

View File

@ -36,6 +36,7 @@ import org.springframework.security.saml2.provider.service.authentication.logout
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.web.servlet.TestMockHttpServletRequests;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@ -135,15 +136,11 @@ public final class OpenSamlLogoutRequestValidatorParametersResolverTests {
} }
private MockHttpServletRequest post(String uri) { private MockHttpServletRequest post(String uri) {
MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); return TestMockHttpServletRequests.post(uri).build();
request.setServletPath(uri);
return request;
} }
private MockHttpServletRequest get(String uri) { private MockHttpServletRequest get(String uri) {
MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); return TestMockHttpServletRequests.get(uri).build();
request.setServletPath(uri);
return request;
} }
private String serialize(XMLObject object) { private String serialize(XMLObject object) {

View File

@ -42,6 +42,7 @@ import org.springframework.security.saml2.provider.service.authentication.TestOp
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.web.servlet.TestMockHttpServletRequests;
import org.springframework.util.StreamUtils; import org.springframework.util.StreamUtils;
import org.springframework.web.util.UriUtils; import org.springframework.web.util.UriUtils;
@ -216,15 +217,11 @@ public final class OpenSaml5AuthenticationTokenConverterTests {
} }
private MockHttpServletRequest post(String uri) { private MockHttpServletRequest post(String uri) {
MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); return TestMockHttpServletRequests.post(uri).build();
request.setServletPath(uri);
return request;
} }
private MockHttpServletRequest get(String uri) { private MockHttpServletRequest get(String uri) {
MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); return TestMockHttpServletRequests.get(uri).build();
request.setServletPath(uri);
return request;
} }
private <T extends SignableSAMLObject> T signed(T toSign) { private <T extends SignableSAMLObject> T signed(T toSign) {

View File

@ -28,6 +28,7 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.security.web.servlet.TestMockHttpServletRequests;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -102,9 +103,7 @@ public class OpenSaml5AuthenticationRequestResolverTests {
} }
private MockHttpServletRequest givenRequest(String path) { private MockHttpServletRequest givenRequest(String path) {
MockHttpServletRequest request = new MockHttpServletRequest("GET", path); return TestMockHttpServletRequests.get(path).build();
request.setServletPath(path);
return request;
} }
} }

View File

@ -36,6 +36,7 @@ import org.springframework.security.saml2.provider.service.authentication.logout
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.web.servlet.TestMockHttpServletRequests;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@ -135,15 +136,11 @@ public final class OpenSaml5LogoutRequestValidatorParametersResolverTests {
} }
private MockHttpServletRequest post(String uri) { private MockHttpServletRequest post(String uri) {
MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); return TestMockHttpServletRequests.post(uri).build();
request.setServletPath(uri);
return request;
} }
private MockHttpServletRequest get(String uri) { private MockHttpServletRequest get(String uri) {
MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); return TestMockHttpServletRequests.get(uri).build();
request.setServletPath(uri);
return request;
} }
private String serialize(XMLObject object) { private String serialize(XMLObject object) {

View File

@ -30,6 +30,7 @@ import org.springframework.security.saml2.provider.service.registration.InMemory
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.web.servlet.TestMockHttpServletRequests;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@ -121,9 +122,7 @@ public final class RequestMatcherMetadataResponseResolverTests {
} }
private MockHttpServletRequest get(String uri) { private MockHttpServletRequest get(String uri) {
MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); return TestMockHttpServletRequests.get(uri).build();
request.setServletPath(uri);
return request;
} }
private RelyingPartyRegistration withEntityId(String entityId) { private RelyingPartyRegistration withEntityId(String entityId) {

View File

@ -46,6 +46,7 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.mock; import static org.mockito.BDDMockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post;
/** /**
* Tests for {@link Saml2LogoutRequestFilter} * Tests for {@link Saml2LogoutRequestFilter}
@ -76,9 +77,8 @@ public class Saml2LogoutRequestFilterTests {
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
Authentication authentication = new TestingAuthenticationToken("user", "password"); Authentication authentication = new TestingAuthenticationToken("user", "password");
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_REQUEST, "request")
request.setServletPath("/logout/saml2/slo"); .build();
request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration); given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration);
given(this.logoutRequestValidator.validate(any())).willReturn(Saml2LogoutValidatorResult.success()); given(this.logoutRequestValidator.validate(any())).willReturn(Saml2LogoutValidatorResult.success());
@ -105,9 +105,8 @@ public class Saml2LogoutRequestFilterTests {
given(this.securityContextHolderStrategy.getContext()).willReturn(new SecurityContextImpl(authentication)); given(this.securityContextHolderStrategy.getContext()).willReturn(new SecurityContextImpl(authentication));
this.logoutRequestProcessingFilter.setSecurityContextHolderStrategy(this.securityContextHolderStrategy); this.logoutRequestProcessingFilter.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_REQUEST, "request")
request.setServletPath("/logout/saml2/slo"); .build();
request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration); given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration);
given(this.logoutRequestValidator.validate(any())).willReturn(Saml2LogoutValidatorResult.success()); given(this.logoutRequestValidator.validate(any())).willReturn(Saml2LogoutValidatorResult.success());
@ -127,9 +126,7 @@ public class Saml2LogoutRequestFilterTests {
public void doFilterWhenRequestMismatchesThenNoLogout() throws Exception { public void doFilterWhenRequestMismatchesThenNoLogout() throws Exception {
Authentication authentication = new TestingAuthenticationToken("user", "password"); Authentication authentication = new TestingAuthenticationToken("user", "password");
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout"); MockHttpServletRequest request = post("/logout").param(Saml2ParameterNames.SAML_RESPONSE, "response").build();
request.setServletPath("/logout");
request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
this.logoutRequestProcessingFilter.doFilterInternal(request, response, new MockFilterChain()); this.logoutRequestProcessingFilter.doFilterInternal(request, response, new MockFilterChain());
verifyNoInteractions(this.logoutRequestValidator, this.logoutHandler); verifyNoInteractions(this.logoutRequestValidator, this.logoutHandler);
@ -139,8 +136,7 @@ public class Saml2LogoutRequestFilterTests {
public void doFilterWhenNoSamlRequestOrResponseThenNoLogout() throws Exception { public void doFilterWhenNoSamlRequestOrResponseThenNoLogout() throws Exception {
Authentication authentication = new TestingAuthenticationToken("user", "password"); Authentication authentication = new TestingAuthenticationToken("user", "password");
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); MockHttpServletRequest request = post("/logout/saml2/slo").build();
request.setServletPath("/logout/saml2/slo");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
this.logoutRequestProcessingFilter.doFilterInternal(request, response, new MockFilterChain()); this.logoutRequestProcessingFilter.doFilterInternal(request, response, new MockFilterChain());
verifyNoInteractions(this.logoutRequestValidator, this.logoutHandler); verifyNoInteractions(this.logoutRequestValidator, this.logoutHandler);
@ -153,9 +149,8 @@ public class Saml2LogoutRequestFilterTests {
.build(); .build();
Authentication authentication = new TestingAuthenticationToken("user", "password"); Authentication authentication = new TestingAuthenticationToken("user", "password");
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_REQUEST, "request")
request.setServletPath("/logout/saml2/slo"); .build();
request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
Saml2LogoutResponse logoutResponse = Saml2LogoutResponse.withRelyingPartyRegistration(registration) Saml2LogoutResponse logoutResponse = Saml2LogoutResponse.withRelyingPartyRegistration(registration)
.samlResponse("response") .samlResponse("response")
@ -182,7 +177,6 @@ public class Saml2LogoutRequestFilterTests {
Authentication authentication = new TestingAuthenticationToken("user", "password"); Authentication authentication = new TestingAuthenticationToken("user", "password");
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo");
request.setServletPath("/logout/saml2/slo");
request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full()
@ -210,9 +204,8 @@ public class Saml2LogoutRequestFilterTests {
public void doFilterWhenInvalidBindingErrorLogoutResponseIsPosted() throws Exception { public void doFilterWhenInvalidBindingErrorLogoutResponseIsPosted() throws Exception {
Authentication authentication = new TestingAuthenticationToken("user", "password"); Authentication authentication = new TestingAuthenticationToken("user", "password");
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_REQUEST, "request")
request.setServletPath("/logout/saml2/slo"); .build();
request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full()
.assertingPartyMetadata((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST)) .assertingPartyMetadata((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST))
@ -242,9 +235,8 @@ public class Saml2LogoutRequestFilterTests {
public void doFilterWhenNoErrorResponseCanBeGeneratedThen401() throws Exception { public void doFilterWhenNoErrorResponseCanBeGeneratedThen401() throws Exception {
Authentication authentication = new TestingAuthenticationToken("user", "password"); Authentication authentication = new TestingAuthenticationToken("user", "password");
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_REQUEST, "request")
request.setServletPath("/logout/saml2/slo"); .build();
request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full()
.assertingPartyMetadata((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST)) .assertingPartyMetadata((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST))

View File

@ -43,6 +43,8 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.mock; import static org.mockito.BDDMockito.mock;
import static org.mockito.BDDMockito.verify; import static org.mockito.BDDMockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post;
/** /**
* Tests for {@link Saml2LogoutResponseFilter} * Tests for {@link Saml2LogoutResponseFilter}
@ -74,9 +76,8 @@ public class Saml2LogoutResponseFilterTests {
public void doFilterWhenSamlResponsePostThenLogout() throws Exception { public void doFilterWhenSamlResponsePostThenLogout() throws Exception {
Authentication authentication = new TestingAuthenticationToken("user", "password"); Authentication authentication = new TestingAuthenticationToken("user", "password");
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_RESPONSE, "response")
request.setServletPath("/logout/saml2/slo"); .build();
request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
given(this.relyingPartyRegistrationResolver.resolve(request, "registration-id")).willReturn(registration); given(this.relyingPartyRegistrationResolver.resolve(request, "registration-id")).willReturn(registration);
@ -94,8 +95,7 @@ public class Saml2LogoutResponseFilterTests {
public void doFilterWhenSamlResponseRedirectThenLogout() throws Exception { public void doFilterWhenSamlResponseRedirectThenLogout() throws Exception {
Authentication authentication = new TestingAuthenticationToken("user", "password"); Authentication authentication = new TestingAuthenticationToken("user", "password");
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/logout/saml2/slo"); MockHttpServletRequest request = get("/logout/saml2/slo").build();
request.setServletPath("/logout/saml2/slo");
request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response"); request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full()
@ -116,9 +116,7 @@ public class Saml2LogoutResponseFilterTests {
public void doFilterWhenRequestMismatchesThenNoLogout() throws Exception { public void doFilterWhenRequestMismatchesThenNoLogout() throws Exception {
Authentication authentication = new TestingAuthenticationToken("user", "password"); Authentication authentication = new TestingAuthenticationToken("user", "password");
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout"); MockHttpServletRequest request = post("/logout").param(Saml2ParameterNames.SAML_REQUEST, "request").build();
request.setServletPath("/logout");
request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
this.logoutResponseProcessingFilter.doFilterInternal(request, response, new MockFilterChain()); this.logoutResponseProcessingFilter.doFilterInternal(request, response, new MockFilterChain());
verifyNoInteractions(this.logoutResponseValidator, this.logoutSuccessHandler); verifyNoInteractions(this.logoutResponseValidator, this.logoutSuccessHandler);
@ -128,8 +126,7 @@ public class Saml2LogoutResponseFilterTests {
public void doFilterWhenNoSamlRequestOrResponseThenNoLogout() throws Exception { public void doFilterWhenNoSamlRequestOrResponseThenNoLogout() throws Exception {
Authentication authentication = new TestingAuthenticationToken("user", "password"); Authentication authentication = new TestingAuthenticationToken("user", "password");
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); MockHttpServletRequest request = post("/logout/saml2/slo").build();
request.setServletPath("/logout/saml2/slo");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
this.logoutResponseProcessingFilter.doFilterInternal(request, response, new MockFilterChain()); this.logoutResponseProcessingFilter.doFilterInternal(request, response, new MockFilterChain());
verifyNoInteractions(this.logoutResponseValidator, this.logoutSuccessHandler); verifyNoInteractions(this.logoutResponseValidator, this.logoutSuccessHandler);
@ -139,9 +136,8 @@ public class Saml2LogoutResponseFilterTests {
public void doFilterWhenValidatorFailsThenStops() throws Exception { public void doFilterWhenValidatorFailsThenStops() throws Exception {
Authentication authentication = new TestingAuthenticationToken("user", "password"); Authentication authentication = new TestingAuthenticationToken("user", "password");
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_RESPONSE, "response")
request.setServletPath("/logout/saml2/slo"); .build();
request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
given(this.relyingPartyRegistrationResolver.resolve(request, "registration-id")).willReturn(registration); given(this.relyingPartyRegistrationResolver.resolve(request, "registration-id")).willReturn(registration);
@ -160,9 +156,8 @@ public class Saml2LogoutResponseFilterTests {
public void doFilterWhenNoRelyingPartyLogoutThen401() throws Exception { public void doFilterWhenNoRelyingPartyLogoutThen401() throws Exception {
Authentication authentication = new TestingAuthenticationToken("user", "password"); Authentication authentication = new TestingAuthenticationToken("user", "password");
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_RESPONSE, "response")
request.setServletPath("/logout/saml2/slo"); .build();
request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full()
.singleLogoutServiceLocation(null) .singleLogoutServiceLocation(null)

View File

@ -39,6 +39,7 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.mock; import static org.mockito.BDDMockito.mock;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post;
/** /**
* Tests for {@link Saml2RelyingPartyInitiatedLogoutSuccessHandler} * Tests for {@link Saml2RelyingPartyInitiatedLogoutSuccessHandler}
@ -72,8 +73,7 @@ public class Saml2RelyingPartyInitiatedLogoutSuccessHandlerTests {
Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration) Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration)
.samlRequest("request") .samlRequest("request")
.build(); .build();
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/saml2/logout"); MockHttpServletRequest request = post("/saml2/logout").build();
request.setServletPath("/saml2/logout");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
given(this.logoutRequestResolver.resolve(any(), any())).willReturn(logoutRequest); given(this.logoutRequestResolver.resolve(any(), any())).willReturn(logoutRequest);
this.logoutRequestSuccessHandler.onLogoutSuccess(request, response, authentication); this.logoutRequestSuccessHandler.onLogoutSuccess(request, response, authentication);
@ -92,8 +92,7 @@ public class Saml2RelyingPartyInitiatedLogoutSuccessHandlerTests {
Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration) Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration)
.samlRequest("request") .samlRequest("request")
.build(); .build();
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/saml2/logout"); MockHttpServletRequest request = post("/saml2/logout").build();
request.setServletPath("/saml2/logout");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
given(this.logoutRequestResolver.resolve(any(), any())).willReturn(logoutRequest); given(this.logoutRequestResolver.resolve(any(), any())).willReturn(logoutRequest);
this.logoutRequestSuccessHandler.onLogoutSuccess(request, response, authentication); this.logoutRequestSuccessHandler.onLogoutSuccess(request, response, authentication);

View File

@ -64,6 +64,7 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* @author Luke Taylor * @author Luke Taylor
@ -96,8 +97,7 @@ public class FilterChainProxyTests {
}).given(this.filter).doFilter(any(), any(), any()); }).given(this.filter).doFilter(any(), any(), any());
this.fcp = new FilterChainProxy(new DefaultSecurityFilterChain(this.matcher, Arrays.asList(this.filter))); this.fcp = new FilterChainProxy(new DefaultSecurityFilterChain(this.matcher, Arrays.asList(this.filter)));
this.fcp.setFilterChainValidator(mock(FilterChainProxy.FilterChainValidator.class)); this.fcp.setFilterChainValidator(mock(FilterChainProxy.FilterChainValidator.class));
this.request = new MockHttpServletRequest("GET", ""); this.request = get("/path").build();
this.request.setServletPath("/path");
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
this.chain = mock(FilterChain.class); this.chain = mock(FilterChain.class);
} }

View File

@ -34,6 +34,7 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests {@link FilterInvocation}. * Tests {@link FilterInvocation}.
@ -45,14 +46,8 @@ public class FilterInvocationTests {
@Test @Test
public void testGettersAndStringMethods() { public void testGettersAndStringMethods() {
MockHttpServletRequest request = new MockHttpServletRequest(null, null); MockHttpServletRequest request = get().requestUri("/mycontext", "/HelloWorld", "/some/more/segments.html")
request.setServletPath("/HelloWorld"); .build();
request.setPathInfo("/some/more/segments.html");
request.setServerName("localhost");
request.setScheme("http");
request.setServerPort(80);
request.setContextPath("/mycontext");
request.setRequestURI("/mycontext/HelloWorld/some/more/segments.html");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain chain = mock(FilterChain.class); FilterChain chain = mock(FilterChain.class);
FilterInvocation fi = new FilterInvocation(request, response, chain); FilterInvocation fi = new FilterInvocation(request, response, chain);
@ -62,7 +57,7 @@ public class FilterInvocationTests {
assertThat(fi.getHttpResponse()).isEqualTo(response); assertThat(fi.getHttpResponse()).isEqualTo(response);
assertThat(fi.getChain()).isEqualTo(chain); assertThat(fi.getChain()).isEqualTo(chain);
assertThat(fi.getRequestUrl()).isEqualTo("/HelloWorld/some/more/segments.html"); assertThat(fi.getRequestUrl()).isEqualTo("/HelloWorld/some/more/segments.html");
assertThat(fi.toString()).isEqualTo("filter invocation [/HelloWorld/some/more/segments.html]"); assertThat(fi.toString()).isEqualTo("filter invocation [GET /HelloWorld/some/more/segments.html]");
assertThat(fi.getFullRequestUrl()).isEqualTo("http://localhost/mycontext/HelloWorld/some/more/segments.html"); assertThat(fi.getFullRequestUrl()).isEqualTo("http://localhost/mycontext/HelloWorld/some/more/segments.html");
} }
@ -89,34 +84,23 @@ public class FilterInvocationTests {
@Test @Test
public void testStringMethodsWithAQueryString() { public void testStringMethodsWithAQueryString() {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get().requestUri("/mycontext", "/HelloWorld", null)
request.setQueryString("foo=bar"); .queryString("foo=bar")
request.setServletPath("/HelloWorld"); .build();
request.setServerName("localhost");
request.setScheme("http");
request.setServerPort(80);
request.setContextPath("/mycontext");
request.setRequestURI("/mycontext/HelloWorld");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class));
assertThat(fi.getRequestUrl()).isEqualTo("/HelloWorld?foo=bar"); assertThat(fi.getRequestUrl()).isEqualTo("/HelloWorld?foo=bar");
assertThat(fi.toString()).isEqualTo("filter invocation [/HelloWorld?foo=bar]"); assertThat(fi.toString()).isEqualTo("filter invocation [GET /HelloWorld?foo=bar]");
assertThat(fi.getFullRequestUrl()).isEqualTo("http://localhost/mycontext/HelloWorld?foo=bar"); assertThat(fi.getFullRequestUrl()).isEqualTo("http://localhost/mycontext/HelloWorld?foo=bar");
} }
@Test @Test
public void testStringMethodsWithoutAnyQueryString() { public void testStringMethodsWithoutAnyQueryString() {
MockHttpServletRequest request = new MockHttpServletRequest(null, null); MockHttpServletRequest request = get().requestUri("/mycontext", "/HelloWorld", null).build();
request.setServletPath("/HelloWorld");
request.setServerName("localhost");
request.setScheme("http");
request.setServerPort(80);
request.setContextPath("/mycontext");
request.setRequestURI("/mycontext/HelloWorld");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class));
assertThat(fi.getRequestUrl()).isEqualTo("/HelloWorld"); assertThat(fi.getRequestUrl()).isEqualTo("/HelloWorld");
assertThat(fi.toString()).isEqualTo("filter invocation [/HelloWorld]"); assertThat(fi.toString()).isEqualTo("filter invocation [GET /HelloWorld]");
assertThat(fi.getFullRequestUrl()).isEqualTo("http://localhost/mycontext/HelloWorld"); assertThat(fi.getFullRequestUrl()).isEqualTo("http://localhost/mycontext/HelloWorld");
} }

View File

@ -29,6 +29,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests for {@link RequestMatcherRedirectFilter}. * Tests for {@link RequestMatcherRedirectFilter}.
@ -42,9 +43,7 @@ public class RequestMatcherRedirectFilterTests {
RequestMatcherRedirectFilter filter = new RequestMatcherRedirectFilter(new AntPathRequestMatcher("/context"), RequestMatcherRedirectFilter filter = new RequestMatcherRedirectFilter(new AntPathRequestMatcher("/context"),
"/test"); "/test");
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/context").build();
request.setServletPath("/context");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
@ -61,8 +60,7 @@ public class RequestMatcherRedirectFilterTests {
RequestMatcherRedirectFilter filter = new RequestMatcherRedirectFilter(new AntPathRequestMatcher("/context"), RequestMatcherRedirectFilter filter = new RequestMatcherRedirectFilter(new AntPathRequestMatcher("/context"),
"/test"); "/test");
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/test").build();
request.setServletPath("/test");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);

View File

@ -58,6 +58,7 @@ import static org.mockito.BDDMockito.willThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests {@link ExceptionTranslationFilter}. * Tests {@link ExceptionTranslationFilter}.
@ -86,13 +87,7 @@ public class ExceptionTranslationFilterTests {
@Test @Test
public void testAccessDeniedWhenAnonymous() throws Exception { public void testAccessDeniedWhenAnonymous() throws Exception {
// Setup our HTTP request // Setup our HTTP request
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get().requestUri("/mycontext", "/secure/page.html", null).build();
request.setServletPath("/secure/page.html");
request.setServerPort(80);
request.setScheme("http");
request.setServerName("localhost");
request.setContextPath("/mycontext");
request.setRequestURI("/mycontext/secure/page.html");
// Setup the FilterChain to thrown an access denied exception // Setup the FilterChain to thrown an access denied exception
FilterChain fc = mockFilterChainWithException(new AccessDeniedException("")); FilterChain fc = mockFilterChainWithException(new AccessDeniedException(""));
// Setup SecurityContextHolder, as filter needs to check if user is // Setup SecurityContextHolder, as filter needs to check if user is
@ -129,13 +124,7 @@ public class ExceptionTranslationFilterTests {
@Test @Test
public void testAccessDeniedWithRememberMe() throws Exception { public void testAccessDeniedWithRememberMe() throws Exception {
// Setup our HTTP request // Setup our HTTP request
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get().requestUri("/mycontext", "/secure/page.html", null).build();
request.setServletPath("/secure/page.html");
request.setServerPort(80);
request.setScheme("http");
request.setServerName("localhost");
request.setContextPath("/mycontext");
request.setRequestURI("/mycontext/secure/page.html");
// Setup the FilterChain to thrown an access denied exception // Setup the FilterChain to thrown an access denied exception
FilterChain fc = mockFilterChainWithException(new AccessDeniedException("")); FilterChain fc = mockFilterChainWithException(new AccessDeniedException(""));
// Setup SecurityContextHolder, as filter needs to check if user is remembered // Setup SecurityContextHolder, as filter needs to check if user is remembered
@ -155,8 +144,7 @@ public class ExceptionTranslationFilterTests {
@Test @Test
public void testAccessDeniedWhenNonAnonymous() throws Exception { public void testAccessDeniedWhenNonAnonymous() throws Exception {
// Setup our HTTP request // Setup our HTTP request
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/secure/page.html").build();
request.setServletPath("/secure/page.html");
// Setup the FilterChain to thrown an access denied exception // Setup the FilterChain to thrown an access denied exception
FilterChain fc = mockFilterChainWithException(new AccessDeniedException("")); FilterChain fc = mockFilterChainWithException(new AccessDeniedException(""));
// Setup SecurityContextHolder, as filter needs to check if user is // Setup SecurityContextHolder, as filter needs to check if user is
@ -178,8 +166,7 @@ public class ExceptionTranslationFilterTests {
@Test @Test
public void testLocalizedErrorMessages() throws Exception { public void testLocalizedErrorMessages() throws Exception {
// Setup our HTTP request // Setup our HTTP request
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/secure/page.html").build();
request.setServletPath("/secure/page.html");
// Setup the FilterChain to thrown an access denied exception // Setup the FilterChain to thrown an access denied exception
FilterChain fc = mockFilterChainWithException(new AccessDeniedException("")); FilterChain fc = mockFilterChainWithException(new AccessDeniedException(""));
// Setup SecurityContextHolder, as filter needs to check if user is // Setup SecurityContextHolder, as filter needs to check if user is
@ -202,13 +189,7 @@ public class ExceptionTranslationFilterTests {
@Test @Test
public void redirectedToLoginFormAndSessionShowsOriginalTargetWhenAuthenticationException() throws Exception { public void redirectedToLoginFormAndSessionShowsOriginalTargetWhenAuthenticationException() throws Exception {
// Setup our HTTP request // Setup our HTTP request
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get().requestUri("/mycontext", "/secure/page.html", null).build();
request.setServletPath("/secure/page.html");
request.setServerPort(80);
request.setScheme("http");
request.setServerName("localhost");
request.setContextPath("/mycontext");
request.setRequestURI("/mycontext/secure/page.html");
// Setup the FilterChain to thrown an authentication failure exception // Setup the FilterChain to thrown an authentication failure exception
FilterChain fc = mockFilterChainWithException(new BadCredentialsException("")); FilterChain fc = mockFilterChainWithException(new BadCredentialsException(""));
// Test // Test
@ -225,13 +206,9 @@ public class ExceptionTranslationFilterTests {
public void redirectedToLoginFormAndSessionShowsOriginalTargetWithExoticPortWhenAuthenticationException() public void redirectedToLoginFormAndSessionShowsOriginalTargetWithExoticPortWhenAuthenticationException()
throws Exception { throws Exception {
// Setup our HTTP request // Setup our HTTP request
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("http://localhost:8080")
request.setServletPath("/secure/page.html"); .requestUri("/mycontext", "/secure/page.html", null)
request.setServerPort(8080); .build();
request.setScheme("http");
request.setServerName("localhost");
request.setContextPath("/mycontext");
request.setRequestURI("/mycontext/secure/page.html");
// Setup the FilterChain to thrown an authentication failure exception // Setup the FilterChain to thrown an authentication failure exception
FilterChain fc = mockFilterChainWithException(new BadCredentialsException("")); FilterChain fc = mockFilterChainWithException(new BadCredentialsException(""));
// Test // Test
@ -258,8 +235,7 @@ public class ExceptionTranslationFilterTests {
@Test @Test
public void successfulAccessGrant() throws Exception { public void successfulAccessGrant() throws Exception {
// Setup our HTTP request // Setup our HTTP request
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/secure/page.html").build();
request.setServletPath("/secure/page.html");
// Test // Test
ExceptionTranslationFilter filter = new ExceptionTranslationFilter(this.mockEntryPoint); ExceptionTranslationFilter filter = new ExceptionTranslationFilter(this.mockEntryPoint);
assertThat(filter.getAuthenticationEntryPoint()).isSameAs(this.mockEntryPoint); assertThat(filter.getAuthenticationEntryPoint()).isSameAs(this.mockEntryPoint);

View File

@ -32,6 +32,7 @@ import org.springframework.security.web.access.intercept.FilterInvocationSecurit
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests {@link ChannelProcessingFilter}. * Tests {@link ChannelProcessingFilter}.
@ -81,9 +82,8 @@ public class ChannelProcessingFilterTests {
filter.setChannelDecisionManager(new MockChannelDecisionManager(true, "SOME_ATTRIBUTE")); filter.setChannelDecisionManager(new MockChannelDecisionManager(true, "SOME_ATTRIBUTE"));
MockFilterInvocationDefinitionMap fids = new MockFilterInvocationDefinitionMap("/path", true, "SOME_ATTRIBUTE"); MockFilterInvocationDefinitionMap fids = new MockFilterInvocationDefinitionMap("/path", true, "SOME_ATTRIBUTE");
filter.setSecurityMetadataSource(fids); filter.setSecurityMetadataSource(fids);
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/path").build();
request.setQueryString("info=now"); request.setQueryString("info=now");
request.setServletPath("/path");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
filter.doFilter(request, response, mock(FilterChain.class)); filter.doFilter(request, response, mock(FilterChain.class));
} }
@ -94,9 +94,8 @@ public class ChannelProcessingFilterTests {
filter.setChannelDecisionManager(new MockChannelDecisionManager(false, "SOME_ATTRIBUTE")); filter.setChannelDecisionManager(new MockChannelDecisionManager(false, "SOME_ATTRIBUTE"));
MockFilterInvocationDefinitionMap fids = new MockFilterInvocationDefinitionMap("/path", true, "SOME_ATTRIBUTE"); MockFilterInvocationDefinitionMap fids = new MockFilterInvocationDefinitionMap("/path", true, "SOME_ATTRIBUTE");
filter.setSecurityMetadataSource(fids); filter.setSecurityMetadataSource(fids);
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/path").build();
request.setQueryString("info=now"); request.setQueryString("info=now");
request.setServletPath("/path");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
filter.doFilter(request, response, mock(FilterChain.class)); filter.doFilter(request, response, mock(FilterChain.class));
} }
@ -107,9 +106,8 @@ public class ChannelProcessingFilterTests {
filter.setChannelDecisionManager(new MockChannelDecisionManager(false, "NOT_USED")); filter.setChannelDecisionManager(new MockChannelDecisionManager(false, "NOT_USED"));
MockFilterInvocationDefinitionMap fids = new MockFilterInvocationDefinitionMap("/path", true, "NOT_USED"); MockFilterInvocationDefinitionMap fids = new MockFilterInvocationDefinitionMap("/path", true, "NOT_USED");
filter.setSecurityMetadataSource(fids); filter.setSecurityMetadataSource(fids);
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/PATH_NOT_MATCHING_CONFIG_ATTRIBUTE").build();
request.setQueryString("info=now"); request.setQueryString("info=now");
request.setServletPath("/PATH_NOT_MATCHING_CONFIG_ATTRIBUTE");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
filter.doFilter(request, response, mock(FilterChain.class)); filter.doFilter(request, response, mock(FilterChain.class));
} }

View File

@ -27,6 +27,7 @@ import org.springframework.security.web.FilterInvocation;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests {@link InsecureChannelProcessor}. * Tests {@link InsecureChannelProcessor}.
@ -37,13 +38,9 @@ public class InsecureChannelProcessorTests {
@Test @Test
public void testDecideDetectsAcceptableChannel() throws Exception { public void testDecideDetectsAcceptableChannel() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("http://localhost:8080").requestUri("/bigapp", "/servlet", null)
request.setQueryString("info=true"); .queryString("info=true")
request.setServerName("localhost"); .build();
request.setContextPath("/bigapp");
request.setServletPath("/servlet");
request.setScheme("http");
request.setServerPort(8080);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class));
InsecureChannelProcessor processor = new InsecureChannelProcessor(); InsecureChannelProcessor processor = new InsecureChannelProcessor();
@ -53,14 +50,9 @@ public class InsecureChannelProcessorTests {
@Test @Test
public void testDecideDetectsUnacceptableChannel() throws Exception { public void testDecideDetectsUnacceptableChannel() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("https://localhost:8443").requestUri("/bigapp", "/servlet", null)
request.setQueryString("info=true"); .queryString("info=true")
request.setServerName("localhost"); .build();
request.setContextPath("/bigapp");
request.setServletPath("/servlet");
request.setScheme("https");
request.setSecure(true);
request.setServerPort(8443);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class));
InsecureChannelProcessor processor = new InsecureChannelProcessor(); InsecureChannelProcessor processor = new InsecureChannelProcessor();

View File

@ -27,6 +27,7 @@ import org.springframework.security.web.FilterInvocation;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests {@link SecureChannelProcessor}. * Tests {@link SecureChannelProcessor}.
@ -37,14 +38,9 @@ public class SecureChannelProcessorTests {
@Test @Test
public void testDecideDetectsAcceptableChannel() throws Exception { public void testDecideDetectsAcceptableChannel() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("https://localhost:8443").requestUri("/bigapp", "/servlet", null)
request.setQueryString("info=true"); .queryString("info=true")
request.setServerName("localhost"); .build();
request.setContextPath("/bigapp");
request.setServletPath("/servlet");
request.setScheme("https");
request.setSecure(true);
request.setServerPort(8443);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class));
SecureChannelProcessor processor = new SecureChannelProcessor(); SecureChannelProcessor processor = new SecureChannelProcessor();
@ -54,13 +50,9 @@ public class SecureChannelProcessorTests {
@Test @Test
public void testDecideDetectsUnacceptableChannel() throws Exception { public void testDecideDetectsUnacceptableChannel() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("http://localhost:8080").requestUri("/bigapp", "/servlet", null)
request.setQueryString("info=true"); .queryString("info=true")
request.setServerName("localhost"); .build();
request.setContextPath("/bigapp");
request.setServletPath("/servlet");
request.setScheme("http");
request.setServerPort(8080);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class));
SecureChannelProcessor processor = new SecureChannelProcessor(); SecureChannelProcessor processor = new SecureChannelProcessor();

View File

@ -31,6 +31,7 @@ import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.web.FilterInvocation; import org.springframework.security.web.FilterInvocation;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* @author Rob Winch * @author Rob Winch
@ -54,8 +55,7 @@ public class AbstractVariableEvaluationContextPostProcessorTests {
@BeforeEach @BeforeEach
public void setup() { public void setup() {
this.processor = new VariableEvaluationContextPostProcessor(); this.processor = new VariableEvaluationContextPostProcessor();
this.request = new MockHttpServletRequest(); this.request = get("/").build();
this.request.setServletPath("/");
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
this.invocation = new FilterInvocation(this.request, this.response, new MockFilterChain()); this.invocation = new FilterInvocation(this.request, this.response, new MockFilterChain());
this.context = new StandardEvaluationContext(); this.context = new StandardEvaluationContext();

View File

@ -32,6 +32,7 @@ import org.springframework.security.web.util.matcher.RequestMatcher;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.request;
/** /**
* Tests {@link DefaultFilterInvocationSecurityMetadataSource}. * Tests {@link DefaultFilterInvocationSecurityMetadataSource}.
@ -53,7 +54,7 @@ public class DefaultFilterInvocationSecurityMetadataSourceTests {
@Test @Test
public void lookupNotRequiringExactMatchSucceedsIfNotMatching() { public void lookupNotRequiringExactMatchSucceedsIfNotMatching() {
createFids("/secure/super/**", null); createFids("/secure/super/**", null);
FilterInvocation fi = createFilterInvocation("/secure/super/somefile.html", null, null, null); FilterInvocation fi = createFilterInvocation("/secure/super/somefile.html", null, null, "GET");
assertThat(this.fids.getAttributes(fi)).isEqualTo(this.def); assertThat(this.fids.getAttributes(fi)).isEqualTo(this.def);
} }
@ -64,7 +65,7 @@ public class DefaultFilterInvocationSecurityMetadataSourceTests {
@Test @Test
public void lookupNotRequiringExactMatchSucceedsIfSecureUrlPathContainsUpperCase() { public void lookupNotRequiringExactMatchSucceedsIfSecureUrlPathContainsUpperCase() {
createFids("/secure/super/**", null); createFids("/secure/super/**", null);
FilterInvocation fi = createFilterInvocation("/secure", "/super/somefile.html", null, null); FilterInvocation fi = createFilterInvocation("/secure", "/super/somefile.html", null, "GET");
Collection<ConfigAttribute> response = this.fids.getAttributes(fi); Collection<ConfigAttribute> response = this.fids.getAttributes(fi);
assertThat(response).isEqualTo(this.def); assertThat(response).isEqualTo(this.def);
} }
@ -72,7 +73,7 @@ public class DefaultFilterInvocationSecurityMetadataSourceTests {
@Test @Test
public void lookupRequiringExactMatchIsSuccessful() { public void lookupRequiringExactMatchIsSuccessful() {
createFids("/SeCurE/super/**", null); createFids("/SeCurE/super/**", null);
FilterInvocation fi = createFilterInvocation("/SeCurE/super/somefile.html", null, null, null); FilterInvocation fi = createFilterInvocation("/SeCurE/super/somefile.html", null, null, "GET");
Collection<ConfigAttribute> response = this.fids.getAttributes(fi); Collection<ConfigAttribute> response = this.fids.getAttributes(fi);
assertThat(response).isEqualTo(this.def); assertThat(response).isEqualTo(this.def);
} }
@ -80,7 +81,7 @@ public class DefaultFilterInvocationSecurityMetadataSourceTests {
@Test @Test
public void lookupRequiringExactMatchWithAdditionalSlashesIsSuccessful() { public void lookupRequiringExactMatchWithAdditionalSlashesIsSuccessful() {
createFids("/someAdminPage.html**", null); createFids("/someAdminPage.html**", null);
FilterInvocation fi = createFilterInvocation("/someAdminPage.html", null, "a=/test", null); FilterInvocation fi = createFilterInvocation("/someAdminPage.html", null, "a=/test", "GET");
Collection<ConfigAttribute> response = this.fids.getAttributes(fi); Collection<ConfigAttribute> response = this.fids.getAttributes(fi);
assertThat(response); // see SEC-161 (it should truncate after ? assertThat(response); // see SEC-161 (it should truncate after ?
// sign).isEqualTo(def) // sign).isEqualTo(def)
@ -129,22 +130,19 @@ public class DefaultFilterInvocationSecurityMetadataSourceTests {
@Test @Test
public void extraQuestionMarkStillMatches() { public void extraQuestionMarkStillMatches() {
createFids("/someAdminPage.html*", null); createFids("/someAdminPage.html*", null);
FilterInvocation fi = createFilterInvocation("/someAdminPage.html", null, null, null); FilterInvocation fi = createFilterInvocation("/someAdminPage.html", null, null, "GET");
Collection<ConfigAttribute> response = this.fids.getAttributes(fi); Collection<ConfigAttribute> response = this.fids.getAttributes(fi);
assertThat(response).isEqualTo(this.def); assertThat(response).isEqualTo(this.def);
fi = createFilterInvocation("/someAdminPage.html", null, "?", null); fi = createFilterInvocation("/someAdminPage.html", null, "?", "GET");
response = this.fids.getAttributes(fi); response = this.fids.getAttributes(fi);
assertThat(response).isEqualTo(this.def); assertThat(response).isEqualTo(this.def);
} }
private FilterInvocation createFilterInvocation(String servletPath, String pathInfo, String queryString, private FilterInvocation createFilterInvocation(String servletPath, String pathInfo, String queryString,
String method) { String method) {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = request(method).requestUri(null, servletPath, pathInfo)
request.setRequestURI(null); .queryString(queryString)
request.setMethod(method); .build();
request.setServletPath(servletPath);
request.setPathInfo(pathInfo);
request.setQueryString(queryString);
return new FilterInvocation(request, new MockHttpServletResponse(), mock(FilterChain.class)); return new FilterInvocation(request, new MockHttpServletResponse(), mock(FilterChain.class));
} }

View File

@ -53,6 +53,7 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests {@link FilterSecurityInterceptor}. * Tests {@link FilterSecurityInterceptor}.
@ -188,8 +189,7 @@ public class FilterSecurityInterceptorTests {
private FilterInvocation createinvocation() { private FilterInvocation createinvocation() {
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/secure/page.html").build();
request.setServletPath("/secure/page.html");
FilterChain chain = mock(FilterChain.class); FilterChain chain = mock(FilterChain.class);
FilterInvocation fi = new FilterInvocation(request, response, chain); FilterInvocation fi = new FilterInvocation(request, response, chain);
return fi; return fi;

View File

@ -59,6 +59,9 @@ import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.Builder;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post;
/** /**
* Tests {@link AbstractAuthenticationProcessingFilter}. * Tests {@link AbstractAuthenticationProcessingFilter}.
@ -75,13 +78,11 @@ public class AbstractAuthenticationProcessingFilterTests {
SimpleUrlAuthenticationFailureHandler failureHandler; SimpleUrlAuthenticationFailureHandler failureHandler;
private MockHttpServletRequest createMockAuthenticationRequest() { private MockHttpServletRequest createMockAuthenticationRequest() {
MockHttpServletRequest request = new MockHttpServletRequest(); return withMockAuthenticationRequest().build();
request.setServletPath("/j_mock_post"); }
request.setScheme("http");
request.setServerName("www.example.com"); private Builder withMockAuthenticationRequest() {
request.setRequestURI("/mycontext/j_mock_post"); return get("www.example.com").requestUri("/mycontext", "/j_mock_post", null);
request.setContextPath("/mycontext");
return request;
} }
@BeforeEach @BeforeEach
@ -100,12 +101,11 @@ public class AbstractAuthenticationProcessingFilterTests {
@Test @Test
public void testDefaultProcessesFilterUrlMatchesWithPathParameter() { public void testDefaultProcessesFilterUrlMatchesWithPathParameter() {
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/login;jsessionid=I8MIONOSTHOR"); MockHttpServletRequest request = post("/login;jsessionid=I8MIONOSTHOR").build();
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
MockAuthenticationFilter filter = new MockAuthenticationFilter(); MockAuthenticationFilter filter = new MockAuthenticationFilter();
filter.setFilterProcessesUrl("/login"); filter.setFilterProcessesUrl("/login");
DefaultHttpFirewall firewall = new DefaultHttpFirewall(); DefaultHttpFirewall firewall = new DefaultHttpFirewall();
request.setServletPath("/login;jsessionid=I8MIONOSTHOR");
// the firewall ensures that path parameters are ignored // the firewall ensures that path parameters are ignored
HttpServletRequest firewallRequest = firewall.getFirewalledRequest(request); HttpServletRequest firewallRequest = firewall.getFirewalledRequest(request);
assertThat(filter.requiresAuthentication(firewallRequest, response)).isTrue(); assertThat(filter.requiresAuthentication(firewallRequest, response)).isTrue();
@ -114,9 +114,9 @@ public class AbstractAuthenticationProcessingFilterTests {
@Test @Test
public void testFilterProcessesUrlVariationsRespected() throws Exception { public void testFilterProcessesUrlVariationsRespected() throws Exception {
// Setup our HTTP request // Setup our HTTP request
MockHttpServletRequest request = createMockAuthenticationRequest(); MockHttpServletRequest request = withMockAuthenticationRequest()
request.setServletPath("/j_OTHER_LOCATION"); .requestUri("/mycontext", "/j_OTHER_LOCATION", null)
request.setRequestURI("/mycontext/j_OTHER_LOCATION"); .build();
// Setup our filter configuration // Setup our filter configuration
MockFilterConfig config = new MockFilterConfig(null, null); MockFilterConfig config = new MockFilterConfig(null, null);
// Setup our expectation that the filter chain will not be invoked, as we redirect // Setup our expectation that the filter chain will not be invoked, as we redirect
@ -150,9 +150,9 @@ public class AbstractAuthenticationProcessingFilterTests {
@Test @Test
public void testIgnoresAnyServletPathOtherThanFilterProcessesUrl() throws Exception { public void testIgnoresAnyServletPathOtherThanFilterProcessesUrl() throws Exception {
// Setup our HTTP request // Setup our HTTP request
MockHttpServletRequest request = createMockAuthenticationRequest(); MockHttpServletRequest request = withMockAuthenticationRequest()
request.setServletPath("/some.file.html"); .requestUri("/mycontext", "/some.file.html", null)
request.setRequestURI("/mycontext/some.file.html"); .build();
// Setup our filter configuration // Setup our filter configuration
MockFilterConfig config = new MockFilterConfig(null, null); MockFilterConfig config = new MockFilterConfig(null, null);
// Setup our expectation that the filter chain will be invoked, as our request is // Setup our expectation that the filter chain will be invoked, as our request is
@ -227,9 +227,9 @@ public class AbstractAuthenticationProcessingFilterTests {
@Test @Test
public void testNormalOperationWithRequestMatcherAndAuthenticationManager() throws Exception { public void testNormalOperationWithRequestMatcherAndAuthenticationManager() throws Exception {
// Setup our HTTP request // Setup our HTTP request
MockHttpServletRequest request = createMockAuthenticationRequest(); MockHttpServletRequest request = withMockAuthenticationRequest()
request.setServletPath("/j_eradicate_corona_virus"); .requestUri("/mycontext", "/j_eradicate_corona_virus", null)
request.setRequestURI("/mycontext/j_eradicate_corona_virus"); .build();
HttpSession sessionPreAuth = request.getSession(); HttpSession sessionPreAuth = request.getSession();
// Setup our filter configuration // Setup our filter configuration
MockFilterConfig config = new MockFilterConfig(null, null); MockFilterConfig config = new MockFilterConfig(null, null);

View File

@ -28,6 +28,7 @@ import org.springframework.security.web.PortMapperImpl;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests {@link LoginUrlAuthenticationEntryPoint}. * Tests {@link LoginUrlAuthenticationEntryPoint}.
@ -73,12 +74,7 @@ public class LoginUrlAuthenticationEntryPointTests {
@Test @Test
public void testHttpsOperationFromOriginalHttpUrl() throws Exception { public void testHttpsOperationFromOriginalHttpUrl() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("http://127.0.0.1").requestUri("/bigWebApp", "/some_path", null).build();
request.setRequestURI("/some_path");
request.setScheme("http");
request.setServerName("www.example.com");
request.setContextPath("/bigWebApp");
request.setServerPort(80);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
LoginUrlAuthenticationEntryPoint ep = new LoginUrlAuthenticationEntryPoint("/hello"); LoginUrlAuthenticationEntryPoint ep = new LoginUrlAuthenticationEntryPoint("/hello");
ep.setPortMapper(new PortMapperImpl()); ep.setPortMapper(new PortMapperImpl());
@ -87,17 +83,17 @@ public class LoginUrlAuthenticationEntryPointTests {
ep.setPortResolver(new MockPortResolver(80, 443)); ep.setPortResolver(new MockPortResolver(80, 443));
ep.afterPropertiesSet(); ep.afterPropertiesSet();
ep.commence(request, response, null); ep.commence(request, response, null);
assertThat(response.getRedirectedUrl()).isEqualTo("https://www.example.com/bigWebApp/hello"); assertThat(response.getRedirectedUrl()).isEqualTo("https://127.0.0.1/bigWebApp/hello");
request.setServerPort(8080); request.setServerPort(8080);
response = new MockHttpServletResponse(); response = new MockHttpServletResponse();
ep.setPortResolver(new MockPortResolver(8080, 8443)); ep.setPortResolver(new MockPortResolver(8080, 8443));
ep.commence(request, response, null); ep.commence(request, response, null);
assertThat(response.getRedirectedUrl()).isEqualTo("https://www.example.com:8443/bigWebApp/hello"); assertThat(response.getRedirectedUrl()).isEqualTo("https://127.0.0.1:8443/bigWebApp/hello");
// Now test an unusual custom HTTP:HTTPS is handled properly // Now test an unusual custom HTTP:HTTPS is handled properly
request.setServerPort(8888); request.setServerPort(8888);
response = new MockHttpServletResponse(); response = new MockHttpServletResponse();
ep.commence(request, response, null); ep.commence(request, response, null);
assertThat(response.getRedirectedUrl()).isEqualTo("https://www.example.com:8443/bigWebApp/hello"); assertThat(response.getRedirectedUrl()).isEqualTo("https://127.0.0.1:8443/bigWebApp/hello");
PortMapperImpl portMapper = new PortMapperImpl(); PortMapperImpl portMapper = new PortMapperImpl();
Map<String, String> map = new HashMap<>(); Map<String, String> map = new HashMap<>();
map.put("8888", "9999"); map.put("8888", "9999");
@ -110,17 +106,13 @@ public class LoginUrlAuthenticationEntryPointTests {
ep.setPortResolver(new MockPortResolver(8888, 9999)); ep.setPortResolver(new MockPortResolver(8888, 9999));
ep.afterPropertiesSet(); ep.afterPropertiesSet();
ep.commence(request, response, null); ep.commence(request, response, null);
assertThat(response.getRedirectedUrl()).isEqualTo("https://www.example.com:9999/bigWebApp/hello"); assertThat(response.getRedirectedUrl()).isEqualTo("https://127.0.0.1:9999/bigWebApp/hello");
} }
@Test @Test
public void testHttpsOperationFromOriginalHttpsUrl() throws Exception { public void testHttpsOperationFromOriginalHttpsUrl() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("https://www.example.com:443").requestUri("/bigWebApp", "/some_path", null)
request.setRequestURI("/some_path"); .build();
request.setScheme("https");
request.setServerName("www.example.com");
request.setContextPath("/bigWebApp");
request.setServerPort(443);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
LoginUrlAuthenticationEntryPoint ep = new LoginUrlAuthenticationEntryPoint("/hello"); LoginUrlAuthenticationEntryPoint ep = new LoginUrlAuthenticationEntryPoint("/hello");
ep.setPortMapper(new PortMapperImpl()); ep.setPortMapper(new PortMapperImpl());
@ -149,13 +141,7 @@ public class LoginUrlAuthenticationEntryPointTests {
ep.setPortMapper(new PortMapperImpl()); ep.setPortMapper(new PortMapperImpl());
ep.setPortResolver(new MockPortResolver(80, 443)); ep.setPortResolver(new MockPortResolver(80, 443));
ep.afterPropertiesSet(); ep.afterPropertiesSet();
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get().requestUri("/bigWebApp", "/some_path", null).build();
request.setRequestURI("/some_path");
request.setContextPath("/bigWebApp");
request.setScheme("http");
request.setServerName("localhost");
request.setContextPath("/bigWebApp");
request.setServerPort(80);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
ep.commence(request, response, null); ep.commence(request, response, null);
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/bigWebApp/hello"); assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/bigWebApp/hello");
@ -167,13 +153,8 @@ public class LoginUrlAuthenticationEntryPointTests {
ep.setPortResolver(new MockPortResolver(8888, 1234)); ep.setPortResolver(new MockPortResolver(8888, 1234));
ep.setForceHttps(true); ep.setForceHttps(true);
ep.afterPropertiesSet(); ep.afterPropertiesSet();
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("http://localhost:8888").requestUri("/bigWebApp", "/some_path", null)
request.setRequestURI("/some_path"); .build(); // NB: Port we can't resolve
request.setContextPath("/bigWebApp");
request.setScheme("http");
request.setServerName("localhost");
request.setContextPath("/bigWebApp");
request.setServerPort(8888); // NB: Port we can't resolve
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
ep.commence(request, response, null); ep.commence(request, response, null);
// Response doesn't switch to HTTPS, as we didn't know HTTP port 8888 to HTTP port // Response doesn't switch to HTTPS, as we didn't know HTTP port 8888 to HTTP port
@ -186,14 +167,7 @@ public class LoginUrlAuthenticationEntryPointTests {
LoginUrlAuthenticationEntryPoint ep = new LoginUrlAuthenticationEntryPoint("/hello"); LoginUrlAuthenticationEntryPoint ep = new LoginUrlAuthenticationEntryPoint("/hello");
ep.setUseForward(true); ep.setUseForward(true);
ep.afterPropertiesSet(); ep.afterPropertiesSet();
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get().requestUri("/bigWebApp", "/some_path", null).build();
request.setRequestURI("/bigWebApp/some_path");
request.setServletPath("/some_path");
request.setContextPath("/bigWebApp");
request.setScheme("http");
request.setServerName("www.example.com");
request.setContextPath("/bigWebApp");
request.setServerPort(80);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
ep.commence(request, response, null); ep.commence(request, response, null);
assertThat(response.getForwardedUrl()).isEqualTo("/hello"); assertThat(response.getForwardedUrl()).isEqualTo("/hello");
@ -205,17 +179,10 @@ public class LoginUrlAuthenticationEntryPointTests {
ep.setUseForward(true); ep.setUseForward(true);
ep.setForceHttps(true); ep.setForceHttps(true);
ep.afterPropertiesSet(); ep.afterPropertiesSet();
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("http://127.0.0.1").requestUri("/bigWebApp", "/some_path", null).build();
request.setRequestURI("/bigWebApp/some_path");
request.setServletPath("/some_path");
request.setContextPath("/bigWebApp");
request.setScheme("http");
request.setServerName("www.example.com");
request.setContextPath("/bigWebApp");
request.setServerPort(80);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
ep.commence(request, response, null); ep.commence(request, response, null);
assertThat(response.getRedirectedUrl()).isEqualTo("https://www.example.com/bigWebApp/some_path"); assertThat(response.getRedirectedUrl()).isEqualTo("https://127.0.0.1/bigWebApp/some_path");
} }
// SEC-1498 // SEC-1498

View File

@ -28,6 +28,7 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests for {@link RequestMatcherDelegatingAuthenticationManagerResolverTests} * Tests for {@link RequestMatcherDelegatingAuthenticationManagerResolverTests}
@ -48,8 +49,7 @@ public class RequestMatcherDelegatingAuthenticationManagerResolverTests {
.add(new AntPathRequestMatcher("/two/**"), this.two) .add(new AntPathRequestMatcher("/two/**"), this.two)
.build(); .build();
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/one/location"); MockHttpServletRequest request = get("/one/location").build();
request.setServletPath("/one/location");
assertThat(resolver.resolve(request)).isEqualTo(this.one); assertThat(resolver.resolve(request)).isEqualTo(this.one);
} }

View File

@ -39,6 +39,7 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post;
/** /**
* Tests {@link UsernamePasswordAuthenticationFilter}. * Tests {@link UsernamePasswordAuthenticationFilter}.
@ -128,10 +129,10 @@ public class UsernamePasswordAuthenticationFilterTests {
@Test @Test
public void testSecurityContextHolderStrategyUsed() throws Exception { public void testSecurityContextHolderStrategyUsed() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/login"); MockHttpServletRequest request = post("/login")
request.setServletPath("/login"); .param(UsernamePasswordAuthenticationFilter.SPRING_SECURITY_FORM_USERNAME_KEY, "rod")
request.addParameter(UsernamePasswordAuthenticationFilter.SPRING_SECURITY_FORM_USERNAME_KEY, "rod"); .param(UsernamePasswordAuthenticationFilter.SPRING_SECURITY_FORM_PASSWORD_KEY, "koala")
request.addParameter(UsernamePasswordAuthenticationFilter.SPRING_SECURITY_FORM_PASSWORD_KEY, "koala"); .build();
UsernamePasswordAuthenticationFilter filter = new UsernamePasswordAuthenticationFilter(); UsernamePasswordAuthenticationFilter filter = new UsernamePasswordAuthenticationFilter();
filter.setAuthenticationManager(createAuthenticationManager()); filter.setAuthenticationManager(createAuthenticationManager());
SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy()); SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy());

View File

@ -24,6 +24,8 @@ import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.web.firewall.DefaultHttpFirewall; import org.springframework.security.web.firewall.DefaultHttpFirewall;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post;
/** /**
* @author Luke Taylor * @author Luke Taylor
@ -39,22 +41,20 @@ public class LogoutHandlerTests {
@Test @Test
public void testRequiresLogoutUrlWorksWithPathParams() { public void testRequiresLogoutUrlWorksWithPathParams() {
MockHttpServletRequest request = new MockHttpServletRequest("POST", "/context/logout;someparam=blah"); MockHttpServletRequest request = post().requestUri("/context", "/logout;someparam=blah", null)
.queryString("otherparam=blah")
.build();
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
request.setContextPath("/context");
request.setServletPath("/logout;someparam=blah");
request.setQueryString("otherparam=blah");
DefaultHttpFirewall fw = new DefaultHttpFirewall(); DefaultHttpFirewall fw = new DefaultHttpFirewall();
assertThat(this.filter.requiresLogout(fw.getFirewalledRequest(request), response)).isTrue(); assertThat(this.filter.requiresLogout(fw.getFirewalledRequest(request), response)).isTrue();
} }
@Test @Test
public void testRequiresLogoutUrlWorksWithQueryParams() { public void testRequiresLogoutUrlWorksWithQueryParams() {
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/context/logout"); MockHttpServletRequest request = get().requestUri("/context", "/logout", null)
request.setContextPath("/context"); .queryString("otherparam=blah")
.build();
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
request.setServletPath("/logout");
request.setQueryString("param=blah");
assertThat(this.filter.requiresLogout(request, response)).isTrue(); assertThat(this.filter.requiresLogout(request, response)).isTrue();
} }

View File

@ -38,6 +38,7 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post;
/** /**
* Tests for {@link GenerateOneTimeTokenWebFilter} * Tests for {@link GenerateOneTimeTokenWebFilter}
@ -55,7 +56,7 @@ public class GenerateOneTimeTokenFilterTests {
private static final String USERNAME = "user"; private static final String USERNAME = "user";
private final MockHttpServletRequest request = new MockHttpServletRequest(); private MockHttpServletRequest request;
private final MockHttpServletResponse response = new MockHttpServletResponse(); private final MockHttpServletResponse response = new MockHttpServletResponse();
@ -63,9 +64,7 @@ public class GenerateOneTimeTokenFilterTests {
@BeforeEach @BeforeEach
void setup() { void setup() {
this.request.setMethod("POST"); this.request = post("/ott/generate").build();
this.request.setServletPath("/ott/generate");
this.request.setRequestURI("/ott/generate");
} }
@Test @Test
@ -87,6 +86,7 @@ public class GenerateOneTimeTokenFilterTests {
void filterWhenUsernameFormParamIsEmptyThenNull() throws ServletException, IOException { void filterWhenUsernameFormParamIsEmptyThenNull() throws ServletException, IOException {
given(this.oneTimeTokenService.generate(ArgumentMatchers.any(GenerateOneTimeTokenRequest.class))) given(this.oneTimeTokenService.generate(ArgumentMatchers.any(GenerateOneTimeTokenRequest.class)))
.willReturn((new DefaultOneTimeToken(TOKEN, USERNAME, Instant.now()))); .willReturn((new DefaultOneTimeToken(TOKEN, USERNAME, Instant.now())));
GenerateOneTimeTokenFilter filter = new GenerateOneTimeTokenFilter(this.oneTimeTokenService, GenerateOneTimeTokenFilter filter = new GenerateOneTimeTokenFilter(this.oneTimeTokenService,
this.successHandler); this.successHandler);

View File

@ -27,6 +27,7 @@ import org.springframework.mock.web.MockHttpServletResponse;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests for {@link DefaultOneTimeTokenSubmitPageGeneratingFilter} * Tests for {@link DefaultOneTimeTokenSubmitPageGeneratingFilter}
@ -37,7 +38,7 @@ class DefaultOneTimeTokenSubmitPageGeneratingFilterTests {
DefaultOneTimeTokenSubmitPageGeneratingFilter filter = new DefaultOneTimeTokenSubmitPageGeneratingFilter(); DefaultOneTimeTokenSubmitPageGeneratingFilter filter = new DefaultOneTimeTokenSubmitPageGeneratingFilter();
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/login/ott"); MockHttpServletRequest request;
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
@ -45,9 +46,7 @@ class DefaultOneTimeTokenSubmitPageGeneratingFilterTests {
@BeforeEach @BeforeEach
void setup() { void setup() {
this.request.setMethod("GET"); this.request = get("/login/ott").build();
this.request.setServletPath("/login/ott");
this.request.setRequestURI("/login/ott");
} }
@Test @Test
@ -80,10 +79,9 @@ class DefaultOneTimeTokenSubmitPageGeneratingFilterTests {
@Test @Test
void setContextThenGenerates() throws Exception { void setContextThenGenerates() throws Exception {
this.request.setContextPath("/context"); MockHttpServletRequest request = get().requestUri("/context", "/login/ott", null).build();
this.request.setRequestURI("/context/login/ott");
this.filter.setLoginProcessingUrl("/login/another"); this.filter.setLoginProcessingUrl("/login/another");
this.filter.doFilterInternal(this.request, this.response, this.filterChain); this.filter.doFilterInternal(request, this.response, this.filterChain);
String response = this.response.getContentAsString(); String response = this.response.getContentAsString();
assertThat(response).contains("<form class=\"login-form\" action=\"/context/login/another\" method=\"post\">"); assertThat(response).contains("<form class=\"login-form\" action=\"/context/login/another\" method=\"post\">");
} }
@ -101,7 +99,7 @@ class DefaultOneTimeTokenSubmitPageGeneratingFilterTests {
void filterThenRenders() throws Exception { void filterThenRenders() throws Exception {
this.request.setParameter("token", "this<>!@#\""); this.request.setParameter("token", "this<>!@#\"");
this.filter.setLoginProcessingUrl("/login/another"); this.filter.setLoginProcessingUrl("/login/another");
this.filter.setResolveHiddenInputs((request) -> Map.of("_csrf", "csrf-token-value")); this.filter.setResolveHiddenInputs((r) -> Map.of("_csrf", "csrf-token-value"));
this.filter.doFilterInternal(this.request, this.response, this.filterChain); this.filter.doFilterInternal(this.request, this.response, this.filterChain);
String response = this.response.getContentAsString(); String response = this.response.getContentAsString();
assertThat(response).isEqualTo( assertThat(response).isEqualTo(

View File

@ -61,6 +61,7 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests {@link BasicAuthenticationFilter}. * Tests {@link BasicAuthenticationFilter}.
@ -94,8 +95,7 @@ public class BasicAuthenticationFilterTests {
@Test @Test
public void testFilterIgnoresRequestsContainingNoAuthorizationHeader() throws Exception { public void testFilterIgnoresRequestsContainingNoAuthorizationHeader() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/some_file.html").build();
request.setServletPath("/some_file.html");
final MockHttpServletResponse response = new MockHttpServletResponse(); final MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain chain = mock(FilterChain.class); FilterChain chain = mock(FilterChain.class);
this.filter.doFilter(request, response, chain); this.filter.doFilter(request, response, chain);
@ -113,9 +113,8 @@ public class BasicAuthenticationFilterTests {
@Test @Test
public void testInvalidBasicAuthorizationTokenIsIgnored() throws Exception { public void testInvalidBasicAuthorizationTokenIsIgnored() throws Exception {
String token = "NOT_A_VALID_TOKEN_AS_MISSING_COLON"; String token = "NOT_A_VALID_TOKEN_AS_MISSING_COLON";
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/some_file.html").build();
request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token));
request.setServletPath("/some_file.html");
request.setSession(new MockHttpSession()); request.setSession(new MockHttpSession());
final MockHttpServletResponse response = new MockHttpServletResponse(); final MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain chain = mock(FilterChain.class); FilterChain chain = mock(FilterChain.class);
@ -127,9 +126,8 @@ public class BasicAuthenticationFilterTests {
@Test @Test
public void invalidBase64IsIgnored() throws Exception { public void invalidBase64IsIgnored() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/some_file.html").build();
request.addHeader("Authorization", "Basic NOT_VALID_BASE64"); request.addHeader("Authorization", "Basic NOT_VALID_BASE64");
request.setServletPath("/some_file.html");
request.setSession(new MockHttpSession()); request.setSession(new MockHttpSession());
final MockHttpServletResponse response = new MockHttpServletResponse(); final MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain chain = mock(FilterChain.class); FilterChain chain = mock(FilterChain.class);
@ -143,9 +141,8 @@ public class BasicAuthenticationFilterTests {
@Test @Test
public void testNormalOperation() throws Exception { public void testNormalOperation() throws Exception {
String token = "rod:koala"; String token = "rod:koala";
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/some_file.html").build();
request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token));
request.setServletPath("/some_file.html");
// Test // Test
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
FilterChain chain = mock(FilterChain.class); FilterChain chain = mock(FilterChain.class);
@ -172,9 +169,8 @@ public class BasicAuthenticationFilterTests {
@Test @Test
public void doFilterWhenSchemeLowercaseThenCaseInsensitveMatchWorks() throws Exception { public void doFilterWhenSchemeLowercaseThenCaseInsensitveMatchWorks() throws Exception {
String token = "rod:koala"; String token = "rod:koala";
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/some_file.html").build();
request.addHeader("Authorization", "basic " + CodecTestUtils.encodeBase64(token)); request.addHeader("Authorization", "basic " + CodecTestUtils.encodeBase64(token));
request.setServletPath("/some_file.html");
// Test // Test
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
FilterChain chain = mock(FilterChain.class); FilterChain chain = mock(FilterChain.class);
@ -187,9 +183,8 @@ public class BasicAuthenticationFilterTests {
@Test @Test
public void doFilterWhenSchemeMixedCaseThenCaseInsensitiveMatchWorks() throws Exception { public void doFilterWhenSchemeMixedCaseThenCaseInsensitiveMatchWorks() throws Exception {
String token = "rod:koala"; String token = "rod:koala";
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/some_file.html").build();
request.addHeader("Authorization", "BaSiC " + CodecTestUtils.encodeBase64(token)); request.addHeader("Authorization", "BaSiC " + CodecTestUtils.encodeBase64(token));
request.setServletPath("/some_file.html");
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
FilterChain chain = mock(FilterChain.class); FilterChain chain = mock(FilterChain.class);
this.filter.doFilter(request, new MockHttpServletResponse(), chain); this.filter.doFilter(request, new MockHttpServletResponse(), chain);
@ -200,9 +195,8 @@ public class BasicAuthenticationFilterTests {
@Test @Test
public void testOtherAuthorizationSchemeIsIgnored() throws Exception { public void testOtherAuthorizationSchemeIsIgnored() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/some_file.html").build();
request.addHeader("Authorization", "SOME_OTHER_AUTHENTICATION_SCHEME"); request.addHeader("Authorization", "SOME_OTHER_AUTHENTICATION_SCHEME");
request.setServletPath("/some_file.html");
FilterChain chain = mock(FilterChain.class); FilterChain chain = mock(FilterChain.class);
this.filter.doFilter(request, new MockHttpServletResponse(), chain); this.filter.doFilter(request, new MockHttpServletResponse(), chain);
verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
@ -222,9 +216,8 @@ public class BasicAuthenticationFilterTests {
@Test @Test
public void testSuccessLoginThenFailureLoginResultsInSessionLosingToken() throws Exception { public void testSuccessLoginThenFailureLoginResultsInSessionLosingToken() throws Exception {
String token = "rod:koala"; String token = "rod:koala";
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/some_file.html").build();
request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token));
request.setServletPath("/some_file.html");
final MockHttpServletResponse response1 = new MockHttpServletResponse(); final MockHttpServletResponse response1 = new MockHttpServletResponse();
FilterChain chain = mock(FilterChain.class); FilterChain chain = mock(FilterChain.class);
this.filter.doFilter(request, response1, chain); this.filter.doFilter(request, response1, chain);
@ -240,7 +233,6 @@ public class BasicAuthenticationFilterTests {
chain = mock(FilterChain.class); chain = mock(FilterChain.class);
this.filter.doFilter(request, response2, chain); this.filter.doFilter(request, response2, chain);
verify(chain, never()).doFilter(any(ServletRequest.class), any(ServletResponse.class)); verify(chain, never()).doFilter(any(ServletRequest.class), any(ServletResponse.class));
request.setServletPath("/some_file.html");
// Test - the filter chain will not be invoked, as we get a 401 forbidden response // Test - the filter chain will not be invoked, as we get a 401 forbidden response
MockHttpServletResponse response = response2; MockHttpServletResponse response = response2;
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
@ -250,9 +242,8 @@ public class BasicAuthenticationFilterTests {
@Test @Test
public void testWrongPasswordContinuesFilterChainIfIgnoreFailureIsTrue() throws Exception { public void testWrongPasswordContinuesFilterChainIfIgnoreFailureIsTrue() throws Exception {
String token = "rod:WRONG_PASSWORD"; String token = "rod:WRONG_PASSWORD";
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/some_file.html").build();
request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token));
request.setServletPath("/some_file.html");
request.setSession(new MockHttpSession()); request.setSession(new MockHttpSession());
this.filter = new BasicAuthenticationFilter(this.manager); this.filter = new BasicAuthenticationFilter(this.manager);
assertThat(this.filter.isIgnoreFailure()).isTrue(); assertThat(this.filter.isIgnoreFailure()).isTrue();
@ -266,9 +257,8 @@ public class BasicAuthenticationFilterTests {
@Test @Test
public void testWrongPasswordReturnsForbiddenIfIgnoreFailureIsFalse() throws Exception { public void testWrongPasswordReturnsForbiddenIfIgnoreFailureIsFalse() throws Exception {
String token = "rod:WRONG_PASSWORD"; String token = "rod:WRONG_PASSWORD";
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/some_file.html").build();
request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token));
request.setServletPath("/some_file.html");
request.setSession(new MockHttpSession()); request.setSession(new MockHttpSession());
assertThat(this.filter.isIgnoreFailure()).isFalse(); assertThat(this.filter.isIgnoreFailure()).isFalse();
final MockHttpServletResponse response = new MockHttpServletResponse(); final MockHttpServletResponse response = new MockHttpServletResponse();
@ -284,9 +274,8 @@ public class BasicAuthenticationFilterTests {
@Test @Test
public void skippedOnErrorDispatch() throws Exception { public void skippedOnErrorDispatch() throws Exception {
String token = "bad:credentials"; String token = "bad:credentials";
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/some_file.html").build();
request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token));
request.setServletPath("/some_file.html");
request.setAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE, "/error"); request.setAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE, "/error");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain chain = mock(FilterChain.class); FilterChain chain = mock(FilterChain.class);
@ -307,10 +296,9 @@ public class BasicAuthenticationFilterTests {
given(this.manager.authenticate(not(eq(rodRequest)))).willThrow(new BadCredentialsException("")); given(this.manager.authenticate(not(eq(rodRequest)))).willThrow(new BadCredentialsException(""));
this.filter = new BasicAuthenticationFilter(this.manager, new BasicAuthenticationEntryPoint()); this.filter = new BasicAuthenticationFilter(this.manager, new BasicAuthenticationEntryPoint());
String token = "rod:äöü"; String token = "rod:äöü";
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/some_file.html").build();
request.addHeader("Authorization", request.addHeader("Authorization",
"Basic " + CodecTestUtils.encodeBase64(token.getBytes(StandardCharsets.UTF_8))); "Basic " + CodecTestUtils.encodeBase64(token.getBytes(StandardCharsets.UTF_8)));
request.setServletPath("/some_file.html");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
// Test // Test
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
@ -336,10 +324,9 @@ public class BasicAuthenticationFilterTests {
this.filter = new BasicAuthenticationFilter(this.manager, new BasicAuthenticationEntryPoint()); this.filter = new BasicAuthenticationFilter(this.manager, new BasicAuthenticationEntryPoint());
this.filter.setCredentialsCharset("ISO-8859-1"); this.filter.setCredentialsCharset("ISO-8859-1");
String token = "rod:äöü"; String token = "rod:äöü";
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/some_file.html").build();
request.addHeader("Authorization", request.addHeader("Authorization",
"Basic " + CodecTestUtils.encodeBase64(token.getBytes(StandardCharsets.ISO_8859_1))); "Basic " + CodecTestUtils.encodeBase64(token.getBytes(StandardCharsets.ISO_8859_1)));
request.setServletPath("/some_file.html");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
// Test // Test
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
@ -367,10 +354,9 @@ public class BasicAuthenticationFilterTests {
this.filter = new BasicAuthenticationFilter(this.manager, new BasicAuthenticationEntryPoint()); this.filter = new BasicAuthenticationFilter(this.manager, new BasicAuthenticationEntryPoint());
this.filter.setCredentialsCharset("ISO-8859-1"); this.filter.setCredentialsCharset("ISO-8859-1");
String token = "rod:äöü"; String token = "rod:äöü";
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/some_file.html").build();
request.addHeader("Authorization", request.addHeader("Authorization",
"Basic " + CodecTestUtils.encodeBase64(token.getBytes(StandardCharsets.UTF_8))); "Basic " + CodecTestUtils.encodeBase64(token.getBytes(StandardCharsets.UTF_8)));
request.setServletPath("/some_file.html");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
// Test // Test
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
@ -383,9 +369,8 @@ public class BasicAuthenticationFilterTests {
@Test @Test
public void requestWhenEmptyBasicAuthorizationHeaderTokenThenUnauthorized() throws Exception { public void requestWhenEmptyBasicAuthorizationHeaderTokenThenUnauthorized() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/some_file.html").build();
request.addHeader("Authorization", "Basic "); request.addHeader("Authorization", "Basic ");
request.setServletPath("/some_file.html");
request.setSession(new MockHttpSession()); request.setSession(new MockHttpSession());
final MockHttpServletResponse response = new MockHttpServletResponse(); final MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain chain = mock(FilterChain.class); FilterChain chain = mock(FilterChain.class);
@ -401,9 +386,8 @@ public class BasicAuthenticationFilterTests {
SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
this.filter.setSecurityContextRepository(securityContextRepository); this.filter.setSecurityContextRepository(securityContextRepository);
String token = "rod:koala"; String token = "rod:koala";
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/some_file.html").build();
request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token));
request.setServletPath("/some_file.html");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
// Test // Test
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
@ -496,9 +480,8 @@ public class BasicAuthenticationFilterTests {
public void doFilterWhenCustomAuthenticationConverterThatIgnoresRequestThenIgnores() throws Exception { public void doFilterWhenCustomAuthenticationConverterThatIgnoresRequestThenIgnores() throws Exception {
this.filter.setAuthenticationConverter(new TestAuthenticationConverter()); this.filter.setAuthenticationConverter(new TestAuthenticationConverter());
String token = "rod:koala"; String token = "rod:koala";
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/ignored").build();
request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token));
request.setServletPath("/ignored");
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
this.filter.doFilter(request, response, filterChain); this.filter.doFilter(request, response, filterChain);
@ -513,9 +496,8 @@ public class BasicAuthenticationFilterTests {
public void doFilterWhenCustomAuthenticationConverterRequestThenAuthenticate() throws Exception { public void doFilterWhenCustomAuthenticationConverterRequestThenAuthenticate() throws Exception {
this.filter.setAuthenticationConverter(new TestAuthenticationConverter()); this.filter.setAuthenticationConverter(new TestAuthenticationConverter());
String token = "rod:koala"; String token = "rod:koala";
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get("/ok").build();
request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token));
request.setServletPath("/ok");
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
this.filter.doFilter(request, response, filterChain); this.filter.doFilter(request, response, filterChain);

View File

@ -53,6 +53,7 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* Tests {@link DigestAuthenticationFilter}. * Tests {@link DigestAuthenticationFilter}.
@ -131,8 +132,7 @@ public class DigestAuthenticationFilterTests {
this.filter = new DigestAuthenticationFilter(); this.filter = new DigestAuthenticationFilter();
this.filter.setUserDetailsService(uds); this.filter.setUserDetailsService(uds);
this.filter.setAuthenticationEntryPoint(ep); this.filter.setAuthenticationEntryPoint(ep);
this.request = new MockHttpServletRequest("GET", REQUEST_URI); this.request = get(REQUEST_URI).build();
this.request.setServletPath(REQUEST_URI);
} }
@Test @Test

View File

@ -41,6 +41,7 @@ import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* @author Rob Winch * @author Rob Winch
@ -120,10 +121,7 @@ public class DebugFilterTests {
@Test @Test
public void doFilterLogsProperly() throws Exception { public void doFilterLogsProperly() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get().requestUri(null, "/path", "/").build();
request.setMethod("GET");
request.setServletPath("/path");
request.setPathInfo("/");
request.addHeader("A", "A Value"); request.addHeader("A", "A Value");
request.addHeader("A", "Another Value"); request.addHeader("A", "Another Value");
request.addHeader("B", "B Value"); request.addHeader("B", "B Value");

View File

@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* @author Luke Taylor * @author Luke Taylor
@ -34,8 +35,7 @@ public class DefaultHttpFirewallTests {
public void unnormalizedPathsAreRejected() { public void unnormalizedPathsAreRejected() {
DefaultHttpFirewall fw = new DefaultHttpFirewall(); DefaultHttpFirewall fw = new DefaultHttpFirewall();
for (String path : this.unnormalizedPaths) { for (String path : this.unnormalizedPaths) {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = get().requestUri(path).build();
request.setServletPath(path);
assertThatExceptionOfType(RequestRejectedException.class) assertThatExceptionOfType(RequestRejectedException.class)
.isThrownBy(() -> fw.getFirewalledRequest(request)); .isThrownBy(() -> fw.getFirewalledRequest(request));
request.setPathInfo(path); request.setPathInfo(path);

View File

@ -27,6 +27,7 @@ import org.springframework.mock.web.MockHttpServletRequest;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
/** /**
* @author Rob Winch * @author Rob Winch
@ -112,8 +113,7 @@ public class StrictHttpFirewallTests {
@Test @Test
public void getFirewalledRequestWhenServletPathNotNormalizedThenThrowsRequestRejectedException() { public void getFirewalledRequestWhenServletPathNotNormalizedThenThrowsRequestRejectedException() {
for (String path : this.unnormalizedPaths) { for (String path : this.unnormalizedPaths) {
this.request = new MockHttpServletRequest("GET", ""); this.request = get().requestUri(path).build();
this.request.setServletPath(path);
assertThatExceptionOfType(RequestRejectedException.class) assertThatExceptionOfType(RequestRejectedException.class)
.isThrownBy(() -> this.firewall.getFirewalledRequest(this.request)); .isThrownBy(() -> this.firewall.getFirewalledRequest(this.request));
} }

View File

@ -28,6 +28,7 @@ import org.springframework.mock.web.MockHttpServletRequest;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get;
import static org.springframework.security.web.util.matcher.RegexRequestMatcher.regexMatcher; import static org.springframework.security.web.util.matcher.RegexRequestMatcher.regexMatcher;
/** /**
@ -50,8 +51,7 @@ public class RegexRequestMatcherTests {
@Test @Test
public void matchesIfHttpMethodAndPathMatch() { public void matchesIfHttpMethodAndPathMatch() {
RegexRequestMatcher matcher = new RegexRequestMatcher(".*", "GET"); RegexRequestMatcher matcher = new RegexRequestMatcher(".*", "GET");
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/anything"); MockHttpServletRequest request = get("/anything").build();
request.setServletPath("/anything");
assertThat(matcher.matches(request)).isTrue(); assertThat(matcher.matches(request)).isTrue();
} }