Add postProcess support to Saml2LogoutConfigurer

Closes gh-10311
This commit is contained in:
Gaurav Tiwari 2021-10-02 13:23:09 +05:30 committed by Josh Cummings
parent fbb7691be4
commit 33708e61fb
2 changed files with 115 additions and 3 deletions

View File

@ -253,7 +253,7 @@ public final class Saml2LogoutConfigurer<H extends HttpSecurityBuilder<H>>
Saml2LogoutRequestFilter filter = new Saml2LogoutRequestFilter(registrations,
this.logoutRequestConfigurer.logoutRequestValidator(), logoutResponseResolver, logoutHandlers);
filter.setLogoutRequestMatcher(createLogoutRequestMatcher());
return filter;
return postProcess(filter);
}
private Saml2LogoutResponseFilter createLogoutResponseProcessingFilter(
@ -262,7 +262,7 @@ public final class Saml2LogoutConfigurer<H extends HttpSecurityBuilder<H>>
this.logoutResponseConfigurer.logoutResponseValidator(), this.logoutSuccessHandler);
logoutResponseFilter.setLogoutRequestMatcher(createLogoutResponseMatcher());
logoutResponseFilter.setLogoutRequestRepository(this.logoutRequestConfigurer.logoutRequestRepository);
return logoutResponseFilter;
return postProcess(logoutResponseFilter);
}
private LogoutFilter createRelyingPartyLogoutFilter(RelyingPartyRegistrationResolver registrations) {
@ -271,7 +271,7 @@ public final class Saml2LogoutConfigurer<H extends HttpSecurityBuilder<H>>
registrations);
LogoutFilter logoutFilter = new LogoutFilter(logoutRequestSuccessHandler, logoutHandlers);
logoutFilter.setLogoutRequestMatcher(createLogoutMatcher());
return logoutFilter;
return postProcess(logoutFilter);
}
private RequestMatcher createLogoutMatcher() {

View File

@ -37,6 +37,7 @@ import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockHttpSession;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.ObjectPostProcessor;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.test.SpringTestContext;
@ -59,12 +60,16 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.saml2.provider.service.web.authentication.logout.HttpSessionLogoutRequestRepository;
import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestFilter;
import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestRepository;
import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestResolver;
import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutResponseFilter;
import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutResponseResolver;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.logout.LogoutFilter;
import org.springframework.security.web.authentication.logout.LogoutHandler;
import org.springframework.security.web.authentication.logout.LogoutSuccessHandler;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
@ -75,6 +80,8 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.mock;
import static org.mockito.BDDMockito.verify;
import static org.mockito.BDDMockito.verifyNoInteractions;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.spy;
import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
@ -346,6 +353,47 @@ public class Saml2LogoutConfigurerTests {
verify(getBean(Saml2LogoutResponseValidator.class)).validate(any());
}
@Test
public void saml2LogoutWhenLogoutGetThenLogsOutAndSendsLogoutRequest() throws Exception {
this.spring.register(Saml2LogoutWithHttpGet.class).autowire();
MvcResult result = this.mvc.perform(get("/logout").with(authentication(this.user)))
.andExpect(status().isFound()).andReturn();
String location = result.getResponse().getHeader("Location");
LogoutHandler logoutHandler = this.spring.getContext().getBean(LogoutHandler.class);
assertThat(location).startsWith("https://ap.example.org/logout/saml2/request");
verify(logoutHandler).logout(any(), any(), any());
}
@Test
public void saml2LogoutWhenSaml2LogoutRequestFilterPostProcessedThenUses() {
Saml2DefaultsWithObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class);
this.spring.register(Saml2DefaultsWithObjectPostProcessorConfig.class).autowire();
verify(Saml2DefaultsWithObjectPostProcessorConfig.objectPostProcessor)
.postProcess(any(Saml2LogoutRequestFilter.class));
}
@Test
public void saml2LogoutWhenSaml2LogoutResponseFilterPostProcessedThenUses() {
Saml2DefaultsWithObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class);
this.spring.register(Saml2DefaultsWithObjectPostProcessorConfig.class).autowire();
verify(Saml2DefaultsWithObjectPostProcessorConfig.objectPostProcessor)
.postProcess(any(Saml2LogoutResponseFilter.class));
}
@Test
public void saml2LogoutWhenLogoutFilterPostProcessedThenUses() {
Saml2DefaultsWithObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class);
this.spring.register(Saml2DefaultsWithObjectPostProcessorConfig.class).autowire();
verify(Saml2DefaultsWithObjectPostProcessorConfig.objectPostProcessor, atLeastOnce())
.postProcess(any(LogoutFilter.class));
}
private <T> T getBean(Class<T> clazz) {
return this.spring.getContext().getBean(clazz);
}
@ -401,6 +449,61 @@ public class Saml2LogoutConfigurerTests {
}
@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class Saml2LogoutWithHttpGet {
LogoutHandler mockLogoutHandler = mock(LogoutHandler.class);
@Bean
SecurityFilterChain web(HttpSecurity http) throws Exception {
// @formatter:off
http
.authorizeRequests((authorize) -> authorize.anyRequest().authenticated())
.logout((logout) -> logout.addLogoutHandler(this.mockLogoutHandler))
.saml2Login(withDefaults())
.saml2Logout((saml2) -> saml2.addObjectPostProcessor(new ObjectPostProcessor<LogoutFilter>() {
@Override
public <O extends LogoutFilter> O postProcess(O filter) {
filter.setLogoutRequestMatcher(new AntPathRequestMatcher("/logout", "GET"));
return filter;
}
}));
return http.build();
// @formatter:on
}
@Bean
LogoutHandler logoutHandler() {
return this.mockLogoutHandler;
}
}
@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class Saml2DefaultsWithObjectPostProcessorConfig {
static ObjectPostProcessor<Object> objectPostProcessor;
@Bean
SecurityFilterChain web(HttpSecurity http) throws Exception {
// @formatter:off
http
.authorizeRequests((authorize) -> authorize.anyRequest().authenticated())
.saml2Login(withDefaults())
.saml2Logout(withDefaults());
return http.build();
// @formatter:on
}
@Bean
static ObjectPostProcessor<Object> objectPostProcessor() {
return objectPostProcessor;
}
}
@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class Saml2LogoutComponentsConfig {
@ -490,4 +593,13 @@ public class Saml2LogoutConfigurerTests {
}
static class ReflectingObjectPostProcessor implements ObjectPostProcessor<Object> {
@Override
public <O> O postProcess(O object) {
return object;
}
}
}