diff --git a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java index 13671526bc..ffac5f5bf1 100644 --- a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java @@ -648,7 +648,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv if (xmlObject instanceof XSDateTime) { return ((XSDateTime) xmlObject).getValue(); } - return null; + return xmlObject; } private static Saml2AuthenticationException createAuthenticationException(String code, String message, diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java index f05c3b0523..9e50a2e152 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java @@ -250,6 +250,32 @@ public class OpenSaml4AuthenticationProviderTests { assertThat(principal.getSessionIndexes()).contains("session-index"); } + @Test + public void authenticateWhenAssertionContainsCustomAttributesThenItSucceeds() { + XMLObjectProviderRegistrySupport.getMarshallerFactory().registerMarshaller( + TestCustomOpenSamlObject.CustomSamlObject.TYPE_NAME, + new TestCustomOpenSamlObject.CustomSamlObjectMarshaller()); + XMLObjectProviderRegistrySupport.getUnmarshallerFactory().registerUnmarshaller( + TestCustomOpenSamlObject.CustomSamlObject.TYPE_NAME, + new TestCustomOpenSamlObject.CustomSamlObjectUnmarshaller()); + Response response = response(); + Assertion assertion = assertion(); + List attributes = TestOpenSamlObjects.customAttributeStatements(); + assertion.getAttributeStatements().addAll(attributes); + TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); + response.getAssertions().add(assertion); + Saml2AuthenticationToken token = token(response, verifying(registration())); + Authentication authentication = this.provider.authenticate(token); + Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal(); + TestCustomOpenSamlObject.CustomSamlObject customSamlObject; + customSamlObject = (TestCustomOpenSamlObject.CustomSamlObject) principal.getAttribute("Address").get(0); + assertThat(customSamlObject.getStreet()).isEqualTo("Test Street"); + assertThat(customSamlObject.getStreetNumber()).isEqualTo("1"); + assertThat(customSamlObject.getZIP()).isEqualTo("11111"); + assertThat(customSamlObject.getCity()).isEqualTo("Test City"); + } + @Test public void authenticateWhenEncryptedAssertionWithoutSignatureThenItFails() { Response response = response(); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSamlObject.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSamlObject.java new file mode 100644 index 0000000000..72f18d3ce7 --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSamlObject.java @@ -0,0 +1,177 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication; + +import java.util.Collections; +import java.util.List; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import javax.xml.namespace.QName; + +import net.shibboleth.utilities.java.support.xml.ElementSupport; +import org.opensaml.core.xml.AbstractXMLObject; +import org.opensaml.core.xml.AbstractXMLObjectBuilder; +import org.opensaml.core.xml.ElementExtensibleXMLObject; +import org.opensaml.core.xml.Namespace; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.io.AbstractXMLObjectMarshaller; +import org.opensaml.core.xml.io.AbstractXMLObjectUnmarshaller; +import org.opensaml.core.xml.io.UnmarshallingException; +import org.opensaml.core.xml.schema.XSAny; +import org.opensaml.core.xml.util.IndexedXMLObjectChildrenList; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.saml2.core.AttributeValue; +import org.w3c.dom.Element; + +public class TestCustomOpenSamlObject { + + public interface CustomSamlObject extends ElementExtensibleXMLObject { + + String TYPE_LOCAL_NAME = "CustomType"; + + String TYPE_CUSTOM_PREFIX = "custom"; + + String CUSTOM_NS = "https://custom.com/schema/custom"; + + /** QName of the CustomType type. */ + QName TYPE_NAME = new QName(CUSTOM_NS, TYPE_LOCAL_NAME, TYPE_CUSTOM_PREFIX); + + String getStreet(); + + String getStreetNumber(); + + String getZIP(); + + String getCity(); + + } + + public static class CustomSamlObjectImpl extends AbstractXMLObject + implements TestCustomOpenSamlObject.CustomSamlObject { + + @Nonnull + private IndexedXMLObjectChildrenList unknownXMLObjects; + + /** + * Constructor. + * @param namespaceURI the namespace the element is in + * @param elementLocalName the local name of the XML element this Object + * represents + * @param namespacePrefix the prefix for the given namespace + */ + protected CustomSamlObjectImpl(@Nullable String namespaceURI, @Nonnull String elementLocalName, + @Nullable String namespacePrefix) { + super(namespaceURI, elementLocalName, namespacePrefix); + super.getNamespaceManager().registerNamespaceDeclaration(new Namespace(CUSTOM_NS, TYPE_CUSTOM_PREFIX)); + this.unknownXMLObjects = new IndexedXMLObjectChildrenList<>(this); + } + + @Nonnull + @Override + public List getUnknownXMLObjects() { + return this.unknownXMLObjects; + } + + @Nonnull + @Override + public List getUnknownXMLObjects(@Nonnull QName typeOrName) { + return (List) this.unknownXMLObjects.subList(typeOrName); + } + + @Nullable + @Override + public List getOrderedChildren() { + return Collections.unmodifiableList(this.unknownXMLObjects); + } + + @Override + public String getStreet() { + return ((XSAny) getOrderedChildren().get(0)).getTextContent(); + } + + @Override + public String getStreetNumber() { + return ((XSAny) getOrderedChildren().get(1)).getTextContent(); + } + + @Override + public String getZIP() { + return ((XSAny) getOrderedChildren().get(2)).getTextContent(); + } + + @Override + public String getCity() { + return ((XSAny) getOrderedChildren().get(3)).getTextContent(); + } + + } + + public static class CustomSamlObjectBuilder + extends AbstractXMLObjectBuilder { + + @Nonnull + @Override + public TestCustomOpenSamlObject.CustomSamlObject buildObject(@Nullable String namespaceURI, + @Nonnull String localName, @Nullable String namespacePrefix) { + return new TestCustomOpenSamlObject.CustomSamlObjectImpl(namespaceURI, localName, namespacePrefix); + } + + } + + public static class CustomSamlObjectMarshaller extends AbstractXMLObjectMarshaller { + + public CustomSamlObjectMarshaller() { + super(); + } + + @Override + protected void marshallElementContent(@Nonnull XMLObject xmlObject, @Nonnull Element domElement) { + final TestCustomOpenSamlObject.CustomSamlObject customSamlObject = (TestCustomOpenSamlObject.CustomSamlObject) xmlObject; + + for (XMLObject object : customSamlObject.getOrderedChildren()) { + ElementSupport.appendChildElement(domElement, object.getDOM()); + } + } + + } + + public static class CustomSamlObjectUnmarshaller extends AbstractXMLObjectUnmarshaller { + + public CustomSamlObjectUnmarshaller() { + super(); + } + + @Override + protected void processChildElement(@Nonnull XMLObject parentXMLObject, @Nonnull XMLObject childXMLObject) + throws UnmarshallingException { + final TestCustomOpenSamlObject.CustomSamlObject customSamlObject = (TestCustomOpenSamlObject.CustomSamlObject) parentXMLObject; + super.processChildElement(customSamlObject, childXMLObject); + customSamlObject.getUnknownXMLObjects().add(childXMLObject); + } + + @Nonnull + @Override + protected XMLObject buildXMLObject(@Nonnull Element domElement) { + return new TestCustomOpenSamlObject.CustomSamlObjectImpl(SAMLConstants.SAML20_NS, + AttributeValue.DEFAULT_ELEMENT_LOCAL_NAME, + TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX); + } + + } + +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java index 1c7b023314..c9bb8cd696 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java @@ -296,6 +296,37 @@ public final class TestOpenSamlObjects { return attribute; } + static List customAttributeStatements() { + List attributeStatements = new ArrayList<>(); + AttributeStatementBuilder attributeStatementBuilder = new AttributeStatementBuilder(); + AttributeBuilder attributeBuilder = new AttributeBuilder(); + Attribute attribute = attributeBuilder.buildObject(); + attribute.setName("Address"); + TestCustomOpenSamlObject.CustomSamlObject samlObject = new TestCustomOpenSamlObject.CustomSamlObjectBuilder() + .buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, TestCustomOpenSamlObject.CustomSamlObject.TYPE_NAME); + XSAny street = new XSAnyBuilder().buildObject(TestCustomOpenSamlObject.CustomSamlObject.CUSTOM_NS, "Street", + TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX); + street.setTextContent("Test Street"); + samlObject.getUnknownXMLObjects().add(street); + XSAny streetNumber = new XSAnyBuilder().buildObject(TestCustomOpenSamlObject.CustomSamlObject.CUSTOM_NS, + "Number", TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX); + streetNumber.setTextContent("1"); + samlObject.getUnknownXMLObjects().add(streetNumber); + XSAny zip = new XSAnyBuilder().buildObject(TestCustomOpenSamlObject.CustomSamlObject.CUSTOM_NS, "ZIP", + TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX); + zip.setTextContent("11111"); + samlObject.getUnknownXMLObjects().add(zip); + XSAny city = new XSAnyBuilder().buildObject(TestCustomOpenSamlObject.CustomSamlObject.CUSTOM_NS, "City", + TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX); + city.setTextContent("Test City"); + samlObject.getUnknownXMLObjects().add(city); + attribute.getAttributeValues().add(samlObject); + AttributeStatement attributeStatement = attributeStatementBuilder.buildObject(); + attributeStatement.getAttributes().add(attribute); + attributeStatements.add(attributeStatement); + return attributeStatements; + } + static List attributeStatements() { List attributeStatements = new ArrayList<>(); AttributeStatementBuilder attributeStatementBuilder = new AttributeStatementBuilder();