diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java index 9e50a2e152..2b90bb16b4 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java @@ -71,6 +71,7 @@ import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; import org.springframework.security.saml2.core.TestSaml2X509Credentials; import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider.ResponseToken; +import org.springframework.security.saml2.provider.service.authentication.TestCustomOpenSamlObjects.CustomOpenSamlObject; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.util.StringUtils; @@ -252,28 +253,22 @@ public class OpenSaml4AuthenticationProviderTests { @Test public void authenticateWhenAssertionContainsCustomAttributesThenItSucceeds() { - XMLObjectProviderRegistrySupport.getMarshallerFactory().registerMarshaller( - TestCustomOpenSamlObject.CustomSamlObject.TYPE_NAME, - new TestCustomOpenSamlObject.CustomSamlObjectMarshaller()); - XMLObjectProviderRegistrySupport.getUnmarshallerFactory().registerUnmarshaller( - TestCustomOpenSamlObject.CustomSamlObject.TYPE_NAME, - new TestCustomOpenSamlObject.CustomSamlObjectUnmarshaller()); Response response = response(); Assertion assertion = assertion(); - List attributes = TestOpenSamlObjects.customAttributeStatements(); - assertion.getAttributeStatements().addAll(attributes); + AttributeStatement attribute = TestOpenSamlObjects.customAttributeStatement("Address", + TestCustomOpenSamlObjects.instance()); + assertion.getAttributeStatements().add(attribute); TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); response.getAssertions().add(assertion); Saml2AuthenticationToken token = token(response, verifying(registration())); Authentication authentication = this.provider.authenticate(token); Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal(); - TestCustomOpenSamlObject.CustomSamlObject customSamlObject; - customSamlObject = (TestCustomOpenSamlObject.CustomSamlObject) principal.getAttribute("Address").get(0); - assertThat(customSamlObject.getStreet()).isEqualTo("Test Street"); - assertThat(customSamlObject.getStreetNumber()).isEqualTo("1"); - assertThat(customSamlObject.getZIP()).isEqualTo("11111"); - assertThat(customSamlObject.getCity()).isEqualTo("Test City"); + CustomOpenSamlObject address = (CustomOpenSamlObject) principal.getAttribute("Address").get(0); + assertThat(address.getStreet()).isEqualTo("Test Street"); + assertThat(address.getStreetNumber()).isEqualTo("1"); + assertThat(address.getZIP()).isEqualTo("11111"); + assertThat(address.getCity()).isEqualTo("Test City"); } @Test diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSamlObject.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSamlObjects.java similarity index 62% rename from saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSamlObject.java rename to saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSamlObjects.java index 72f18d3ce7..deb642aeeb 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSamlObject.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSamlObjects.java @@ -29,18 +29,54 @@ import org.opensaml.core.xml.AbstractXMLObjectBuilder; import org.opensaml.core.xml.ElementExtensibleXMLObject; import org.opensaml.core.xml.Namespace; import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.core.xml.io.AbstractXMLObjectMarshaller; import org.opensaml.core.xml.io.AbstractXMLObjectUnmarshaller; import org.opensaml.core.xml.io.UnmarshallingException; import org.opensaml.core.xml.schema.XSAny; +import org.opensaml.core.xml.schema.impl.XSAnyBuilder; import org.opensaml.core.xml.util.IndexedXMLObjectChildrenList; import org.opensaml.saml.common.xml.SAMLConstants; import org.opensaml.saml.saml2.core.AttributeValue; import org.w3c.dom.Element; -public class TestCustomOpenSamlObject { +import org.springframework.security.saml2.core.OpenSamlInitializationService; - public interface CustomSamlObject extends ElementExtensibleXMLObject { +public class TestCustomOpenSamlObjects { + + static { + OpenSamlInitializationService.initialize(); + XMLObjectProviderRegistrySupport.getMarshallerFactory().registerMarshaller( + CustomOpenSamlObject.TYPE_NAME, + new TestCustomOpenSamlObjects.CustomSamlObjectMarshaller()); + XMLObjectProviderRegistrySupport.getUnmarshallerFactory().registerUnmarshaller( + CustomOpenSamlObject.TYPE_NAME, + new TestCustomOpenSamlObjects.CustomSamlObjectUnmarshaller()); + } + + public static CustomOpenSamlObject instance() { + CustomOpenSamlObject samlObject = new TestCustomOpenSamlObjects.CustomSamlObjectBuilder() + .buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, CustomOpenSamlObject.TYPE_NAME); + XSAny street = new XSAnyBuilder().buildObject(CustomOpenSamlObject.CUSTOM_NS, "Street", + CustomOpenSamlObject.TYPE_CUSTOM_PREFIX); + street.setTextContent("Test Street"); + samlObject.getUnknownXMLObjects().add(street); + XSAny streetNumber = new XSAnyBuilder().buildObject(CustomOpenSamlObject.CUSTOM_NS, + "Number", CustomOpenSamlObject.TYPE_CUSTOM_PREFIX); + streetNumber.setTextContent("1"); + samlObject.getUnknownXMLObjects().add(streetNumber); + XSAny zip = new XSAnyBuilder().buildObject(CustomOpenSamlObject.CUSTOM_NS, "ZIP", + CustomOpenSamlObject.TYPE_CUSTOM_PREFIX); + zip.setTextContent("11111"); + samlObject.getUnknownXMLObjects().add(zip); + XSAny city = new XSAnyBuilder().buildObject(CustomOpenSamlObject.CUSTOM_NS, "City", + CustomOpenSamlObject.TYPE_CUSTOM_PREFIX); + city.setTextContent("Test City"); + samlObject.getUnknownXMLObjects().add(city); + return samlObject; + } + + public interface CustomOpenSamlObject extends ElementExtensibleXMLObject { String TYPE_LOCAL_NAME = "CustomType"; @@ -61,8 +97,8 @@ public class TestCustomOpenSamlObject { } - public static class CustomSamlObjectImpl extends AbstractXMLObject - implements TestCustomOpenSamlObject.CustomSamlObject { + public static class CustomOpenSamlObjectImpl extends AbstractXMLObject + implements CustomOpenSamlObject { @Nonnull private IndexedXMLObjectChildrenList unknownXMLObjects; @@ -74,7 +110,7 @@ public class TestCustomOpenSamlObject { * represents * @param namespacePrefix the prefix for the given namespace */ - protected CustomSamlObjectImpl(@Nullable String namespaceURI, @Nonnull String elementLocalName, + protected CustomOpenSamlObjectImpl(@Nullable String namespaceURI, @Nonnull String elementLocalName, @Nullable String namespacePrefix) { super(namespaceURI, elementLocalName, namespacePrefix); super.getNamespaceManager().registerNamespaceDeclaration(new Namespace(CUSTOM_NS, TYPE_CUSTOM_PREFIX)); @@ -122,13 +158,13 @@ public class TestCustomOpenSamlObject { } public static class CustomSamlObjectBuilder - extends AbstractXMLObjectBuilder { + extends AbstractXMLObjectBuilder { @Nonnull @Override - public TestCustomOpenSamlObject.CustomSamlObject buildObject(@Nullable String namespaceURI, - @Nonnull String localName, @Nullable String namespacePrefix) { - return new TestCustomOpenSamlObject.CustomSamlObjectImpl(namespaceURI, localName, namespacePrefix); + public CustomOpenSamlObject buildObject(@Nullable String namespaceURI, + @Nonnull String localName, @Nullable String namespacePrefix) { + return new CustomOpenSamlObjectImpl(namespaceURI, localName, namespacePrefix); } } @@ -141,7 +177,7 @@ public class TestCustomOpenSamlObject { @Override protected void marshallElementContent(@Nonnull XMLObject xmlObject, @Nonnull Element domElement) { - final TestCustomOpenSamlObject.CustomSamlObject customSamlObject = (TestCustomOpenSamlObject.CustomSamlObject) xmlObject; + final CustomOpenSamlObject customSamlObject = (CustomOpenSamlObject) xmlObject; for (XMLObject object : customSamlObject.getOrderedChildren()) { ElementSupport.appendChildElement(domElement, object.getDOM()); @@ -159,7 +195,7 @@ public class TestCustomOpenSamlObject { @Override protected void processChildElement(@Nonnull XMLObject parentXMLObject, @Nonnull XMLObject childXMLObject) throws UnmarshallingException { - final TestCustomOpenSamlObject.CustomSamlObject customSamlObject = (TestCustomOpenSamlObject.CustomSamlObject) parentXMLObject; + final CustomOpenSamlObject customSamlObject = (CustomOpenSamlObject) parentXMLObject; super.processChildElement(customSamlObject, childXMLObject); customSamlObject.getUnknownXMLObjects().add(childXMLObject); } @@ -167,9 +203,9 @@ public class TestCustomOpenSamlObject { @Nonnull @Override protected XMLObject buildXMLObject(@Nonnull Element domElement) { - return new TestCustomOpenSamlObject.CustomSamlObjectImpl(SAMLConstants.SAML20_NS, + return new CustomOpenSamlObjectImpl(SAMLConstants.SAML20_NS, AttributeValue.DEFAULT_ELEMENT_LOCAL_NAME, - TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX); + CustomOpenSamlObject.TYPE_CUSTOM_PREFIX); } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java index c9bb8cd696..e0edad7136 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java @@ -296,35 +296,15 @@ public final class TestOpenSamlObjects { return attribute; } - static List customAttributeStatements() { - List attributeStatements = new ArrayList<>(); + static AttributeStatement customAttributeStatement(String attributeName, XMLObject customAttributeValue) { AttributeStatementBuilder attributeStatementBuilder = new AttributeStatementBuilder(); AttributeBuilder attributeBuilder = new AttributeBuilder(); Attribute attribute = attributeBuilder.buildObject(); - attribute.setName("Address"); - TestCustomOpenSamlObject.CustomSamlObject samlObject = new TestCustomOpenSamlObject.CustomSamlObjectBuilder() - .buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, TestCustomOpenSamlObject.CustomSamlObject.TYPE_NAME); - XSAny street = new XSAnyBuilder().buildObject(TestCustomOpenSamlObject.CustomSamlObject.CUSTOM_NS, "Street", - TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX); - street.setTextContent("Test Street"); - samlObject.getUnknownXMLObjects().add(street); - XSAny streetNumber = new XSAnyBuilder().buildObject(TestCustomOpenSamlObject.CustomSamlObject.CUSTOM_NS, - "Number", TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX); - streetNumber.setTextContent("1"); - samlObject.getUnknownXMLObjects().add(streetNumber); - XSAny zip = new XSAnyBuilder().buildObject(TestCustomOpenSamlObject.CustomSamlObject.CUSTOM_NS, "ZIP", - TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX); - zip.setTextContent("11111"); - samlObject.getUnknownXMLObjects().add(zip); - XSAny city = new XSAnyBuilder().buildObject(TestCustomOpenSamlObject.CustomSamlObject.CUSTOM_NS, "City", - TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX); - city.setTextContent("Test City"); - samlObject.getUnknownXMLObjects().add(city); - attribute.getAttributeValues().add(samlObject); + attribute.setName(attributeName); + attribute.getAttributeValues().add(customAttributeValue); AttributeStatement attributeStatement = attributeStatementBuilder.buildObject(); attributeStatement.getAttributes().add(attribute); - attributeStatements.add(attributeStatement); - return attributeStatements; + return attributeStatement; } static List attributeStatements() {