diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java index 69d97e2240..f5f69e9a39 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java @@ -531,9 +531,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi } private Object getXmlObjectValue(XMLObject xmlObject) { - if (xmlObject == null) { - return null; - } if (xmlObject instanceof XSAny) { return getXSAnyObjectValue((XSAny) xmlObject); } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java index 7e91bb85fd..b3fb27f1da 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java @@ -19,11 +19,15 @@ package org.springframework.security.saml2.provider.service.authentication; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectOutputStream; +import java.io.StringReader; import java.time.Instant; import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; import org.hamcrest.BaseMatcher; import org.hamcrest.Description; @@ -33,27 +37,40 @@ import org.joda.time.Duration; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.Marshaller; import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.AttributeValue; import org.opensaml.saml.saml2.core.EncryptedAssertion; import org.opensaml.saml.saml2.core.EncryptedID; import org.opensaml.saml.saml2.core.NameID; import org.opensaml.saml.saml2.core.Response; +import org.w3c.dom.Document; +import org.w3c.dom.Element; +import org.xml.sax.InputSource; import org.springframework.security.core.Authentication; import org.springframework.security.saml2.credentials.Saml2X509Credential; -import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion; -import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements; -import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted; -import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response; -import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential; -import static org.springframework.test.util.AssertionErrors.assertEquals; -import static org.springframework.test.util.AssertionErrors.assertTrue; +import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion; +import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements; +import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted; +import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response; +import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed; import static org.springframework.util.StringUtils.hasText; /** @@ -203,24 +220,48 @@ public class OpenSamlAuthenticationProviderTests { public void authenticateWhenAssertionContainsAttributesThenItSucceeds() { Response response = response(); Assertion assertion = assertion(); - attributeStatements().forEach(as -> assertion.getAttributeStatements().add(as)); + List attributes = attributeStatements(); + assertion.getAttributeStatements().addAll(attributes); signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); response.getAssertions().add(assertion); Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); Authentication authentication = this.provider.authenticate(token); Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal(); - Map attributes = new LinkedHashMap<>(); - attributes.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com")); - attributes.put("name", Collections.singletonList("John Doe")); - attributes.put("age", Collections.singletonList(21)); - attributes.put("website", Collections.singletonList("https://johndoe.com/")); - attributes.put("registered", Collections.singletonList(true)); + Map expected = new LinkedHashMap<>(); + expected.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com")); + expected.put("name", Collections.singletonList("John Doe")); + expected.put("age", Collections.singletonList(21)); + expected.put("website", Collections.singletonList("https://johndoe.com/")); + expected.put("registered", Collections.singletonList(true)); Instant registeredDate = Instant.ofEpochMilli(DateTime.parse("1970-01-01T00:00:00Z").getMillis()); - attributes.put("registeredDate", Collections.singletonList(registeredDate)); + expected.put("registeredDate", Collections.singletonList(registeredDate)); - assertEquals("Values should be equal", "John Doe", principal.getFirstAttribute("name")); - assertTrue("Attributes should be equal", attributes.equals(principal.getAttributes())); + assertEquals("John Doe", principal.getFirstAttribute("name")); + assertEquals(expected, principal.getAttributes()); + } + + @Test + public void authenticateWhenAttributeValueMarshallerConfiguredThenUses() throws Exception { + Response response = response(); + Assertion assertion = assertion(); + List attributes = attributeStatements(); + assertion.getAttributeStatements().addAll(attributes); + signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); + response.getAssertions().add(assertion); + Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); + + Element attributeElement = element("value"); + Marshaller marshaller = mock(Marshaller.class); + when(marshaller.marshall(any(XMLObject.class))).thenReturn(attributeElement); + + try { + XMLObjectProviderRegistrySupport.getMarshallerFactory().registerMarshaller(AttributeValue.DEFAULT_ELEMENT_NAME, marshaller); + this.provider.authenticate(token); + verify(marshaller, atLeastOnce()).marshall(any(XMLObject.class)); + } finally { + XMLObjectProviderRegistrySupport.getMarshallerFactory().deregisterMarshaller(AttributeValue.DEFAULT_ELEMENT_NAME); + } } @Test @@ -352,4 +393,11 @@ public class OpenSamlAuthenticationProviderTests { return new Saml2AuthenticationToken(payload, DESTINATION, ASSERTING_PARTY_ENTITY_ID, RELYING_PARTY_ENTITY_ID, Arrays.asList(credentials)); } + + private static Element element(String xml) throws Exception { + DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance(); + DocumentBuilder builder = factory.newDocumentBuilder(); + Document doc = builder.parse(new InputSource(new StringReader(xml))); + return doc.getDocumentElement(); + } }