Polish AuthnRequest Customization Support
Having the application generate the AuthnRequest fresh allows Spring Security to back away more gracefully. Using a Consumer implies that the application will need to undo any values that Spring Security set that the application doesn't want. Also, if this does become a configuration burden, it can be simplified in a separate ticket by exposing the default Converter. Issue gh-8776
This commit is contained in:
parent
3694485056
commit
af5c55c380
|
@ -35,6 +35,7 @@ import org.junit.Before;
|
|||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.opensaml.saml.saml2.core.Assertion;
|
||||
import org.opensaml.saml.saml2.core.AuthnRequest;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.ConfigurableApplicationContext;
|
||||
|
@ -89,6 +90,7 @@ import static org.mockito.Mockito.verify;
|
|||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.security.config.Customizer.withDefaults;
|
||||
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
|
||||
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest;
|
||||
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
|
||||
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials;
|
||||
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
|
||||
|
@ -176,8 +178,8 @@ public class Saml2LoginConfigurerTests {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void authenticationRequestWhenAuthnRequestConsumerResolverThenUses() throws Exception {
|
||||
this.spring.register(CustomAuthnRequestConsumerResolver.class).autowire();
|
||||
public void authenticationRequestWhenAuthnRequestContextConverterThenUses() throws Exception {
|
||||
this.spring.register(CustomAuthenticationRequestContextConverterResolver.class).autowire();
|
||||
|
||||
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id"))
|
||||
.andReturn();
|
||||
|
@ -315,7 +317,7 @@ public class Saml2LoginConfigurerTests {
|
|||
|
||||
@EnableWebSecurity
|
||||
@Import(Saml2LoginConfigBeans.class)
|
||||
static class CustomAuthnRequestConsumerResolver extends WebSecurityConfigurerAdapter {
|
||||
static class CustomAuthenticationRequestContextConverterResolver extends WebSecurityConfigurerAdapter {
|
||||
|
||||
@Override
|
||||
protected void configure(HttpSecurity http) throws Exception {
|
||||
|
@ -330,8 +332,12 @@ public class Saml2LoginConfigurerTests {
|
|||
Saml2AuthenticationRequestFactory authenticationRequestFactory() {
|
||||
OpenSamlAuthenticationRequestFactory authenticationRequestFactory =
|
||||
new OpenSamlAuthenticationRequestFactory();
|
||||
authenticationRequestFactory.setAuthnRequestConsumerResolver(
|
||||
context -> authnRequest -> authnRequest.setForceAuthn(true));
|
||||
authenticationRequestFactory.setAuthenticationRequestContextConverter(
|
||||
context -> {
|
||||
AuthnRequest authnRequest = authnRequest();
|
||||
authnRequest.setForceAuthn(true);
|
||||
return authnRequest;
|
||||
});
|
||||
return authenticationRequestFactory;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,8 +25,6 @@ import java.util.Collection;
|
|||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Function;
|
||||
|
||||
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
|
||||
import org.joda.time.DateTime;
|
||||
|
@ -88,8 +86,8 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
|
|||
return context.getRelyingPartyRegistration().getAssertionConsumerServiceBinding().getUrn();
|
||||
};
|
||||
|
||||
private Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver
|
||||
= context -> authnRequest -> {};
|
||||
private Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter
|
||||
= this::createAuthnRequest;
|
||||
|
||||
/**
|
||||
* Creates an {@link OpenSamlAuthenticationRequestFactory}
|
||||
|
@ -124,7 +122,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
|
|||
*/
|
||||
@Override
|
||||
public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) {
|
||||
AuthnRequest authnRequest = createAuthnRequest(context);
|
||||
AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context);
|
||||
String xml = context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned() ?
|
||||
serialize(sign(authnRequest, context.getRelyingPartyRegistration())) :
|
||||
serialize(authnRequest);
|
||||
|
@ -139,7 +137,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
|
|||
*/
|
||||
@Override
|
||||
public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext context) {
|
||||
AuthnRequest authnRequest = createAuthnRequest(context);
|
||||
AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context);
|
||||
String xml = serialize(authnRequest);
|
||||
Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context);
|
||||
String deflatedAndEncoded = samlEncode(samlDeflate(xml));
|
||||
|
@ -168,11 +166,9 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
|
|||
}
|
||||
|
||||
private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
|
||||
AuthnRequest authnRequest = createAuthnRequest(context.getIssuer(),
|
||||
return createAuthnRequest(context.getIssuer(),
|
||||
context.getDestination(), context.getAssertionConsumerServiceUrl(),
|
||||
this.protocolBindingResolver.convert(context));
|
||||
this.authnRequestConsumerResolver.apply(context).accept(authnRequest);
|
||||
return authnRequest;
|
||||
}
|
||||
|
||||
private AuthnRequest createAuthnRequest
|
||||
|
@ -194,13 +190,13 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
|
|||
/**
|
||||
* Set the {@link AuthnRequest} post-processor resolver
|
||||
*
|
||||
* @param authnRequestConsumerResolver
|
||||
* @param authenticationRequestContextConverter
|
||||
* @since 5.4
|
||||
*/
|
||||
public void setAuthnRequestConsumerResolver(
|
||||
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver) {
|
||||
Assert.notNull(authnRequestConsumerResolver, "authnRequestConsumerResolver cannot be null");
|
||||
this.authnRequestConsumerResolver = authnRequestConsumerResolver;
|
||||
public void setAuthenticationRequestContextConverter(
|
||||
Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter) {
|
||||
Assert.notNull(authenticationRequestContextConverter, "authenticationRequestContextConverter cannot be null");
|
||||
this.authenticationRequestContextConverter = authenticationRequestContextConverter;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
package org.springframework.security.saml2.provider.service.authentication;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Function;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
|
@ -31,6 +29,7 @@ import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller;
|
|||
import org.w3c.dom.Document;
|
||||
import org.w3c.dom.Element;
|
||||
|
||||
import org.springframework.core.convert.converter.Converter;
|
||||
import org.springframework.security.saml2.Saml2Exception;
|
||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
||||
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
|
||||
|
@ -47,6 +46,7 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getU
|
|||
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
|
||||
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
|
||||
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate;
|
||||
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest;
|
||||
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
|
||||
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
|
||||
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
|
||||
|
@ -63,8 +63,7 @@ public class OpenSamlAuthenticationRequestFactoryTests {
|
|||
private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder;
|
||||
private RelyingPartyRegistration relyingPartyRegistration;
|
||||
|
||||
private AuthnRequestUnmarshaller unmarshaller = (AuthnRequestUnmarshaller) getUnmarshallerFactory()
|
||||
.getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
|
||||
private AuthnRequestUnmarshaller unmarshaller;
|
||||
|
||||
@Rule
|
||||
public ExpectedException exception = ExpectedException.none();
|
||||
|
@ -84,6 +83,8 @@ public class OpenSamlAuthenticationRequestFactoryTests {
|
|||
.assertionConsumerServiceUrl("https://issuer/sso");
|
||||
context = contextBuilder.build();
|
||||
factory = new OpenSamlAuthenticationRequestFactory();
|
||||
this.unmarshaller =(AuthnRequestUnmarshaller) getUnmarshallerFactory()
|
||||
.getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -182,29 +183,29 @@ public class OpenSamlAuthenticationRequestFactoryTests {
|
|||
|
||||
@Test
|
||||
public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
|
||||
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
|
||||
mock(Function.class);
|
||||
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
|
||||
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
|
||||
Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter =
|
||||
mock(Converter.class);
|
||||
when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest());
|
||||
this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter);
|
||||
|
||||
this.factory.createPostAuthenticationRequest(this.context);
|
||||
verify(authnRequestConsumerResolver).apply(this.context);
|
||||
verify(authenticationRequestContextConverter).convert(this.context);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
|
||||
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
|
||||
mock(Function.class);
|
||||
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
|
||||
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
|
||||
Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter =
|
||||
mock(Converter.class);
|
||||
when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest());
|
||||
this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter);
|
||||
|
||||
this.factory.createRedirectAuthenticationRequest(this.context);
|
||||
verify(authnRequestConsumerResolver).apply(this.context);
|
||||
verify(authenticationRequestContextConverter).convert(this.context);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setAuthnRequestConsumerResolverWhenNullThenException() {
|
||||
assertThatCode(() -> this.factory.setAuthnRequestConsumerResolver(null))
|
||||
public void setAuthenticationRequestContextConverterWhenNullThenException() {
|
||||
assertThatCode(() -> this.factory.setAuthenticationRequestContextConverter(null))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
|
|
|
@ -53,6 +53,7 @@ import org.opensaml.saml.saml2.core.Assertion;
|
|||
import org.opensaml.saml.saml2.core.Attribute;
|
||||
import org.opensaml.saml.saml2.core.AttributeStatement;
|
||||
import org.opensaml.saml.saml2.core.AttributeValue;
|
||||
import org.opensaml.saml.saml2.core.AuthnRequest;
|
||||
import org.opensaml.saml.saml2.core.Conditions;
|
||||
import org.opensaml.saml.saml2.core.EncryptedAssertion;
|
||||
import org.opensaml.saml.saml2.core.EncryptedID;
|
||||
|
@ -86,7 +87,7 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getB
|
|||
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
|
||||
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.assertingPartySigningCredential;
|
||||
|
||||
final class TestOpenSamlObjects {
|
||||
public final class TestOpenSamlObjects {
|
||||
static {
|
||||
OpenSamlInitializationService.initialize();
|
||||
}
|
||||
|
@ -188,6 +189,16 @@ final class TestOpenSamlObjects {
|
|||
return conditions;
|
||||
}
|
||||
|
||||
public static AuthnRequest authnRequest() {
|
||||
Issuer issuer = build(Issuer.DEFAULT_ELEMENT_NAME);
|
||||
issuer.setValue(ASSERTING_PARTY_ENTITY_ID);
|
||||
AuthnRequest authnRequest = build(AuthnRequest.DEFAULT_ELEMENT_NAME);
|
||||
authnRequest.setIssuer(issuer);
|
||||
authnRequest.setDestination(ASSERTING_PARTY_ENTITY_ID + "/SSO.saml2");
|
||||
authnRequest.setAssertionConsumerServiceURL(DESTINATION);
|
||||
return authnRequest;
|
||||
}
|
||||
|
||||
static Credential getSigningCredential(Saml2X509Credential credential, String entityId) {
|
||||
BasicCredential cred = getBasicCredential(credential);
|
||||
cred.setEntityId(entityId);
|
||||
|
|
Loading…
Reference in New Issue