Post-process AuthenticationRequestFilter

Fixes gh-8552
This commit is contained in:
Josh Cummings 2020-04-17 16:48:59 -06:00
parent 8e7c4c143c
commit 51a0cffd36
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
2 changed files with 58 additions and 2 deletions

View File

@ -323,9 +323,9 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
private Filter build(B http) { private Filter build(B http) {
Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http); Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
return new Saml2WebSsoAuthenticationRequestFilter( return postProcess(new Saml2WebSsoAuthenticationRequestFilter(
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository, Saml2LoginConfigurer.this.relyingPartyRegistrationRepository,
authenticationRequestResolver); authenticationRequestResolver));
} }
private Saml2AuthenticationRequestFactory getResolver(B http) { private Saml2AuthenticationRequestFactory getResolver(B http) {

View File

@ -23,6 +23,7 @@ import java.util.Base64;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import org.junit.After; import org.junit.After;
import org.junit.Assert; import org.junit.Assert;
@ -55,9 +56,13 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication; import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpRequestResponseHolder;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
@ -66,10 +71,15 @@ import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
/** /**
* Tests for different Java configuration for {@link Saml2LoginConfigurer} * Tests for different Java configuration for {@link Saml2LoginConfigurer}
@ -133,6 +143,20 @@ public class Saml2LoginConfigurerTests {
validateSaml2WebSsoAuthenticationFilterConfiguration(); validateSaml2WebSsoAuthenticationFilterConfiguration();
} }
@Test
public void saml2LoginWhenCustomAuthenticationRequestContextResolverThenUses() throws Exception {
this.spring.register(CustomAuthenticationRequestContextResolver.class).autowire();
Saml2AuthenticationRequestContext context = authenticationRequestContext().build();
Saml2AuthenticationRequestContextResolver resolver =
CustomAuthenticationRequestContextResolver.resolver;
when(resolver.resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class)))
.thenReturn(context);
this.mvc.perform(get("/saml2/authenticate/registration-id"))
.andExpect(status().isFound());
verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class));
}
private void validateSaml2WebSsoAuthenticationFilterConfiguration() { private void validateSaml2WebSsoAuthenticationFilterConfiguration() {
// get the OpenSamlAuthenticationProvider // get the OpenSamlAuthenticationProvider
Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain); Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain);
@ -219,6 +243,38 @@ public class Saml2LoginConfigurerTests {
} }
} }
@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthenticationRequestContextResolver extends WebSecurityConfigurerAdapter {
private static final Saml2AuthenticationRequestContextResolver resolver =
mock(Saml2AuthenticationRequestContextResolver.class);
@Override
protected void configure(HttpSecurity http) throws Exception {
ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter> processor
= new ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter>() {
@Override
public <O extends Saml2WebSsoAuthenticationRequestFilter> O postProcess(O filter) {
filter.setAuthenticationRequestContextResolver(resolver);
return filter;
}
};
http
.authorizeRequests(authz -> authz
.anyRequest().authenticated()
)
.saml2Login(saml2 -> saml2
.addObjectPostProcessor(processor)
);
}
@Bean
Saml2AuthenticationRequestContextResolver resolver() {
return resolver;
}
}
private static AuthenticationManager getAuthenticationManagerMock(String role) { private static AuthenticationManager getAuthenticationManagerMock(String role) {
return new AuthenticationManager() { return new AuthenticationManager() {