Polish AuthnRequest Customization Support

Having the application generate the AuthnRequest fresh allows Spring
Security to back away more gracefully. Using a Consumer implies that
the application will need to undo any values that Spring Security set
that the application doesn't want.

Also, if this does become a configuration burden, it can be simplified
in a separate ticket by exposing the default Converter.

Issue gh-8776
This commit is contained in:
Josh Cummings 2020-08-19 12:16:27 -06:00
parent 3694485056
commit af5c55c380
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
4 changed files with 50 additions and 36 deletions

View File

@ -35,6 +35,7 @@ import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ConfigurableApplicationContext;
@ -89,6 +90,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
@ -176,8 +178,8 @@ public class Saml2LoginConfigurerTests {
}
@Test
public void authenticationRequestWhenAuthnRequestConsumerResolverThenUses() throws Exception {
this.spring.register(CustomAuthnRequestConsumerResolver.class).autowire();
public void authenticationRequestWhenAuthnRequestContextConverterThenUses() throws Exception {
this.spring.register(CustomAuthenticationRequestContextConverterResolver.class).autowire();
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id"))
.andReturn();
@ -315,7 +317,7 @@ public class Saml2LoginConfigurerTests {
@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthnRequestConsumerResolver extends WebSecurityConfigurerAdapter {
static class CustomAuthenticationRequestContextConverterResolver extends WebSecurityConfigurerAdapter {
@Override
protected void configure(HttpSecurity http) throws Exception {
@ -330,8 +332,12 @@ public class Saml2LoginConfigurerTests {
Saml2AuthenticationRequestFactory authenticationRequestFactory() {
OpenSamlAuthenticationRequestFactory authenticationRequestFactory =
new OpenSamlAuthenticationRequestFactory();
authenticationRequestFactory.setAuthnRequestConsumerResolver(
context -> authnRequest -> authnRequest.setForceAuthn(true));
authenticationRequestFactory.setAuthenticationRequestContextConverter(
context -> {
AuthnRequest authnRequest = authnRequest();
authnRequest.setForceAuthn(true);
return authnRequest;
});
return authenticationRequestFactory;
}
}

View File

@ -25,8 +25,6 @@ import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.UUID;
import java.util.function.Consumer;
import java.util.function.Function;
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
import org.joda.time.DateTime;
@ -88,8 +86,8 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
return context.getRelyingPartyRegistration().getAssertionConsumerServiceBinding().getUrn();
};
private Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver
= context -> authnRequest -> {};
private Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter
= this::createAuthnRequest;
/**
* Creates an {@link OpenSamlAuthenticationRequestFactory}
@ -124,7 +122,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
*/
@Override
public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) {
AuthnRequest authnRequest = createAuthnRequest(context);
AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context);
String xml = context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned() ?
serialize(sign(authnRequest, context.getRelyingPartyRegistration())) :
serialize(authnRequest);
@ -139,7 +137,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
*/
@Override
public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext context) {
AuthnRequest authnRequest = createAuthnRequest(context);
AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context);
String xml = serialize(authnRequest);
Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context);
String deflatedAndEncoded = samlEncode(samlDeflate(xml));
@ -168,11 +166,9 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
}
private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
AuthnRequest authnRequest = createAuthnRequest(context.getIssuer(),
return createAuthnRequest(context.getIssuer(),
context.getDestination(), context.getAssertionConsumerServiceUrl(),
this.protocolBindingResolver.convert(context));
this.authnRequestConsumerResolver.apply(context).accept(authnRequest);
return authnRequest;
}
private AuthnRequest createAuthnRequest
@ -194,13 +190,13 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
/**
* Set the {@link AuthnRequest} post-processor resolver
*
* @param authnRequestConsumerResolver
* @param authenticationRequestContextConverter
* @since 5.4
*/
public void setAuthnRequestConsumerResolver(
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver) {
Assert.notNull(authnRequestConsumerResolver, "authnRequestConsumerResolver cannot be null");
this.authnRequestConsumerResolver = authnRequestConsumerResolver;
public void setAuthenticationRequestContextConverter(
Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter) {
Assert.notNull(authenticationRequestContextConverter, "authenticationRequestContextConverter cannot be null");
this.authenticationRequestContextConverter = authenticationRequestContextConverter;
}
/**

View File

@ -17,8 +17,6 @@
package org.springframework.security.saml2.provider.service.authentication;
import java.io.ByteArrayInputStream;
import java.util.function.Consumer;
import java.util.function.Function;
import org.junit.Assert;
import org.junit.Before;
@ -31,6 +29,7 @@ import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
@ -47,6 +46,7 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getU
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest;
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
@ -63,8 +63,7 @@ public class OpenSamlAuthenticationRequestFactoryTests {
private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder;
private RelyingPartyRegistration relyingPartyRegistration;
private AuthnRequestUnmarshaller unmarshaller = (AuthnRequestUnmarshaller) getUnmarshallerFactory()
.getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
private AuthnRequestUnmarshaller unmarshaller;
@Rule
public ExpectedException exception = ExpectedException.none();
@ -84,6 +83,8 @@ public class OpenSamlAuthenticationRequestFactoryTests {
.assertionConsumerServiceUrl("https://issuer/sso");
context = contextBuilder.build();
factory = new OpenSamlAuthenticationRequestFactory();
this.unmarshaller =(AuthnRequestUnmarshaller) getUnmarshallerFactory()
.getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
}
@Test
@ -182,29 +183,29 @@ public class OpenSamlAuthenticationRequestFactoryTests {
@Test
public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
mock(Function.class);
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter =
mock(Converter.class);
when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest());
this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter);
this.factory.createPostAuthenticationRequest(this.context);
verify(authnRequestConsumerResolver).apply(this.context);
verify(authenticationRequestContextConverter).convert(this.context);
}
@Test
public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
mock(Function.class);
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter =
mock(Converter.class);
when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest());
this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter);
this.factory.createRedirectAuthenticationRequest(this.context);
verify(authnRequestConsumerResolver).apply(this.context);
verify(authenticationRequestContextConverter).convert(this.context);
}
@Test
public void setAuthnRequestConsumerResolverWhenNullThenException() {
assertThatCode(() -> this.factory.setAuthnRequestConsumerResolver(null))
public void setAuthenticationRequestContextConverterWhenNullThenException() {
assertThatCode(() -> this.factory.setAuthenticationRequestContextConverter(null))
.isInstanceOf(IllegalArgumentException.class);
}

View File

@ -53,6 +53,7 @@ import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AttributeValue;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.Conditions;
import org.opensaml.saml.saml2.core.EncryptedAssertion;
import org.opensaml.saml.saml2.core.EncryptedID;
@ -86,7 +87,7 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getB
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.assertingPartySigningCredential;
final class TestOpenSamlObjects {
public final class TestOpenSamlObjects {
static {
OpenSamlInitializationService.initialize();
}
@ -188,6 +189,16 @@ final class TestOpenSamlObjects {
return conditions;
}
public static AuthnRequest authnRequest() {
Issuer issuer = build(Issuer.DEFAULT_ELEMENT_NAME);
issuer.setValue(ASSERTING_PARTY_ENTITY_ID);
AuthnRequest authnRequest = build(AuthnRequest.DEFAULT_ELEMENT_NAME);
authnRequest.setIssuer(issuer);
authnRequest.setDestination(ASSERTING_PARTY_ENTITY_ID + "/SSO.saml2");
authnRequest.setAssertionConsumerServiceURL(DESTINATION);
return authnRequest;
}
static Credential getSigningCredential(Saml2X509Credential credential, String entityId) {
BasicCredential cred = getBasicCredential(credential);
cred.setEntityId(entityId);