Polish AuthnRequest Customization Support

Having the application generate the AuthnRequest fresh allows Spring
Security to back away more gracefully. Using a Consumer implies that
the application will need to undo any values that Spring Security set
that the application doesn't want.

Also, if this does become a configuration burden, it can be simplified
in a separate ticket by exposing the default Converter.

Issue gh-8776
This commit is contained in:
Josh Cummings 2020-08-19 12:16:27 -06:00
parent 3694485056
commit af5c55c380
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
4 changed files with 50 additions and 36 deletions

View File

@ -35,6 +35,7 @@ import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ConfigurableApplicationContext; import org.springframework.context.ConfigurableApplicationContext;
@ -89,6 +90,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential; import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext; import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials; import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
@ -176,8 +178,8 @@ public class Saml2LoginConfigurerTests {
} }
@Test @Test
public void authenticationRequestWhenAuthnRequestConsumerResolverThenUses() throws Exception { public void authenticationRequestWhenAuthnRequestContextConverterThenUses() throws Exception {
this.spring.register(CustomAuthnRequestConsumerResolver.class).autowire(); this.spring.register(CustomAuthenticationRequestContextConverterResolver.class).autowire();
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")) MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id"))
.andReturn(); .andReturn();
@ -315,7 +317,7 @@ public class Saml2LoginConfigurerTests {
@EnableWebSecurity @EnableWebSecurity
@Import(Saml2LoginConfigBeans.class) @Import(Saml2LoginConfigBeans.class)
static class CustomAuthnRequestConsumerResolver extends WebSecurityConfigurerAdapter { static class CustomAuthenticationRequestContextConverterResolver extends WebSecurityConfigurerAdapter {
@Override @Override
protected void configure(HttpSecurity http) throws Exception { protected void configure(HttpSecurity http) throws Exception {
@ -330,8 +332,12 @@ public class Saml2LoginConfigurerTests {
Saml2AuthenticationRequestFactory authenticationRequestFactory() { Saml2AuthenticationRequestFactory authenticationRequestFactory() {
OpenSamlAuthenticationRequestFactory authenticationRequestFactory = OpenSamlAuthenticationRequestFactory authenticationRequestFactory =
new OpenSamlAuthenticationRequestFactory(); new OpenSamlAuthenticationRequestFactory();
authenticationRequestFactory.setAuthnRequestConsumerResolver( authenticationRequestFactory.setAuthenticationRequestContextConverter(
context -> authnRequest -> authnRequest.setForceAuthn(true)); context -> {
AuthnRequest authnRequest = authnRequest();
authnRequest.setForceAuthn(true);
return authnRequest;
});
return authenticationRequestFactory; return authenticationRequestFactory;
} }
} }

View File

@ -25,8 +25,6 @@ import java.util.Collection;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
import java.util.function.Consumer;
import java.util.function.Function;
import net.shibboleth.utilities.java.support.xml.SerializeSupport; import net.shibboleth.utilities.java.support.xml.SerializeSupport;
import org.joda.time.DateTime; import org.joda.time.DateTime;
@ -88,8 +86,8 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
return context.getRelyingPartyRegistration().getAssertionConsumerServiceBinding().getUrn(); return context.getRelyingPartyRegistration().getAssertionConsumerServiceBinding().getUrn();
}; };
private Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver private Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter
= context -> authnRequest -> {}; = this::createAuthnRequest;
/** /**
* Creates an {@link OpenSamlAuthenticationRequestFactory} * Creates an {@link OpenSamlAuthenticationRequestFactory}
@ -124,7 +122,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
*/ */
@Override @Override
public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) { public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) {
AuthnRequest authnRequest = createAuthnRequest(context); AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context);
String xml = context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned() ? String xml = context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned() ?
serialize(sign(authnRequest, context.getRelyingPartyRegistration())) : serialize(sign(authnRequest, context.getRelyingPartyRegistration())) :
serialize(authnRequest); serialize(authnRequest);
@ -139,7 +137,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
*/ */
@Override @Override
public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext context) { public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext context) {
AuthnRequest authnRequest = createAuthnRequest(context); AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context);
String xml = serialize(authnRequest); String xml = serialize(authnRequest);
Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context); Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context);
String deflatedAndEncoded = samlEncode(samlDeflate(xml)); String deflatedAndEncoded = samlEncode(samlDeflate(xml));
@ -168,11 +166,9 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
} }
private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) { private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
AuthnRequest authnRequest = createAuthnRequest(context.getIssuer(), return createAuthnRequest(context.getIssuer(),
context.getDestination(), context.getAssertionConsumerServiceUrl(), context.getDestination(), context.getAssertionConsumerServiceUrl(),
this.protocolBindingResolver.convert(context)); this.protocolBindingResolver.convert(context));
this.authnRequestConsumerResolver.apply(context).accept(authnRequest);
return authnRequest;
} }
private AuthnRequest createAuthnRequest private AuthnRequest createAuthnRequest
@ -194,13 +190,13 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
/** /**
* Set the {@link AuthnRequest} post-processor resolver * Set the {@link AuthnRequest} post-processor resolver
* *
* @param authnRequestConsumerResolver * @param authenticationRequestContextConverter
* @since 5.4 * @since 5.4
*/ */
public void setAuthnRequestConsumerResolver( public void setAuthenticationRequestContextConverter(
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver) { Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter) {
Assert.notNull(authnRequestConsumerResolver, "authnRequestConsumerResolver cannot be null"); Assert.notNull(authenticationRequestContextConverter, "authenticationRequestContextConverter cannot be null");
this.authnRequestConsumerResolver = authnRequestConsumerResolver; this.authenticationRequestContextConverter = authenticationRequestContextConverter;
} }
/** /**

View File

@ -17,8 +17,6 @@
package org.springframework.security.saml2.provider.service.authentication; package org.springframework.security.saml2.provider.service.authentication;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.util.function.Consumer;
import java.util.function.Function;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
@ -31,6 +29,7 @@ import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller;
import org.w3c.dom.Document; import org.w3c.dom.Document;
import org.w3c.dom.Element; import org.w3c.dom.Element;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
@ -47,6 +46,7 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getU
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode; import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate; import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest;
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration; import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST; import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT; import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
@ -63,8 +63,7 @@ public class OpenSamlAuthenticationRequestFactoryTests {
private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder; private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder;
private RelyingPartyRegistration relyingPartyRegistration; private RelyingPartyRegistration relyingPartyRegistration;
private AuthnRequestUnmarshaller unmarshaller = (AuthnRequestUnmarshaller) getUnmarshallerFactory() private AuthnRequestUnmarshaller unmarshaller;
.getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
@Rule @Rule
public ExpectedException exception = ExpectedException.none(); public ExpectedException exception = ExpectedException.none();
@ -84,6 +83,8 @@ public class OpenSamlAuthenticationRequestFactoryTests {
.assertionConsumerServiceUrl("https://issuer/sso"); .assertionConsumerServiceUrl("https://issuer/sso");
context = contextBuilder.build(); context = contextBuilder.build();
factory = new OpenSamlAuthenticationRequestFactory(); factory = new OpenSamlAuthenticationRequestFactory();
this.unmarshaller =(AuthnRequestUnmarshaller) getUnmarshallerFactory()
.getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
} }
@Test @Test
@ -182,29 +183,29 @@ public class OpenSamlAuthenticationRequestFactoryTests {
@Test @Test
public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() { public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver = Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter =
mock(Function.class); mock(Converter.class);
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {}); when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest());
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver); this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter);
this.factory.createPostAuthenticationRequest(this.context); this.factory.createPostAuthenticationRequest(this.context);
verify(authnRequestConsumerResolver).apply(this.context); verify(authenticationRequestContextConverter).convert(this.context);
} }
@Test @Test
public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() { public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver = Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter =
mock(Function.class); mock(Converter.class);
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {}); when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest());
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver); this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter);
this.factory.createRedirectAuthenticationRequest(this.context); this.factory.createRedirectAuthenticationRequest(this.context);
verify(authnRequestConsumerResolver).apply(this.context); verify(authenticationRequestContextConverter).convert(this.context);
} }
@Test @Test
public void setAuthnRequestConsumerResolverWhenNullThenException() { public void setAuthenticationRequestContextConverterWhenNullThenException() {
assertThatCode(() -> this.factory.setAuthnRequestConsumerResolver(null)) assertThatCode(() -> this.factory.setAuthenticationRequestContextConverter(null))
.isInstanceOf(IllegalArgumentException.class); .isInstanceOf(IllegalArgumentException.class);
} }

View File

@ -53,6 +53,7 @@ import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute; import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement; import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AttributeValue; import org.opensaml.saml.saml2.core.AttributeValue;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.Conditions; import org.opensaml.saml.saml2.core.Conditions;
import org.opensaml.saml.saml2.core.EncryptedAssertion; import org.opensaml.saml.saml2.core.EncryptedAssertion;
import org.opensaml.saml.saml2.core.EncryptedID; import org.opensaml.saml.saml2.core.EncryptedID;
@ -86,7 +87,7 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getB
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS; import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.assertingPartySigningCredential; import static org.springframework.security.saml2.core.TestSaml2X509Credentials.assertingPartySigningCredential;
final class TestOpenSamlObjects { public final class TestOpenSamlObjects {
static { static {
OpenSamlInitializationService.initialize(); OpenSamlInitializationService.initialize();
} }
@ -188,6 +189,16 @@ final class TestOpenSamlObjects {
return conditions; return conditions;
} }
public static AuthnRequest authnRequest() {
Issuer issuer = build(Issuer.DEFAULT_ELEMENT_NAME);
issuer.setValue(ASSERTING_PARTY_ENTITY_ID);
AuthnRequest authnRequest = build(AuthnRequest.DEFAULT_ELEMENT_NAME);
authnRequest.setIssuer(issuer);
authnRequest.setDestination(ASSERTING_PARTY_ENTITY_ID + "/SSO.saml2");
authnRequest.setAssertionConsumerServiceURL(DESTINATION);
return authnRequest;
}
static Credential getSigningCredential(Saml2X509Credential credential, String entityId) { static Credential getSigningCredential(Saml2X509Credential credential, String entityId) {
BasicCredential cred = getBasicCredential(credential); BasicCredential cred = getBasicCredential(credential);
cred.setEntityId(entityId); cred.setEntityId(entityId);