Add Saml2AuthenticationRequestRepository

Closes gh-9185
This commit is contained in:
Marcus Da Coregio 2021-06-28 13:58:42 -03:00 committed by Josh Cummings
parent 6b68a6d62b
commit 16e17d242e
15 changed files with 657 additions and 19 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -33,6 +33,7 @@ import org.springframework.security.config.annotation.web.configurers.AbstractAu
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer; import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider; import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
@ -40,6 +41,8 @@ import org.springframework.security.saml2.provider.service.authentication.OpenSa
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
@ -206,6 +209,7 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
} }
this.saml2WebSsoAuthenticationFilter = new Saml2WebSsoAuthenticationFilter(getAuthenticationConverter(http), this.saml2WebSsoAuthenticationFilter = new Saml2WebSsoAuthenticationFilter(getAuthenticationConverter(http),
this.loginProcessingUrl); this.loginProcessingUrl);
setAuthenticationRequestRepository(http, this.saml2WebSsoAuthenticationFilter);
setAuthenticationFilter(this.saml2WebSsoAuthenticationFilter); setAuthenticationFilter(this.saml2WebSsoAuthenticationFilter);
super.loginProcessingUrl(this.loginProcessingUrl); super.loginProcessingUrl(this.loginProcessingUrl);
if (StringUtils.hasText(this.loginPage)) { if (StringUtils.hasText(this.loginPage)) {
@ -252,6 +256,11 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
} }
} }
private void setAuthenticationRequestRepository(B http,
Saml2WebSsoAuthenticationFilter saml2WebSsoAuthenticationFilter) {
saml2WebSsoAuthenticationFilter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http));
}
private AuthenticationConverter getAuthenticationConverter(B http) { private AuthenticationConverter getAuthenticationConverter(B http) {
if (this.authenticationConverter == null) { if (this.authenticationConverter == null) {
return new Saml2AuthenticationTokenConverter( return new Saml2AuthenticationTokenConverter(
@ -311,6 +320,16 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
return idps; return idps;
} }
private Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> getAuthenticationRequestRepository(
B http) {
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = getBeanOrNull(http,
Saml2AuthenticationRequestRepository.class);
if (repository == null) {
return new HttpSessionSaml2AuthenticationRequestRepository();
}
return repository;
}
private <C> C getSharedOrBean(B http, Class<C> clazz) { private <C> C getSharedOrBean(B http, Class<C> clazz) {
C shared = http.getSharedObject(clazz); C shared = http.getSharedObject(clazz);
if (shared != null) { if (shared != null) {
@ -348,8 +367,12 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
private Filter build(B http) { private Filter build(B http) {
Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http); Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http); Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http);
return postProcess( Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = getAuthenticationRequestRepository(
new Saml2WebSsoAuthenticationRequestFilter(contextResolver, authenticationRequestResolver)); http);
Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(contextResolver,
authenticationRequestResolver);
filter.setAuthenticationRequestRepository(repository);
return postProcess(filter);
} }
private Saml2AuthenticationRequestFactory getResolver(B http) { private Saml2AuthenticationRequestFactory getResolver(B http) {

View File

@ -63,6 +63,7 @@ import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMap
import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.core.Saml2Utils; import org.springframework.security.saml2.core.Saml2Utils;
import org.springframework.security.saml2.core.TestSaml2X509Credentials; import org.springframework.security.saml2.core.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider; import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
@ -76,9 +77,11 @@ import org.springframework.security.saml2.provider.service.authentication.TestSa
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpRequestResponseHolder;
@ -237,6 +240,29 @@ public class Saml2LoginConfigurerTests {
assertThat(exception.getCause()).isInstanceOf(IOException.class); assertThat(exception.getCause()).isInstanceOf(IOException.class);
} }
@Test
public void authenticationRequestWhenCustomAuthnRequestRepositoryThenUses() throws Exception {
this.spring.register(CustomAuthenticationRequestRepository.class).autowire();
MockHttpServletRequestBuilder request = get("/saml2/authenticate/registration-id");
this.mvc.perform(request).andExpect(status().isFound());
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = this.spring.getContext()
.getBean(Saml2AuthenticationRequestRepository.class);
verify(repository).saveAuthenticationRequest(any(AbstractSaml2AuthenticationRequest.class),
any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void authenticateWhenCustomAuthnRequestRepositoryThenUses() throws Exception {
this.spring.register(CustomAuthenticationRequestRepository.class).autowire();
MockHttpServletRequestBuilder request = post("/login/saml2/sso/registration-id").param("SAMLResponse",
SIGNED_RESPONSE);
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = this.spring.getContext()
.getBean(Saml2AuthenticationRequestRepository.class);
this.mvc.perform(request);
verify(repository).loadAuthenticationRequest(any(HttpServletRequest.class));
verify(repository).removeAuthenticationRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
private void validateSaml2WebSsoAuthenticationFilterConfiguration() { private void validateSaml2WebSsoAuthenticationFilterConfiguration() {
// get the OpenSamlAuthenticationProvider // get the OpenSamlAuthenticationProvider
Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain); Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain);
@ -371,7 +397,7 @@ public class Saml2LoginConfigurerTests {
@Bean @Bean
Saml2AuthenticationRequestContextResolver resolver() { Saml2AuthenticationRequestContextResolver resolver() {
return resolver; return this.resolver;
} }
} }
@ -420,6 +446,27 @@ public class Saml2LoginConfigurerTests {
} }
@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthenticationRequestRepository {
private static final Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = mock(
Saml2AuthenticationRequestRepository.class);
@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
http.authorizeRequests((authz) -> authz.anyRequest().authenticated());
http.saml2Login(withDefaults());
return http.build();
}
@Bean
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository() {
return this.repository;
}
}
static class Saml2LoginConfigBeans { static class Saml2LoginConfigBeans {
@Bean @Bean

View File

@ -1610,3 +1610,33 @@ http {
The success handler will send logout requests to the asserting party. The success handler will send logout requests to the asserting party.
The request matcher will detect logout requests from the asserting party. The request matcher will detect logout requests from the asserting party.
[[servlet-saml2login-store-authn-request]]
=== Storing the `AuthnRequest`
The `Saml2AuthenticationRequestRepository` is responsible for the persistence of the `AuthnRequest` from the time the `AuthnRequest` <<servlet-saml2login-sp-initiated-factory,is initiated>> to the time the `SAMLResponse` <<servlet-saml2login-authenticate-responses,is received>>.
The `Saml2AuthenticationTokenConverter` is responsible for loading the `AuthnRequest` from the `Saml2AuthenticationRequestRepository` and saving it into the `Saml2AuthenticationToken`.
The default implementation of `Saml2AuthenticationRequestRepository` is `HttpSessionSaml2AuthenticationRequestRepository`, which stores the `AuthnRequest` in the `HttpSession`.
If you have a custom implementation of `Saml2AuthenticationRequestRepository`, you may configure it by exposing it as a `@Bean` as shown in the following example:
====
.Java
[source,java,role="primary"]
----
@Bean
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository() {
return new CustomSaml2AuthenticationRequestRepository();
}
----
.Kotlin
[source,kotlin,role="secondary"]
----
@Bean
open fun authenticationRequestRepository(): Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> {
return CustomSaml2AuthenticationRequestRepository()
}
----
====

View File

@ -55,6 +55,7 @@ dependencies {
testImplementation "org.junit.jupiter:junit-jupiter-params" testImplementation "org.junit.jupiter:junit-jupiter-params"
testImplementation "org.junit.jupiter:junit-jupiter-engine" testImplementation "org.junit.jupiter:junit-jupiter-engine"
testImplementation "org.mockito:mockito-core" testImplementation "org.mockito:mockito-core"
testImplementation "org.mockito:mockito-inline"
testImplementation "org.mockito:mockito-junit-jupiter" testImplementation "org.mockito:mockito-junit-jupiter"
testImplementation "org.springframework:spring-test" testImplementation "org.springframework:spring-test"
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -38,6 +38,32 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
private final String saml2Response; private final String saml2Response;
private final AbstractSaml2AuthenticationRequest authenticationRequest;
/**
* Creates a {@link Saml2AuthenticationToken} with the provided parameters.
*
* Note that the given {@link RelyingPartyRegistration} should have all its templates
* resolved at this point. See
* {@link org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter}
* for an example of performing that resolution.
* @param relyingPartyRegistration the resolved {@link RelyingPartyRegistration} to
* use
* @param saml2Response the SAML 2.0 response to authenticate
* @param authenticationRequest the {@code AuthNRequest} sent to the asserting party
*
* @since 5.6
*/
public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response,
AbstractSaml2AuthenticationRequest authenticationRequest) {
super(Collections.emptyList());
Assert.notNull(relyingPartyRegistration, "relyingPartyRegistration cannot be null");
Assert.notNull(saml2Response, "saml2Response cannot be null");
this.relyingPartyRegistration = relyingPartyRegistration;
this.saml2Response = saml2Response;
this.authenticationRequest = authenticationRequest;
}
/** /**
* Creates a {@link Saml2AuthenticationToken} with the provided parameters * Creates a {@link Saml2AuthenticationToken} with the provided parameters
* *
@ -52,11 +78,7 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
* @since 5.4 * @since 5.4
*/ */
public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response) { public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response) {
super(Collections.emptyList()); this(relyingPartyRegistration, saml2Response, null);
Assert.notNull(relyingPartyRegistration, "relyingPartyRegistration cannot be null");
Assert.notNull(saml2Response, "saml2Response cannot be null");
this.relyingPartyRegistration = relyingPartyRegistration;
this.saml2Response = saml2Response;
} }
/** /**
@ -81,6 +103,7 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
.entityId(idpEntityId).singleSignOnServiceLocation(idpEntityId)) .entityId(idpEntityId).singleSignOnServiceLocation(idpEntityId))
.build(); .build();
this.saml2Response = saml2Response; this.saml2Response = saml2Response;
this.authenticationRequest = null;
} }
/** /**
@ -179,4 +202,14 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
return this.relyingPartyRegistration.getAssertingPartyDetails().getEntityId(); return this.relyingPartyRegistration.getAssertingPartyDetails().getEntityId();
} }
/**
* Returns the authentication request sent to the assertion party or {@code null} if
* no authentication request is present
* @return the authentication request sent to the assertion party
* @since 5.6
*/
public AbstractSaml2AuthenticationRequest getAuthenticationRequest() {
return this.authenticationRequest;
}
} }

View File

@ -0,0 +1,73 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.servlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
/**
* A {@link Saml2AuthenticationRequestRepository} implementation that uses
* {@link HttpSession} to store and retrieve the
* {@link AbstractSaml2AuthenticationRequest}
*
* @author Marcus Da Coregio
* @since 5.6
*/
public class HttpSessionSaml2AuthenticationRequestRepository
implements Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> {
private static final String DEFAULT_SAML2_AUTHN_REQUEST_ATTR_NAME = HttpSessionSaml2AuthenticationRequestRepository.class
.getName().concat(".SAML2_AUTHN_REQUEST");
private String saml2AuthnRequestAttributeName = DEFAULT_SAML2_AUTHN_REQUEST_ATTR_NAME;
@Override
public AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
HttpSession httpSession = request.getSession(false);
if (httpSession == null) {
return null;
}
return (AbstractSaml2AuthenticationRequest) httpSession.getAttribute(this.saml2AuthnRequestAttributeName);
}
@Override
public void saveAuthenticationRequest(AbstractSaml2AuthenticationRequest authenticationRequest,
HttpServletRequest request, HttpServletResponse response) {
if (authenticationRequest == null) {
removeAuthenticationRequest(request, response);
return;
}
HttpSession httpSession = request.getSession();
httpSession.setAttribute(this.saml2AuthnRequestAttributeName, authenticationRequest);
}
@Override
public AbstractSaml2AuthenticationRequest removeAuthenticationRequest(HttpServletRequest request,
HttpServletResponse response) {
AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
if (authenticationRequest == null) {
return null;
}
HttpSession httpSession = request.getSession();
httpSession.removeAttribute(this.saml2AuthnRequestAttributeName);
return authenticationRequest;
}
}

View File

@ -0,0 +1,60 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.servlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
/**
* A repository for {@link AbstractSaml2AuthenticationRequest}
*
* @param <T> the type of SAML 2.0 Authentication Request
* @author Marcus Da Coregio
* @since 5.6
*/
public interface Saml2AuthenticationRequestRepository<T extends AbstractSaml2AuthenticationRequest> {
/**
* Loads the {@link AbstractSaml2AuthenticationRequest} from the request
* @param request the current request
* @return the {@link AbstractSaml2AuthenticationRequest} or {@code null} if it is not
* present
*/
T loadAuthenticationRequest(HttpServletRequest request);
/**
* Saves the current authentication request using the {@link HttpServletRequest} and
* {@link HttpServletResponse}
* @param authenticationRequest the {@link AbstractSaml2AuthenticationRequest}
* @param request the current request
* @param response the current response
*/
void saveAuthenticationRequest(T authenticationRequest, HttpServletRequest request, HttpServletResponse response);
/**
* Removes the authentication request using the {@link HttpServletRequest} and
* {@link HttpServletResponse}
* @param request the current request
* @param response the current response
* @return the removed {@link AbstractSaml2AuthenticationRequest} or {@code null} if
* it is not present
*/
T removeAuthenticationRequest(HttpServletRequest request, HttpServletResponse response);
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -23,8 +23,11 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.saml2.core.Saml2Error; import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
@ -42,6 +45,8 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
private final AuthenticationConverter authenticationConverter; private final AuthenticationConverter authenticationConverter;
private Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = new HttpSessionSaml2AuthenticationRequestRepository();
/** /**
* Creates a {@code Saml2WebSsoAuthenticationFilter} authentication filter that is * Creates a {@code Saml2WebSsoAuthenticationFilter} authentication filter that is
* configured to use the {@link #DEFAULT_FILTER_PROCESSES_URI} processing URL * configured to use the {@link #DEFAULT_FILTER_PROCESSES_URI} processing URL
@ -100,7 +105,33 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
"No relying party registration found"); "No relying party registration found");
throw new Saml2AuthenticationException(saml2Error); throw new Saml2AuthenticationException(saml2Error);
} }
this.authenticationRequestRepository.removeAuthenticationRequest(request, response);
return getAuthenticationManager().authenticate(authentication); return getAuthenticationManager().authenticate(authentication);
} }
/**
* Use the given {@link Saml2AuthenticationRequestRepository} to remove the saved
* authentication request. If the {@link #authenticationConverter} is of the type
* {@link Saml2AuthenticationTokenConverter}, the
* {@link Saml2AuthenticationRequestRepository} will also be set into the
* {@link #authenticationConverter}.
* @param authenticationRequestRepository the
* {@link Saml2AuthenticationRequestRepository} to use
* @since 5.6
*/
public void setAuthenticationRequestRepository(
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository) {
Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null");
this.authenticationRequestRepository = authenticationRequestRepository;
setAuthenticationRequestRepositoryIntoAuthenticationConverter(authenticationRequestRepository);
}
private void setAuthenticationRequestRepositoryIntoAuthenticationConverter(
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository) {
if (this.authenticationConverter instanceof Saml2AuthenticationTokenConverter) {
Saml2AuthenticationTokenConverter authenticationTokenConverter = (Saml2AuthenticationTokenConverter) this.authenticationConverter;
authenticationTokenConverter.setAuthenticationRequestRepository(authenticationRequestRepository);
}
}
} }

View File

@ -27,6 +27,7 @@ import javax.servlet.http.HttpServletResponse;
import org.opensaml.core.Version; import org.opensaml.core.Version;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
@ -34,6 +35,8 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
@ -79,6 +82,8 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}"); private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}");
private Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = new HttpSessionSaml2AuthenticationRequestRepository();
/** /**
* Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided * Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided
* parameters * parameters
@ -149,6 +154,19 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
this.redirectMatcher = redirectMatcher; this.redirectMatcher = redirectMatcher;
} }
/**
* Use the given {@link Saml2AuthenticationRequestRepository} to save the
* authentication request
* @param authenticationRequestRepository the
* {@link Saml2AuthenticationRequestRepository} to use
* @since 5.6
*/
public void setAuthenticationRequestRepository(
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository) {
Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null");
this.authenticationRequestRepository = authenticationRequestRepository;
}
@Override @Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException { throws ServletException, IOException {
@ -165,17 +183,18 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
} }
RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration(); RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration();
if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) { if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) {
sendRedirect(response, context); sendRedirect(request, response, context);
} }
else { else {
sendPost(response, context); sendPost(request, response, context);
} }
} }
private void sendRedirect(HttpServletResponse response, Saml2AuthenticationRequestContext context) private void sendRedirect(HttpServletRequest request, HttpServletResponse response,
throws IOException { Saml2AuthenticationRequestContext context) throws IOException {
Saml2RedirectAuthenticationRequest authenticationRequest = this.authenticationRequestFactory Saml2RedirectAuthenticationRequest authenticationRequest = this.authenticationRequestFactory
.createRedirectAuthenticationRequest(context); .createRedirectAuthenticationRequest(context);
this.authenticationRequestRepository.saveAuthenticationRequest(authenticationRequest, request, response);
UriComponentsBuilder uriBuilder = UriComponentsBuilder UriComponentsBuilder uriBuilder = UriComponentsBuilder
.fromUriString(authenticationRequest.getAuthenticationRequestUri()); .fromUriString(authenticationRequest.getAuthenticationRequestUri());
addParameter("SAMLRequest", authenticationRequest.getSamlRequest(), uriBuilder); addParameter("SAMLRequest", authenticationRequest.getSamlRequest(), uriBuilder);
@ -194,9 +213,11 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
} }
} }
private void sendPost(HttpServletResponse response, Saml2AuthenticationRequestContext context) throws IOException { private void sendPost(HttpServletRequest request, HttpServletResponse response,
Saml2AuthenticationRequestContext context) throws IOException {
Saml2PostAuthenticationRequest authenticationRequest = this.authenticationRequestFactory Saml2PostAuthenticationRequest authenticationRequest = this.authenticationRequestFactory
.createPostAuthenticationRequest(context); .createPostAuthenticationRequest(context);
this.authenticationRequestRepository.saveAuthenticationRequest(authenticationRequest, request, response);
String html = createSamlPostRequestFormData(authenticationRequest); String html = createSamlPostRequestFormData(authenticationRequest);
response.setContentType(MediaType.TEXT_HTML_VALUE); response.setContentType(MediaType.TEXT_HTML_VALUE);
response.getWriter().write(html); response.getWriter().write(html);

View File

@ -18,6 +18,7 @@ package org.springframework.security.saml2.provider.service.web;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.function.Function;
import java.util.zip.Inflater; import java.util.zip.Inflater;
import java.util.zip.InflaterOutputStream; import java.util.zip.InflaterOutputStream;
@ -30,9 +31,12 @@ import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.security.saml2.core.Saml2Error; import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -50,6 +54,8 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver; private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
private Function<HttpServletRequest, AbstractSaml2AuthenticationRequest> loader;
/** /**
* Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for * Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for
* resolving {@link RelyingPartyRegistration}s * resolving {@link RelyingPartyRegistration}s
@ -60,6 +66,7 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) { Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null"); Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest;
} }
@Override @Override
@ -74,7 +81,25 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
} }
byte[] b = samlDecode(saml2Response); byte[] b = samlDecode(saml2Response);
saml2Response = inflateIfRequired(request, b); saml2Response = inflateIfRequired(request, b);
return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response); AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response, authenticationRequest);
}
/**
* Use the given {@link Saml2AuthenticationRequestRepository} to load authentication
* request.
* @param authenticationRequestRepository the
* {@link Saml2AuthenticationRequestRepository} to use
* @since 5.6
*/
public void setAuthenticationRequestRepository(
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository) {
Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null");
this.loader = authenticationRequestRepository::loadAuthenticationRequest;
}
private AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
return this.loader.apply(request);
} }
private String inflateIfRequired(HttpServletRequest request, byte[] b) { private String inflateIfRequired(HttpServletRequest request, byte[] b) {

View File

@ -0,0 +1,38 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.authentication;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
/**
* Tests instances for {@link Saml2AuthenticationToken}
*
* @author Marcus Da Coregio
*/
public final class TestSaml2AuthenticationTokens {
private TestSaml2AuthenticationTokens() {
}
public static Saml2AuthenticationToken token() {
RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.relyingPartyRegistration()
.build();
return new Saml2AuthenticationToken(relyingPartyRegistration, "saml2-xml-response-object");
}
}

View File

@ -0,0 +1,143 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.servlet;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockHttpSession;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* @author Marcus Da Coregio
*/
public class HttpSessionSaml2AuthenticationRequestRepositoryTests {
private static final String IDP_SSO_URL = "https://sso-url.example.com/IDP/SSO";
private MockHttpServletRequest request;
private MockHttpServletResponse response;
private HttpSessionSaml2AuthenticationRequestRepository authenticationRequestRepository;
@BeforeEach
public void setup() {
this.request = new MockHttpServletRequest();
this.response = new MockHttpServletResponse();
this.authenticationRequestRepository = new HttpSessionSaml2AuthenticationRequestRepository();
}
@Test
public void loadAuthenticationRequestWhenInvalidSessionThenNull() {
AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
.loadAuthenticationRequest(this.request);
assertThat(authenticationRequest).isNull();
}
@Test
public void loadAuthenticationRequestWhenNoAttributeInSessionThenNull() {
this.request.getSession();
AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
.loadAuthenticationRequest(this.request);
assertThat(authenticationRequest).isNull();
}
@Test
public void loadAuthenticationRequestWhenAttributeInSessionThenReturnsAuthenticationRequest() {
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
given(mockAuthenticationRequest.getAuthenticationRequestUri()).willReturn(IDP_SSO_URL);
this.request.getSession();
this.authenticationRequestRepository.saveAuthenticationRequest(mockAuthenticationRequest, this.request,
this.response);
AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
.loadAuthenticationRequest(this.request);
assertThat(authenticationRequest.getAuthenticationRequestUri()).isEqualTo(IDP_SSO_URL);
}
@Test
public void saveAuthenticationRequestWhenSessionDontExistsThenCreateAndSave() {
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
this.authenticationRequestRepository.saveAuthenticationRequest(mockAuthenticationRequest, this.request,
this.response);
AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
.loadAuthenticationRequest(this.request);
assertThat(authenticationRequest).isNotNull();
}
@Test
public void saveAuthenticationRequestWhenSessionExistsThenSave() {
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
this.request.getSession();
this.authenticationRequestRepository.saveAuthenticationRequest(mockAuthenticationRequest, this.request,
this.response);
AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
.loadAuthenticationRequest(this.request);
assertThat(authenticationRequest).isNotNull();
}
@Test
public void saveAuthenticationRequestWhenNullAuthenticationRequestThenDontSave() {
this.request.getSession();
this.authenticationRequestRepository.saveAuthenticationRequest(null, this.request, this.response);
AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
.loadAuthenticationRequest(this.request);
assertThat(authenticationRequest).isNull();
}
@Test
public void removeAuthenticationRequestWhenInvalidSessionThenReturnNull() {
AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
.removeAuthenticationRequest(this.request, this.response);
assertThat(authenticationRequest).isNull();
}
@Test
public void removeAuthenticationRequestWhenAttributeInSessionThenRemoveAuthenticationRequest() {
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
given(mockAuthenticationRequest.getAuthenticationRequestUri()).willReturn(IDP_SSO_URL);
this.request.getSession();
this.authenticationRequestRepository.saveAuthenticationRequest(mockAuthenticationRequest, this.request,
this.response);
AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
.removeAuthenticationRequest(this.request, this.response);
AbstractSaml2AuthenticationRequest authenticationRequestAfterRemove = this.authenticationRequestRepository
.loadAuthenticationRequest(this.request);
assertThat(authenticationRequest.getAuthenticationRequestUri()).isEqualTo(IDP_SSO_URL);
assertThat(authenticationRequestAfterRemove).isNull();
}
@Test
public void removeAuthenticationRequestWhenValidSessionNoAttributeThenReturnsNull() {
MockHttpSession session = mock(MockHttpSession.class);
MockHttpServletRequest request = new MockHttpServletRequest();
request.setSession(session);
AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
.removeAuthenticationRequest(request, this.response);
verify(session).getAttribute(anyString());
assertThat(authenticationRequest).isNull();
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -24,12 +24,20 @@ import org.junit.jupiter.api.Test;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationTokens;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
import org.springframework.security.web.authentication.AuthenticationConverter;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
public class Saml2WebSsoAuthenticationFilterTests { public class Saml2WebSsoAuthenticationFilterTests {
@ -84,4 +92,45 @@ public class Saml2WebSsoAuthenticationFilterTests {
.withMessage("No relying party registration found"); .withMessage("No relying party registration found");
} }
@Test
public void attemptAuthenticationWhenSavedAuthnRequestThenRemovesAuthnRequest() {
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
Saml2AuthenticationRequestRepository.class);
AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class);
given(authenticationConverter.convert(this.request)).willReturn(TestSaml2AuthenticationTokens.token());
this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverter, "/some/other/path/{registrationId}");
this.filter.setAuthenticationManager((authentication) -> null);
this.request.setPathInfo("/some/other/path/idp-registration-id");
this.filter.setAuthenticationRequestRepository(authenticationRequestRepository);
this.filter.attemptAuthentication(this.request, this.response);
verify(authenticationRequestRepository).removeAuthenticationRequest(this.request, this.response);
}
@Test
public void setAuthenticationRequestRepositoryWhenNullThenThrowsIllegalArgument() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationRequestRepository(null))
.withMessage("authenticationRequestRepository cannot be null");
}
@Test
public void setAuthenticationRequestRepositoryWhenExpectedAuthenticationConverterTypeThenSetLoaderIntoConverter() {
Saml2AuthenticationTokenConverter authenticationConverterMock = mock(Saml2AuthenticationTokenConverter.class);
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
Saml2AuthenticationRequestRepository.class);
this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverterMock,
"/some/other/path/{registrationId}");
this.filter.setAuthenticationRequestRepository(authenticationRequestRepository);
verify(authenticationConverterMock).setAuthenticationRequestRepository(authenticationRequestRepository);
}
@Test
public void setAuthenticationRequestRepositoryWhenNotExpectedAuthenticationConverterTypeThenDontSet() {
AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class);
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
Saml2AuthenticationRequestRepository.class);
this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverter, "/some/other/path/{registrationId}");
this.filter.setAuthenticationRequestRepository(authenticationRequestRepository);
verifyNoInteractions(authenticationConverter);
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,6 +28,7 @@ import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.saml2.credentials.TestSaml2X509Credentials; import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
@ -36,6 +37,7 @@ import org.springframework.security.saml2.provider.service.authentication.TestSa
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.web.util.HtmlUtils; import org.springframework.web.util.HtmlUtils;
import org.springframework.web.util.UriUtils; import org.springframework.web.util.UriUtils;
@ -43,6 +45,7 @@ import org.springframework.web.util.UriUtils;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -60,6 +63,9 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
private Saml2AuthenticationRequestContextResolver resolver = mock(Saml2AuthenticationRequestContextResolver.class); private Saml2AuthenticationRequestContextResolver resolver = mock(Saml2AuthenticationRequestContextResolver.class);
private Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
Saml2AuthenticationRequestRepository.class);
private MockHttpServletRequest request; private MockHttpServletRequest request;
private MockHttpServletResponse response; private MockHttpServletResponse response;
@ -79,6 +85,7 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
.providerDetails((c) -> c.entityId("idp-entity-id")).providerDetails((c) -> c.webSsoUrl(IDP_SSO_URL)) .providerDetails((c) -> c.entityId("idp-entity-id")).providerDetails((c) -> c.webSsoUrl(IDP_SSO_URL))
.assertionConsumerServiceUrlTemplate("template") .assertionConsumerServiceUrlTemplate("template")
.credentials((c) -> c.add(TestSaml2X509Credentials.assertingPartyPrivateCredential())); .credentials((c) -> c.add(TestSaml2X509Credentials.assertingPartyPrivateCredential()));
this.filter.setAuthenticationRequestRepository(this.authenticationRequestRepository);
} }
@Test @Test
@ -216,4 +223,37 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
assertThat(this.response.getStatus()).isEqualTo(401); assertThat(this.response.getStatus()).isEqualTo(401);
} }
@Test
public void setAuthenticationRequestRepositoryWhenNullThenException() {
Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(this.resolver,
this.factory);
assertThatIllegalArgumentException().isThrownBy(() -> filter.setAuthenticationRequestRepository(null));
}
@Test
public void doFilterWhenRedirectThenSaveRedirectRequest() throws ServletException, IOException {
Saml2AuthenticationRequestContext context = authenticationRequestContext().build();
Saml2RedirectAuthenticationRequest request = redirectAuthenticationRequest(context).build();
given(this.resolver.resolve(any())).willReturn(context);
given(this.factory.createRedirectAuthenticationRequest(any())).willReturn(request);
this.filter.doFilterInternal(this.request, this.response, this.filterChain);
verify(this.authenticationRequestRepository).saveAuthenticationRequest(
any(Saml2RedirectAuthenticationRequest.class), eq(this.request), eq(this.response));
}
@Test
public void doFilterWhenPostThenSaveRedirectRequest() throws ServletException, IOException {
RelyingPartyRegistration registration = this.rpBuilder
.assertingPartyDetails((asserting) -> asserting.singleSignOnServiceBinding(Saml2MessageBinding.POST))
.build();
Saml2AuthenticationRequestContext context = authenticationRequestContext()
.relyingPartyRegistration(registration).build();
Saml2PostAuthenticationRequest request = postAuthenticationRequest(context).build();
given(this.resolver.resolve(any())).willReturn(context);
given(this.factory.createPostAuthenticationRequest(any())).willReturn(request);
this.filter.doFilterInternal(this.request, this.response, this.filterChain);
verify(this.authenticationRequestRepository).saveAuthenticationRequest(
any(Saml2PostAuthenticationRequest.class), eq(this.request), eq(this.response));
}
} }

View File

@ -31,10 +31,12 @@ import org.springframework.core.io.ClassPathResource;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.core.Saml2Utils; import org.springframework.security.saml2.core.Saml2Utils;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
import org.springframework.util.StreamUtils; import org.springframework.util.StreamUtils;
import org.springframework.web.util.UriUtils; import org.springframework.web.util.UriUtils;
@ -43,6 +45,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
public class Saml2AuthenticationTokenConverterTests { public class Saml2AuthenticationTokenConverterTests {
@ -155,6 +158,27 @@ public class Saml2AuthenticationTokenConverterTests {
validateSsoCircleXml(token.getSaml2Response()); validateSsoCircleXml(token.getSaml2Response());
} }
@Test
public void convertWhenSavedAuthenticationRequestThenToken() {
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
Saml2AuthenticationRequestRepository.class);
AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
this.relyingPartyRegistrationResolver);
converter.setAuthenticationRequestRepository(authenticationRequestRepository);
given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
.willReturn(this.relyingPartyRegistration);
given(authenticationRequestRepository.loadAuthenticationRequest(any(HttpServletRequest.class)))
.willReturn(authenticationRequest);
MockHttpServletRequest request = new MockHttpServletRequest();
request.setParameter("SAMLResponse", Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8)));
Saml2AuthenticationToken token = converter.convert(request);
assertThat(token.getSaml2Response()).isEqualTo("response");
assertThat(token.getRelyingPartyRegistration().getRegistrationId())
.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest);
}
private void validateSsoCircleXml(String xml) { private void validateSsoCircleXml(String xml) {
assertThat(xml).contains("InResponseTo=\"ARQ9a73ead-7dcf-45a8-89eb-26f3c9900c36\"") assertThat(xml).contains("InResponseTo=\"ARQ9a73ead-7dcf-45a8-89eb-26f3c9900c36\"")
.contains(" ID=\"s246d157446618e90e43fb79bdd4d9e9e19cf2c7c4\"") .contains(" ID=\"s246d157446618e90e43fb79bdd4d9e9e19cf2c7c4\"")