From da7477cd41504569a5087a4d1b593bd341edaac1 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Tue, 18 Aug 2020 15:57:02 -0600 Subject: [PATCH] Add Response to Authentication Conversion Support Closes gh-8010 --- .../OpenSamlAuthenticationProvider.java | 89 +++++++++++++++++-- .../OpenSamlAuthenticationProviderTests.java | 40 ++++++++- .../authentication/TestOpenSamlObjects.java | 8 +- 3 files changed, 128 insertions(+), 9 deletions(-) diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java index d9e4c3736c..f23c644fbb 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java @@ -28,7 +28,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.function.Function; import javax.annotation.Nonnull; import javax.xml.namespace.QName; @@ -185,8 +184,10 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi private GrantedAuthoritiesMapper authoritiesMapper = (a -> a); private Duration responseTimeValidationSkew = Duration.ofMinutes(5); - private Function> authenticationConverter = - token -> response -> { + private Converter responseAuthenticationConverter = + responseToken -> { + Response response = responseToken.response; + Saml2AuthenticationToken token = responseToken.token; Assertion assertion = CollectionUtils.firstElement(response.getAssertions()); String username = assertion.getSubject().getNameID().getValue(); Map> attributes = getAssertionAttributes(assertion); @@ -255,11 +256,42 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi this.assertionValidator = assertionValidator; } + /** + * Set the {@link Converter} to use for converting a validated {@link Response} into + * an {@link AbstractAuthenticationToken}. + * + * You can delegate to the default behavior by calling {@link #createDefaultResponseAuthenticationConverter()} + * like so: + * + *
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 * 	Converter<ResponseToken, Saml2Authentication> authenticationConverter =
+	 * 			createDefaultResponseAuthenticationConverter();
+	 *	provider.setResponseAuthenticationConverter(responseToken -> {
+	 *		Saml2Authentication authentication = authenticationConverter.convert(responseToken);
+	 *		User user = myUserRepository.findByUsername(authentication.getName());
+	 *		return new MyAuthentication(authentication, user);
+	 *	});
+	 * 
+ * + * This method takes precedence over {@link #setAuthoritiesExtractor(Converter)} and + * {@link #setAuthoritiesMapper(GrantedAuthoritiesMapper)}. + * + * @param responseAuthenticationConverter the {@link Converter} to use + * @since 5.4 + */ + public void setResponseAuthenticationConverter( + Converter responseAuthenticationConverter) { + Assert.notNull(responseAuthenticationConverter, "responseAuthenticationConverter cannot be null"); + this.responseAuthenticationConverter = responseAuthenticationConverter; + } + /** * Sets the {@link Converter} used for extracting assertion attributes that * can be mapped to authorities. * @param authoritiesExtractor the {@code Converter} used for mapping the * assertion attributes to authorities + * @deprecated Use {@link #setResponseAuthenticationConverter(Converter)} instead */ public void setAuthoritiesExtractor(Converter> authoritiesExtractor) { Assert.notNull(authoritiesExtractor, "authoritiesExtractor cannot be null"); @@ -271,6 +303,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi * to a new set of authorities which will be associated to the {@link Saml2Authentication}. * Note: This implementation is only retrieving * @param authoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the user's authorities + * @deprecated Use {@link #setResponseAuthenticationConverter(Converter)} instead */ public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) { notNull(authoritiesMapper, "authoritiesMapper cannot be null"); @@ -286,6 +319,27 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi this.responseTimeValidationSkew = responseTimeValidationSkew; } + /** + * Construct a default strategy for converting a SAML 2.0 Response and {@link Authentication} + * token into a {@link Saml2Authentication} + * + * @return the default response authentication converter strategy + * @since 5.4 + */ + public static Converter + createDefaultResponseAuthenticationConverter() { + return responseToken -> { + Saml2AuthenticationToken token = responseToken.token; + Response response = responseToken.response; + Assertion assertion = CollectionUtils.firstElement(response.getAssertions()); + String username = assertion.getSubject().getNameID().getValue(); + Map> attributes = getAssertionAttributes(assertion); + return new Saml2Authentication( + new DefaultSaml2AuthenticatedPrincipal(username, attributes), token.getSaml2Response(), + Collections.singleton(new SimpleGrantedAuthority("ROLE_USER"))); + }; + } + /** * @param authentication the authentication request object, must be of type * {@link Saml2AuthenticationToken} @@ -300,7 +354,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi String serializedResponse = token.getSaml2Response(); Response response = parse(serializedResponse); process(token, response); - return this.authenticationConverter.apply(token).convert(response); + return this.responseAuthenticationConverter.convert(new ResponseToken(response, token)); } catch (Saml2AuthenticationException e) { throw e; } catch (Exception e) { @@ -496,7 +550,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi } } - private Map> getAssertionAttributes(Assertion assertion) { + private static Map> getAssertionAttributes(Assertion assertion) { Map> attributeMap = new LinkedHashMap<>(); for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) { for (Attribute attribute : attributeStatement.getAttributes()) { @@ -515,7 +569,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi return attributeMap; } - private Object getXmlObjectValue(XMLObject xmlObject) { + private static Object getXmlObjectValue(XMLObject xmlObject) { if (xmlObject instanceof XSAny) { return ((XSAny) xmlObject).getTextContent(); } @@ -706,6 +760,29 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi return new Saml2AuthenticationException(validationError(code, description), cause); } + /** + * A tuple containing an OpenSAML {@link Response} and its associated authentication token. + * + * @since 5.4 + */ + public static class ResponseToken { + private final Saml2AuthenticationToken token; + private final Response response; + + ResponseToken(Response response, Saml2AuthenticationToken token) { + this.token = token; + this.response = response; + } + + public Response getResponse() { + return this.response; + } + + public Saml2AuthenticationToken getToken() { + return this.token; + } + } + /** * A tuple containing an OpenSAML {@link Assertion} and its associated authentication token. * diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java index 54033535e1..ab60156562 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java @@ -77,17 +77,20 @@ import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParamete import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED; import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_ASSERTION; import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE; +import static org.springframework.security.saml2.core.Saml2ResponseValidatorResult.success; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential; import static org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.createDefaultAssertionValidator; +import static org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.createDefaultResponseAuthenticationConverter; import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion; import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements; import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted; import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response; import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed; +import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signedResponseWithOneAssertion; import static org.springframework.util.StringUtils.hasText; /** @@ -103,6 +106,10 @@ public class OpenSamlAuthenticationProviderTests { private static String ASSERTING_PARTY_ENTITY_ID = "https://some.idp.test/saml2/idp"; private OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); + private Saml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal + ("name", Collections.emptyMap()); + private Saml2Authentication authentication = new Saml2Authentication + (this.principal, "response", Collections.emptyList()); @Rule public ExpectedException exception = ExpectedException.none(); @@ -380,7 +387,7 @@ public class OpenSamlAuthenticationProviderTests { signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID); Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); when(validator.convert(any(OpenSamlAuthenticationProvider.AssertionToken.class))) - .thenReturn(Saml2ResponseValidatorResult.success()); + .thenReturn(success()); provider.authenticate(token); verify(validator).convert(any(OpenSamlAuthenticationProvider.AssertionToken.class)); } @@ -388,7 +395,7 @@ public class OpenSamlAuthenticationProviderTests { @Test public void authenticateWhenDefaultConditionValidatorNotUsedThenSignatureStillChecked() { OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); - provider.setAssertionValidator(assertionToken -> Saml2ResponseValidatorResult.success()); + provider.setAssertionValidator(assertionToken -> success()); Response response = response(); Assertion assertion = assertion(); signed(assertion, relyingPartyDecryptingCredential(), RELYING_PARTY_ENTITY_ID); // broken signature @@ -424,6 +431,35 @@ public class OpenSamlAuthenticationProviderTests { .isInstanceOf(IllegalArgumentException.class); } + @Test + public void createDefaultResponseAuthenticationConverterWhenResponseThenConverts() { + Response response = signedResponseWithOneAssertion(); + Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); + OpenSamlAuthenticationProvider.ResponseToken responseToken = + new OpenSamlAuthenticationProvider.ResponseToken(response, token); + Saml2Authentication authentication = createDefaultResponseAuthenticationConverter() + .convert(responseToken); + assertThat(authentication.getName()).isEqualTo("test@saml.user"); + } + + @Test + public void authenticateWhenResponseAuthenticationConverterConfiguredThenUses() { + Converter authenticationConverter = + mock(Converter.class); + OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); + provider.setResponseAuthenticationConverter(authenticationConverter); + Response response = signedResponseWithOneAssertion(); + Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); + provider.authenticate(token); + verify(authenticationConverter).convert(any()); + } + + @Test + public void setResponseAuthenticationConverterWhenNullThenIllegalArgument() { + assertThatCode(() -> this.provider.setResponseAuthenticationConverter(null)) + .isInstanceOf(IllegalArgumentException.class); + } + private T build(QName qName) { return (T) getBuilderFactory().getBuilder(qName).buildObject(qName); } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java index 80d649bd3a..1df38bd978 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java @@ -79,6 +79,7 @@ import org.springframework.security.saml2.core.OpenSamlInitializationService; import org.springframework.security.saml2.core.Saml2X509Credential; import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory; +import static org.springframework.security.saml2.core.TestSaml2X509Credentials.assertingPartySigningCredential; final class TestOpenSamlObjects { static { @@ -107,6 +108,12 @@ final class TestOpenSamlObjects { return response; } + static Response signedResponseWithOneAssertion() { + Response response = response(); + response.getAssertions().add(assertion()); + return signed(response, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); + } + static Assertion assertion() { return assertion(USERNAME, ASSERTING_PARTY_ENTITY_ID, RELYING_PARTY_ENTITY_ID, DESTINATION); } @@ -135,7 +142,6 @@ final class TestOpenSamlObjects { return assertion; } - static Issuer issuer(String entityId) { Issuer issuer = build(Issuer.DEFAULT_ELEMENT_NAME); issuer.setValue(entityId);