mirror of
https://github.com/spring-projects/spring-security.git
synced 2025-07-17 07:43:30 +00:00
Polish AuthnRequest Customization Support
Having the application generate the AuthnRequest fresh allows Spring Security to back away more gracefully. Using a Consumer implies that the application will need to undo any values that Spring Security set that the application doesn't want. Also, if this does become a configuration burden, it can be simplified in a separate ticket by exposing the default Converter. Issue gh-8776
This commit is contained in:
parent
3694485056
commit
af5c55c380
@ -35,6 +35,7 @@ import org.junit.Before;
|
|||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.opensaml.saml.saml2.core.Assertion;
|
import org.opensaml.saml.saml2.core.Assertion;
|
||||||
|
import org.opensaml.saml.saml2.core.AuthnRequest;
|
||||||
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.context.ConfigurableApplicationContext;
|
import org.springframework.context.ConfigurableApplicationContext;
|
||||||
@ -89,6 +90,7 @@ import static org.mockito.Mockito.verify;
|
|||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
import static org.springframework.security.config.Customizer.withDefaults;
|
import static org.springframework.security.config.Customizer.withDefaults;
|
||||||
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
|
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
|
||||||
|
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest;
|
||||||
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
|
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
|
||||||
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials;
|
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials;
|
||||||
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
|
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
|
||||||
@ -176,8 +178,8 @@ public class Saml2LoginConfigurerTests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void authenticationRequestWhenAuthnRequestConsumerResolverThenUses() throws Exception {
|
public void authenticationRequestWhenAuthnRequestContextConverterThenUses() throws Exception {
|
||||||
this.spring.register(CustomAuthnRequestConsumerResolver.class).autowire();
|
this.spring.register(CustomAuthenticationRequestContextConverterResolver.class).autowire();
|
||||||
|
|
||||||
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id"))
|
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id"))
|
||||||
.andReturn();
|
.andReturn();
|
||||||
@ -315,7 +317,7 @@ public class Saml2LoginConfigurerTests {
|
|||||||
|
|
||||||
@EnableWebSecurity
|
@EnableWebSecurity
|
||||||
@Import(Saml2LoginConfigBeans.class)
|
@Import(Saml2LoginConfigBeans.class)
|
||||||
static class CustomAuthnRequestConsumerResolver extends WebSecurityConfigurerAdapter {
|
static class CustomAuthenticationRequestContextConverterResolver extends WebSecurityConfigurerAdapter {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void configure(HttpSecurity http) throws Exception {
|
protected void configure(HttpSecurity http) throws Exception {
|
||||||
@ -330,8 +332,12 @@ public class Saml2LoginConfigurerTests {
|
|||||||
Saml2AuthenticationRequestFactory authenticationRequestFactory() {
|
Saml2AuthenticationRequestFactory authenticationRequestFactory() {
|
||||||
OpenSamlAuthenticationRequestFactory authenticationRequestFactory =
|
OpenSamlAuthenticationRequestFactory authenticationRequestFactory =
|
||||||
new OpenSamlAuthenticationRequestFactory();
|
new OpenSamlAuthenticationRequestFactory();
|
||||||
authenticationRequestFactory.setAuthnRequestConsumerResolver(
|
authenticationRequestFactory.setAuthenticationRequestContextConverter(
|
||||||
context -> authnRequest -> authnRequest.setForceAuthn(true));
|
context -> {
|
||||||
|
AuthnRequest authnRequest = authnRequest();
|
||||||
|
authnRequest.setForceAuthn(true);
|
||||||
|
return authnRequest;
|
||||||
|
});
|
||||||
return authenticationRequestFactory;
|
return authenticationRequestFactory;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -25,8 +25,6 @@ import java.util.Collection;
|
|||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import java.util.function.Consumer;
|
|
||||||
import java.util.function.Function;
|
|
||||||
|
|
||||||
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
|
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
|
||||||
import org.joda.time.DateTime;
|
import org.joda.time.DateTime;
|
||||||
@ -88,8 +86,8 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
|
|||||||
return context.getRelyingPartyRegistration().getAssertionConsumerServiceBinding().getUrn();
|
return context.getRelyingPartyRegistration().getAssertionConsumerServiceBinding().getUrn();
|
||||||
};
|
};
|
||||||
|
|
||||||
private Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver
|
private Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter
|
||||||
= context -> authnRequest -> {};
|
= this::createAuthnRequest;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates an {@link OpenSamlAuthenticationRequestFactory}
|
* Creates an {@link OpenSamlAuthenticationRequestFactory}
|
||||||
@ -124,7 +122,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
|
|||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) {
|
public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) {
|
||||||
AuthnRequest authnRequest = createAuthnRequest(context);
|
AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context);
|
||||||
String xml = context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned() ?
|
String xml = context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned() ?
|
||||||
serialize(sign(authnRequest, context.getRelyingPartyRegistration())) :
|
serialize(sign(authnRequest, context.getRelyingPartyRegistration())) :
|
||||||
serialize(authnRequest);
|
serialize(authnRequest);
|
||||||
@ -139,7 +137,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
|
|||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext context) {
|
public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext context) {
|
||||||
AuthnRequest authnRequest = createAuthnRequest(context);
|
AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context);
|
||||||
String xml = serialize(authnRequest);
|
String xml = serialize(authnRequest);
|
||||||
Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context);
|
Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context);
|
||||||
String deflatedAndEncoded = samlEncode(samlDeflate(xml));
|
String deflatedAndEncoded = samlEncode(samlDeflate(xml));
|
||||||
@ -168,11 +166,9 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
|
|||||||
}
|
}
|
||||||
|
|
||||||
private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
|
private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
|
||||||
AuthnRequest authnRequest = createAuthnRequest(context.getIssuer(),
|
return createAuthnRequest(context.getIssuer(),
|
||||||
context.getDestination(), context.getAssertionConsumerServiceUrl(),
|
context.getDestination(), context.getAssertionConsumerServiceUrl(),
|
||||||
this.protocolBindingResolver.convert(context));
|
this.protocolBindingResolver.convert(context));
|
||||||
this.authnRequestConsumerResolver.apply(context).accept(authnRequest);
|
|
||||||
return authnRequest;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private AuthnRequest createAuthnRequest
|
private AuthnRequest createAuthnRequest
|
||||||
@ -194,13 +190,13 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
|
|||||||
/**
|
/**
|
||||||
* Set the {@link AuthnRequest} post-processor resolver
|
* Set the {@link AuthnRequest} post-processor resolver
|
||||||
*
|
*
|
||||||
* @param authnRequestConsumerResolver
|
* @param authenticationRequestContextConverter
|
||||||
* @since 5.4
|
* @since 5.4
|
||||||
*/
|
*/
|
||||||
public void setAuthnRequestConsumerResolver(
|
public void setAuthenticationRequestContextConverter(
|
||||||
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver) {
|
Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter) {
|
||||||
Assert.notNull(authnRequestConsumerResolver, "authnRequestConsumerResolver cannot be null");
|
Assert.notNull(authenticationRequestContextConverter, "authenticationRequestContextConverter cannot be null");
|
||||||
this.authnRequestConsumerResolver = authnRequestConsumerResolver;
|
this.authenticationRequestContextConverter = authenticationRequestContextConverter;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -17,8 +17,6 @@
|
|||||||
package org.springframework.security.saml2.provider.service.authentication;
|
package org.springframework.security.saml2.provider.service.authentication;
|
||||||
|
|
||||||
import java.io.ByteArrayInputStream;
|
import java.io.ByteArrayInputStream;
|
||||||
import java.util.function.Consumer;
|
|
||||||
import java.util.function.Function;
|
|
||||||
|
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
@ -31,6 +29,7 @@ import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller;
|
|||||||
import org.w3c.dom.Document;
|
import org.w3c.dom.Document;
|
||||||
import org.w3c.dom.Element;
|
import org.w3c.dom.Element;
|
||||||
|
|
||||||
|
import org.springframework.core.convert.converter.Converter;
|
||||||
import org.springframework.security.saml2.Saml2Exception;
|
import org.springframework.security.saml2.Saml2Exception;
|
||||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
||||||
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
|
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
|
||||||
@ -47,6 +46,7 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getU
|
|||||||
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
|
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
|
||||||
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
|
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
|
||||||
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate;
|
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate;
|
||||||
|
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest;
|
||||||
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
|
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
|
||||||
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
|
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
|
||||||
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
|
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
|
||||||
@ -63,8 +63,7 @@ public class OpenSamlAuthenticationRequestFactoryTests {
|
|||||||
private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder;
|
private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder;
|
||||||
private RelyingPartyRegistration relyingPartyRegistration;
|
private RelyingPartyRegistration relyingPartyRegistration;
|
||||||
|
|
||||||
private AuthnRequestUnmarshaller unmarshaller = (AuthnRequestUnmarshaller) getUnmarshallerFactory()
|
private AuthnRequestUnmarshaller unmarshaller;
|
||||||
.getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
|
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public ExpectedException exception = ExpectedException.none();
|
public ExpectedException exception = ExpectedException.none();
|
||||||
@ -84,6 +83,8 @@ public class OpenSamlAuthenticationRequestFactoryTests {
|
|||||||
.assertionConsumerServiceUrl("https://issuer/sso");
|
.assertionConsumerServiceUrl("https://issuer/sso");
|
||||||
context = contextBuilder.build();
|
context = contextBuilder.build();
|
||||||
factory = new OpenSamlAuthenticationRequestFactory();
|
factory = new OpenSamlAuthenticationRequestFactory();
|
||||||
|
this.unmarshaller =(AuthnRequestUnmarshaller) getUnmarshallerFactory()
|
||||||
|
.getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -182,29 +183,29 @@ public class OpenSamlAuthenticationRequestFactoryTests {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
|
public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
|
||||||
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
|
Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter =
|
||||||
mock(Function.class);
|
mock(Converter.class);
|
||||||
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
|
when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest());
|
||||||
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
|
this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter);
|
||||||
|
|
||||||
this.factory.createPostAuthenticationRequest(this.context);
|
this.factory.createPostAuthenticationRequest(this.context);
|
||||||
verify(authnRequestConsumerResolver).apply(this.context);
|
verify(authenticationRequestContextConverter).convert(this.context);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
|
public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
|
||||||
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
|
Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter =
|
||||||
mock(Function.class);
|
mock(Converter.class);
|
||||||
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
|
when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest());
|
||||||
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
|
this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter);
|
||||||
|
|
||||||
this.factory.createRedirectAuthenticationRequest(this.context);
|
this.factory.createRedirectAuthenticationRequest(this.context);
|
||||||
verify(authnRequestConsumerResolver).apply(this.context);
|
verify(authenticationRequestContextConverter).convert(this.context);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void setAuthnRequestConsumerResolverWhenNullThenException() {
|
public void setAuthenticationRequestContextConverterWhenNullThenException() {
|
||||||
assertThatCode(() -> this.factory.setAuthnRequestConsumerResolver(null))
|
assertThatCode(() -> this.factory.setAuthenticationRequestContextConverter(null))
|
||||||
.isInstanceOf(IllegalArgumentException.class);
|
.isInstanceOf(IllegalArgumentException.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,6 +53,7 @@ import org.opensaml.saml.saml2.core.Assertion;
|
|||||||
import org.opensaml.saml.saml2.core.Attribute;
|
import org.opensaml.saml.saml2.core.Attribute;
|
||||||
import org.opensaml.saml.saml2.core.AttributeStatement;
|
import org.opensaml.saml.saml2.core.AttributeStatement;
|
||||||
import org.opensaml.saml.saml2.core.AttributeValue;
|
import org.opensaml.saml.saml2.core.AttributeValue;
|
||||||
|
import org.opensaml.saml.saml2.core.AuthnRequest;
|
||||||
import org.opensaml.saml.saml2.core.Conditions;
|
import org.opensaml.saml.saml2.core.Conditions;
|
||||||
import org.opensaml.saml.saml2.core.EncryptedAssertion;
|
import org.opensaml.saml.saml2.core.EncryptedAssertion;
|
||||||
import org.opensaml.saml.saml2.core.EncryptedID;
|
import org.opensaml.saml.saml2.core.EncryptedID;
|
||||||
@ -86,7 +87,7 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getB
|
|||||||
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
|
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
|
||||||
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.assertingPartySigningCredential;
|
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.assertingPartySigningCredential;
|
||||||
|
|
||||||
final class TestOpenSamlObjects {
|
public final class TestOpenSamlObjects {
|
||||||
static {
|
static {
|
||||||
OpenSamlInitializationService.initialize();
|
OpenSamlInitializationService.initialize();
|
||||||
}
|
}
|
||||||
@ -188,6 +189,16 @@ final class TestOpenSamlObjects {
|
|||||||
return conditions;
|
return conditions;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static AuthnRequest authnRequest() {
|
||||||
|
Issuer issuer = build(Issuer.DEFAULT_ELEMENT_NAME);
|
||||||
|
issuer.setValue(ASSERTING_PARTY_ENTITY_ID);
|
||||||
|
AuthnRequest authnRequest = build(AuthnRequest.DEFAULT_ELEMENT_NAME);
|
||||||
|
authnRequest.setIssuer(issuer);
|
||||||
|
authnRequest.setDestination(ASSERTING_PARTY_ENTITY_ID + "/SSO.saml2");
|
||||||
|
authnRequest.setAssertionConsumerServiceURL(DESTINATION);
|
||||||
|
return authnRequest;
|
||||||
|
}
|
||||||
|
|
||||||
static Credential getSigningCredential(Saml2X509Credential credential, String entityId) {
|
static Credential getSigningCredential(Saml2X509Credential credential, String entityId) {
|
||||||
BasicCredential cred = getBasicCredential(credential);
|
BasicCredential cred = getBasicCredential(credential);
|
||||||
cred.setEntityId(entityId);
|
cred.setEntityId(entityId);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user