Polish SAML Attribute Support

Issue gh-8661
This commit is contained in:
Josh Cummings 2020-06-16 14:55:53 -06:00
parent eed33228f4
commit 360db53dd2
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
2 changed files with 65 additions and 20 deletions

View File

@ -531,9 +531,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
} }
private Object getXmlObjectValue(XMLObject xmlObject) { private Object getXmlObjectValue(XMLObject xmlObject) {
if (xmlObject == null) {
return null;
}
if (xmlObject instanceof XSAny) { if (xmlObject instanceof XSAny) {
return getXSAnyObjectValue((XSAny) xmlObject); return getXSAnyObjectValue((XSAny) xmlObject);
} }

View File

@ -19,11 +19,15 @@ package org.springframework.security.saml2.provider.service.authentication;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.ObjectOutputStream; import java.io.ObjectOutputStream;
import java.io.StringReader;
import java.time.Instant; import java.time.Instant;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import org.hamcrest.BaseMatcher; import org.hamcrest.BaseMatcher;
import org.hamcrest.Description; import org.hamcrest.Description;
@ -33,27 +37,40 @@ import org.joda.time.Duration;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.Marshaller;
import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AttributeValue;
import org.opensaml.saml.saml2.core.EncryptedAssertion; import org.opensaml.saml.saml2.core.EncryptedAssertion;
import org.opensaml.saml.saml2.core.EncryptedID; import org.opensaml.saml.saml2.core.EncryptedID;
import org.opensaml.saml.saml2.core.NameID; import org.opensaml.saml.saml2.core.NameID;
import org.opensaml.saml.saml2.core.Response; import org.opensaml.saml.saml2.core.Response;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.xml.sax.InputSource;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.credentials.Saml2X509Credential; import org.springframework.security.saml2.credentials.Saml2X509Credential;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion; import static org.junit.Assert.assertEquals;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements; import static org.junit.Assert.assertTrue;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted; import static org.mockito.ArgumentMatchers.any;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response; import static org.mockito.Mockito.atLeastOnce;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
import static org.springframework.test.util.AssertionErrors.assertEquals; import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion;
import static org.springframework.test.util.AssertionErrors.assertTrue; import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed;
import static org.springframework.util.StringUtils.hasText; import static org.springframework.util.StringUtils.hasText;
/** /**
@ -203,24 +220,48 @@ public class OpenSamlAuthenticationProviderTests {
public void authenticateWhenAssertionContainsAttributesThenItSucceeds() { public void authenticateWhenAssertionContainsAttributesThenItSucceeds() {
Response response = response(); Response response = response();
Assertion assertion = assertion(); Assertion assertion = assertion();
attributeStatements().forEach(as -> assertion.getAttributeStatements().add(as)); List<AttributeStatement> attributes = attributeStatements();
assertion.getAttributeStatements().addAll(attributes);
signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion); response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
Authentication authentication = this.provider.authenticate(token); Authentication authentication = this.provider.authenticate(token);
Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal(); Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal();
Map<String, Object> attributes = new LinkedHashMap<>(); Map<String, Object> expected = new LinkedHashMap<>();
attributes.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com")); expected.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com"));
attributes.put("name", Collections.singletonList("John Doe")); expected.put("name", Collections.singletonList("John Doe"));
attributes.put("age", Collections.singletonList(21)); expected.put("age", Collections.singletonList(21));
attributes.put("website", Collections.singletonList("https://johndoe.com/")); expected.put("website", Collections.singletonList("https://johndoe.com/"));
attributes.put("registered", Collections.singletonList(true)); expected.put("registered", Collections.singletonList(true));
Instant registeredDate = Instant.ofEpochMilli(DateTime.parse("1970-01-01T00:00:00Z").getMillis()); Instant registeredDate = Instant.ofEpochMilli(DateTime.parse("1970-01-01T00:00:00Z").getMillis());
attributes.put("registeredDate", Collections.singletonList(registeredDate)); expected.put("registeredDate", Collections.singletonList(registeredDate));
assertEquals("Values should be equal", "John Doe", principal.getFirstAttribute("name")); assertEquals("John Doe", principal.getFirstAttribute("name"));
assertTrue("Attributes should be equal", attributes.equals(principal.getAttributes())); assertEquals(expected, principal.getAttributes());
}
@Test
public void authenticateWhenAttributeValueMarshallerConfiguredThenUses() throws Exception {
Response response = response();
Assertion assertion = assertion();
List<AttributeStatement> attributes = attributeStatements();
assertion.getAttributeStatements().addAll(attributes);
signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
Element attributeElement = element("<element>value</element>");
Marshaller marshaller = mock(Marshaller.class);
when(marshaller.marshall(any(XMLObject.class))).thenReturn(attributeElement);
try {
XMLObjectProviderRegistrySupport.getMarshallerFactory().registerMarshaller(AttributeValue.DEFAULT_ELEMENT_NAME, marshaller);
this.provider.authenticate(token);
verify(marshaller, atLeastOnce()).marshall(any(XMLObject.class));
} finally {
XMLObjectProviderRegistrySupport.getMarshallerFactory().deregisterMarshaller(AttributeValue.DEFAULT_ELEMENT_NAME);
}
} }
@Test @Test
@ -352,4 +393,11 @@ public class OpenSamlAuthenticationProviderTests {
return new Saml2AuthenticationToken(payload, return new Saml2AuthenticationToken(payload,
DESTINATION, ASSERTING_PARTY_ENTITY_ID, RELYING_PARTY_ENTITY_ID, Arrays.asList(credentials)); DESTINATION, ASSERTING_PARTY_ENTITY_ID, RELYING_PARTY_ENTITY_ID, Arrays.asList(credentials));
} }
private static Element element(String xml) throws Exception {
DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
DocumentBuilder builder = factory.newDocumentBuilder();
Document doc = builder.parse(new InputSource(new StringReader(xml)));
return doc.getDocumentElement();
}
} }