From 2e2da06bdbeaf99a45ac727f555b7d5c067df936 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Thu, 23 Jul 2020 16:08:16 -0600 Subject: [PATCH] OpenSamlAuthenticationProvider Uses OpenSAML Directly Closes gh-8773 --- .../OpenSamlAuthenticationProvider.java | 51 ++++++++++++++----- .../OpenSamlAuthenticationProviderTests.java | 30 ++++++++--- 2 files changed, 59 insertions(+), 22 deletions(-) diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java index eb8b353cb5..e8ecc5d0d7 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java @@ -15,6 +15,8 @@ */ package org.springframework.security.saml2.provider.service.authentication; +import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; @@ -32,13 +34,17 @@ import java.util.function.Function; import javax.annotation.Nonnull; import net.shibboleth.utilities.java.support.resolver.CriteriaSet; +import net.shibboleth.utilities.java.support.xml.ParserPool; +import net.shibboleth.utilities.java.support.xml.SerializeSupport; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.joda.time.DateTime; +import org.opensaml.core.config.ConfigurationService; import org.opensaml.core.criterion.EntityIdCriterion; import org.opensaml.core.xml.XMLObject; -import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.config.XMLObjectProviderRegistry; import org.opensaml.core.xml.io.Marshaller; +import org.opensaml.core.xml.io.MarshallingException; import org.opensaml.core.xml.schema.XSAny; import org.opensaml.core.xml.schema.XSBoolean; import org.opensaml.core.xml.schema.XSBooleanValue; @@ -65,6 +71,7 @@ import org.opensaml.saml.saml2.core.EncryptedID; import org.opensaml.saml.saml2.core.NameID; import org.opensaml.saml.saml2.core.Response; import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.core.impl.ResponseUnmarshaller; import org.opensaml.saml.saml2.encryption.Decrypter; import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver; import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; @@ -88,6 +95,8 @@ import org.opensaml.xmlsec.keyinfo.impl.StaticKeyInfoCredentialResolver; import org.opensaml.xmlsec.signature.support.SignaturePrevalidator; import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; +import org.w3c.dom.Document; +import org.w3c.dom.Element; import org.springframework.core.convert.converter.Converter; import org.springframework.security.authentication.AbstractAuthenticationToken; @@ -120,7 +129,6 @@ import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_IS import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE; import static org.springframework.security.saml2.core.Saml2ErrorCodes.MALFORMED_RESPONSE_DATA; import static org.springframework.security.saml2.core.Saml2ErrorCodes.SUBJECT_NOT_FOUND; -import static org.springframework.security.saml2.core.Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS; import static org.springframework.util.Assert.notNull; /** @@ -167,7 +175,9 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi private static Log logger = LogFactory.getLog(OpenSamlAuthenticationProvider.class); - private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance(); + private final XMLObjectProviderRegistry registry; + private final ResponseUnmarshaller responseUnmarshaller; + private final ParserPool parserPool; private Converter> authoritiesExtractor = (a -> singletonList(new SimpleGrantedAuthority("ROLE_USER"))); @@ -192,6 +202,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion))); }; + /** + * Creates an {@link OpenSamlAuthenticationProvider} + */ + public OpenSamlAuthenticationProvider() { + this.registry = ConfigurationService.get(XMLObjectProviderRegistry.class); + this.responseUnmarshaller = (ResponseUnmarshaller) this.registry.getUnmarshallerFactory() + .getUnmarshaller(Response.DEFAULT_ELEMENT_NAME); + this.parserPool = this.registry.getParserPool(); + } + /** * Sets the {@link Converter} used for extracting assertion attributes that * can be mapped to authorities. @@ -265,15 +285,13 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi private Response parse(String response) throws Saml2Exception, Saml2AuthenticationException { try { - Object result = this.saml.resolve(response); - if (result instanceof Response) { - return (Response) result; - } - else { - throw authException(UNKNOWN_RESPONSE_CLASS, "Invalid response class:" + result.getClass().getName()); - } - } catch (Saml2Exception x) { - throw authException(MALFORMED_RESPONSE_DATA, x.getMessage(), x); + Document document = this.parserPool.parse(new ByteArrayInputStream( + response.getBytes(StandardCharsets.UTF_8))); + Element element = document.getDocumentElement(); + return (Response) this.responseUnmarshaller.unmarshall(element); + } + catch (Exception e) { + throw authException(MALFORMED_RESPONSE_DATA, e.getMessage(), e); } } @@ -427,9 +445,14 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi } private Object getXSAnyObjectValue(XSAny xsAny) { - Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(xsAny); + Marshaller marshaller = this.registry.getMarshallerFactory().getMarshaller(xsAny); if (marshaller != null) { - return this.saml.serialize(xsAny); + try { + Element element = marshaller.marshall(xsAny); + return SerializeSupport.nodeToString(element); + } catch (MarshallingException e) { + throw new Saml2Exception(e); + } } return xsAny.getTextContent(); } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java index 69983e956c..1f8c00414b 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java @@ -29,6 +29,7 @@ import java.util.Map; import javax.xml.parsers.DocumentBuilder; import javax.xml.parsers.DocumentBuilderFactory; +import net.shibboleth.utilities.java.support.xml.SerializeSupport; import org.hamcrest.BaseMatcher; import org.hamcrest.Description; import org.hamcrest.Matcher; @@ -40,6 +41,7 @@ import org.junit.rules.ExpectedException; import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.core.xml.io.Marshaller; +import org.opensaml.core.xml.io.MarshallingException; import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.AttributeStatement; import org.opensaml.saml.saml2.core.AttributeValue; @@ -52,6 +54,7 @@ import org.w3c.dom.Element; import org.xml.sax.InputSource; import org.springframework.security.core.Authentication; +import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.credentials.Saml2X509Credential; import static org.assertj.core.api.Assertions.assertThat; @@ -60,6 +63,8 @@ import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory; +import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential; @@ -85,8 +90,6 @@ public class OpenSamlAuthenticationProviderTests { private static String ASSERTING_PARTY_ENTITY_ID = "https://some.idp.test/saml2/idp"; private OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); - private OpenSamlImplementation saml = OpenSamlImplementation.getInstance(); - @Rule public ExpectedException exception = ExpectedException.none(); @@ -108,10 +111,11 @@ public class OpenSamlAuthenticationProviderTests { @Test public void authenticateWhenUnknownDataClassThenThrowAuthenticationException() { - this.exception.expect(authenticationMatcher(Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS)); + this.exception.expect(authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA)); - Assertion assertion = this.saml.buildSamlObject(Assertion.DEFAULT_ELEMENT_NAME); - this.provider.authenticate(token(this.saml.serialize(assertion), relyingPartyVerifyingCredential())); + Assertion assertion = (Assertion) getBuilderFactory().getBuilder(Assertion.DEFAULT_ELEMENT_NAME) + .buildObject(Assertion.DEFAULT_ELEMENT_NAME); + this.provider.authenticate(token(serialize(assertion), relyingPartyVerifyingCredential())); } @Test @@ -316,7 +320,7 @@ public class OpenSamlAuthenticationProviderTests { Response response = response(); EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); - Saml2AuthenticationToken token = token(this.saml.serialize(response), relyingPartyVerifyingCredential()); + Saml2AuthenticationToken token = token(serialize(response), relyingPartyVerifyingCredential()); this.provider.authenticate(token); } @@ -329,7 +333,7 @@ public class OpenSamlAuthenticationProviderTests { Response response = response(); EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); - Saml2AuthenticationToken token = token(this.saml.serialize(response), assertingPartyPrivateCredential()); + Saml2AuthenticationToken token = token(serialize(response), assertingPartyPrivateCredential()); this.provider.authenticate(token); } @@ -349,6 +353,16 @@ public class OpenSamlAuthenticationProviderTests { objectOutputStream.flush(); } + private String serialize(XMLObject object) { + try { + Marshaller marshaller = getMarshallerFactory().getMarshaller(object); + Element element = marshaller.marshall(object); + return SerializeSupport.nodeToString(element); + } catch (MarshallingException e) { + throw new Saml2Exception(e); + } + } + private Matcher authenticationMatcher(String code) { return authenticationMatcher(code, null); } @@ -382,7 +396,7 @@ public class OpenSamlAuthenticationProviderTests { } private Saml2AuthenticationToken token(Response response, Saml2X509Credential... credentials) { - String payload = this.saml.serialize(response); + String payload = serialize(response); return token(payload, credentials); }