diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java index 7324d25e86..69d97e2240 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,11 +17,13 @@ package org.springframework.security.saml2.provider.service.authentication; import java.security.cert.X509Certificate; import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -31,7 +33,19 @@ import javax.annotation.Nonnull; import net.shibboleth.utilities.java.support.resolver.CriteriaSet; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.joda.time.DateTime; import org.opensaml.core.criterion.EntityIdCriterion; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.Marshaller; + +import org.opensaml.core.xml.schema.XSAny; +import org.opensaml.core.xml.schema.XSBoolean; +import org.opensaml.core.xml.schema.XSBooleanValue; +import org.opensaml.core.xml.schema.XSDateTime; +import org.opensaml.core.xml.schema.XSInteger; +import org.opensaml.core.xml.schema.XSString; +import org.opensaml.core.xml.schema.XSURI; import org.opensaml.saml.common.assertion.ValidationContext; import org.opensaml.saml.common.assertion.ValidationResult; import org.opensaml.saml.common.xml.SAMLConstants; @@ -45,6 +59,8 @@ import org.opensaml.saml.saml2.assertion.SubjectConfirmationValidator; import org.opensaml.saml.saml2.assertion.impl.AudienceRestrictionConditionValidator; import org.opensaml.saml.saml2.assertion.impl.BearerSubjectConfirmationValidator; import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; import org.opensaml.saml.saml2.core.EncryptedAssertion; import org.opensaml.saml.saml2.core.EncryptedID; import org.opensaml.saml.saml2.core.NameID; @@ -205,8 +221,9 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi List validAssertions = validateResponse(token, response); Assertion assertion = validAssertions.get(0); String username = getUsername(token, assertion); + Map> attributes = getAssertionAttributes(assertion); return new Saml2Authentication( - new SimpleSaml2AuthenticatedPrincipal(username), token.getSaml2Response(), + new SimpleSaml2AuthenticatedPrincipal(username, attributes), token.getSaml2Response(), this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion))); } catch (Saml2AuthenticationException e) { throw e; @@ -494,6 +511,60 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi throw last; } + private Map> getAssertionAttributes(Assertion assertion) { + Map> attributeMap = new LinkedHashMap<>(); + for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) { + for (Attribute attribute : attributeStatement.getAttributes()) { + + List attributeValues = new ArrayList<>(); + for (XMLObject xmlObject : attribute.getAttributeValues()) { + Object attributeValue = getXmlObjectValue(xmlObject); + if (attributeValue != null) { + attributeValues.add(attributeValue); + } + } + attributeMap.put(attribute.getName(), attributeValues); + + } + } + return attributeMap; + } + + private Object getXmlObjectValue(XMLObject xmlObject) { + if (xmlObject == null) { + return null; + } + if (xmlObject instanceof XSAny) { + return getXSAnyObjectValue((XSAny) xmlObject); + } + if (xmlObject instanceof XSString) { + return ((XSString) xmlObject).getValue(); + } + if (xmlObject instanceof XSInteger) { + return ((XSInteger) xmlObject).getValue(); + } + if (xmlObject instanceof XSURI) { + return ((XSURI) xmlObject).getValue(); + } + if (xmlObject instanceof XSBoolean) { + XSBooleanValue xsBooleanValue = ((XSBoolean) xmlObject).getValue(); + return xsBooleanValue != null ? xsBooleanValue.getValue() : null; + } + if (xmlObject instanceof XSDateTime) { + DateTime dateTime = ((XSDateTime) xmlObject).getValue(); + return dateTime != null ? Instant.ofEpochMilli(dateTime.getMillis()) : null; + } + return null; + } + + private Object getXSAnyObjectValue(XSAny xsAny) { + Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(xsAny); + if (marshaller != null) { + return this.saml.serialize(xsAny); + } + return xsAny.getTextContent(); + } + private Saml2Error validationError(String code, String description) { return new Saml2Error(code, description); } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticatedPrincipal.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticatedPrincipal.java index 97bc90d65f..54cb297ffb 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticatedPrincipal.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticatedPrincipal.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,13 @@ package org.springframework.security.saml2.provider.service.authentication; +import org.springframework.lang.Nullable; import org.springframework.security.core.AuthenticatedPrincipal; +import org.springframework.util.CollectionUtils; + +import java.util.Collections; +import java.util.List; +import java.util.Map; /** * Saml2 representation of an {@link AuthenticatedPrincipal}. @@ -25,4 +31,40 @@ import org.springframework.security.core.AuthenticatedPrincipal; * @since 5.2.2 */ public interface Saml2AuthenticatedPrincipal extends AuthenticatedPrincipal { + /** + * Get the first value of Saml2 token attribute by name + * + * @param name the name of the attribute + * @param the type of the attribute + * @return the first attribute value or {@code null} otherwise + * @since 5.4 + */ + @Nullable + default A getFirstAttribute(String name) { + List values = getAttribute(name); + return CollectionUtils.firstElement(values); + } + + /** + * Get the Saml2 token attribute by name + * + * @param name the name of the attribute + * @param the type of the attribute + * @return the attribute or {@code null} otherwise + * @since 5.4 + */ + @Nullable + default List getAttribute(String name) { + return (List) getAttributes().get(name); + } + + /** + * Get the Saml2 token attributes + * + * @return the Saml2 token attributes + * @since 5.4 + */ + default Map> getAttributes() { + return Collections.emptyMap(); + } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/SimpleSaml2AuthenticatedPrincipal.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/SimpleSaml2AuthenticatedPrincipal.java index 3eb752c46a..d926d9c5bc 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/SimpleSaml2AuthenticatedPrincipal.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/SimpleSaml2AuthenticatedPrincipal.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ package org.springframework.security.saml2.provider.service.authentication; import java.io.Serializable; +import java.util.List; +import java.util.Map; /** * Default implementation of a {@link Saml2AuthenticatedPrincipal}. @@ -27,13 +29,20 @@ import java.io.Serializable; class SimpleSaml2AuthenticatedPrincipal implements Saml2AuthenticatedPrincipal, Serializable { private final String name; + private final Map> attributes; - SimpleSaml2AuthenticatedPrincipal(String name) { + SimpleSaml2AuthenticatedPrincipal(String name, Map> attributes) { this.name = name; + this.attributes = attributes; } @Override public String getName() { return this.name; } + + @Override + public Map> getAttributes() { + return this.attributes; + } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java index c3e3c8317a..7e91bb85fd 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java @@ -19,7 +19,11 @@ package org.springframework.security.saml2.provider.service.authentication; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectOutputStream; +import java.time.Instant; import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; import org.hamcrest.BaseMatcher; import org.hamcrest.Description; @@ -39,6 +43,7 @@ import org.springframework.security.core.Authentication; import org.springframework.security.saml2.credentials.Saml2X509Credential; import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion; +import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements; import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted; import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response; import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed; @@ -47,6 +52,7 @@ import static org.springframework.security.saml2.credentials.TestSaml2X509Creden import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential; +import static org.springframework.test.util.AssertionErrors.assertEquals; import static org.springframework.test.util.AssertionErrors.assertTrue; import static org.springframework.util.StringUtils.hasText; @@ -193,6 +199,30 @@ public class OpenSamlAuthenticationProviderTests { this.provider.authenticate(token); } + @Test + public void authenticateWhenAssertionContainsAttributesThenItSucceeds() { + Response response = response(); + Assertion assertion = assertion(); + attributeStatements().forEach(as -> assertion.getAttributeStatements().add(as)); + signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); + response.getAssertions().add(assertion); + Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); + Authentication authentication = this.provider.authenticate(token); + Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal(); + + Map attributes = new LinkedHashMap<>(); + attributes.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com")); + attributes.put("name", Collections.singletonList("John Doe")); + attributes.put("age", Collections.singletonList(21)); + attributes.put("website", Collections.singletonList("https://johndoe.com/")); + attributes.put("registered", Collections.singletonList(true)); + Instant registeredDate = Instant.ofEpochMilli(DateTime.parse("1970-01-01T00:00:00Z").getMillis()); + attributes.put("registeredDate", Collections.singletonList(registeredDate)); + + assertEquals("Values should be equal", "John Doe", principal.getFirstAttribute("name")); + assertTrue("Attributes should be equal", attributes.equals(principal.getAttributes())); + } + @Test public void authenticateWhenEncryptedAssertionWithoutSignatureThenItFails() throws Exception { this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_SIGNATURE)); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/SimpleSaml2AuthenticatedPrincipalTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/SimpleSaml2AuthenticatedPrincipalTests.java index 5948ab7ca9..bd937e78f4 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/SimpleSaml2AuthenticatedPrincipalTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/SimpleSaml2AuthenticatedPrincipalTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,15 +16,58 @@ package org.springframework.security.saml2.provider.service.authentication; -import org.junit.Assert; +import org.joda.time.DateTime; import org.junit.Test; +import java.time.Instant; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + public class SimpleSaml2AuthenticatedPrincipalTests { @Test public void createSimpleSaml2AuthenticatedPrincipal() { - SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user"); + Map> attributes = new LinkedHashMap<>(); + attributes.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com")); + SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user", attributes); + assertThat(principal.getName()).isEqualTo("user"); + assertThat(principal.getAttributes()).isEqualTo(attributes); + } - Assert.assertEquals("user", principal.getName()); + @Test + public void getFirstAttributeWhenStringValueThenReturnsValue() { + Map> attributes = new LinkedHashMap<>(); + attributes.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com")); + SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user", attributes); + assertThat(principal.getFirstAttribute("email")).isEqualTo(attributes.get("email").get(0)); + } + + @Test + public void getAttributeWhenStringValuesThenReturnsValues() { + Map> attributes = new LinkedHashMap<>(); + attributes.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com")); + SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user", attributes); + assertThat(principal.getAttribute("email")).isEqualTo(attributes.get("email")); + } + + @Test + public void getAttributeWhenDistinctValuesThenReturnsValues() { + final Boolean registered = true; + final Instant registeredDate = Instant.ofEpochMilli(DateTime.parse("1970-01-01T00:00:00Z").getMillis()); + + Map> attributes = new LinkedHashMap<>(); + attributes.put("registration", Arrays.asList(registered, registeredDate)); + + SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user", attributes); + + List registrationInfo = principal.getAttribute("registration"); + + assertThat(registrationInfo).isNotNull(); + assertThat((Boolean) registrationInfo.get(0)).isEqualTo(registered); + assertThat((Instant) registrationInfo.get(1)).isEqualTo(registeredDate); } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java index 27d0d73129..79b8823141 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java @@ -19,6 +19,8 @@ package org.springframework.security.saml2.provider.service.authentication; import java.security.cert.X509Certificate; import java.util.Base64; import java.util.UUID; +import java.util.List; +import java.util.ArrayList; import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; @@ -26,9 +28,26 @@ import org.apache.xml.security.encryption.XMLCipherParameters; import org.joda.time.DateTime; import org.joda.time.Duration; import org.opensaml.core.xml.io.MarshallingException; + +import org.opensaml.core.xml.schema.XSAny; +import org.opensaml.core.xml.schema.XSBoolean; +import org.opensaml.core.xml.schema.XSBooleanValue; +import org.opensaml.core.xml.schema.XSDateTime; +import org.opensaml.core.xml.schema.XSInteger; +import org.opensaml.core.xml.schema.XSString; +import org.opensaml.core.xml.schema.XSURI; +import org.opensaml.core.xml.schema.impl.XSAnyBuilder; +import org.opensaml.core.xml.schema.impl.XSBooleanBuilder; +import org.opensaml.core.xml.schema.impl.XSDateTimeBuilder; +import org.opensaml.core.xml.schema.impl.XSIntegerBuilder; +import org.opensaml.core.xml.schema.impl.XSStringBuilder; +import org.opensaml.core.xml.schema.impl.XSURIBuilder; import org.opensaml.saml.common.SAMLVersion; import org.opensaml.saml.common.SignableSAMLObject; import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.AttributeValue; import org.opensaml.saml.saml2.core.Conditions; import org.opensaml.saml.saml2.core.EncryptedAssertion; import org.opensaml.saml.saml2.core.EncryptedID; @@ -38,6 +57,8 @@ import org.opensaml.saml.saml2.core.Response; import org.opensaml.saml.saml2.core.Subject; import org.opensaml.saml.saml2.core.SubjectConfirmation; import org.opensaml.saml.saml2.core.SubjectConfirmationData; +import org.opensaml.saml.saml2.core.impl.AttributeBuilder; +import org.opensaml.saml.saml2.core.impl.AttributeStatementBuilder; import org.opensaml.saml.saml2.encryption.Encrypter; import org.opensaml.security.SecurityException; import org.opensaml.security.credential.BasicCredential; @@ -222,4 +243,66 @@ final class TestOpenSamlObjects { return encrypter; } + + static List attributeStatements() { + List attributeStatements = new ArrayList<>(); + + AttributeStatementBuilder attributeStatementBuilder = new AttributeStatementBuilder(); + AttributeBuilder attributeBuilder = new AttributeBuilder(); + + AttributeStatement attrStmt1 = attributeStatementBuilder.buildObject(); + + Attribute emailAttr = attributeBuilder.buildObject(); + emailAttr.setName("email"); + XSAny email1 = new XSAnyBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME); + email1.setTextContent("john.doe@example.com"); + emailAttr.getAttributeValues().add(email1); + XSAny email2 = new XSAnyBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME); + email2.setTextContent("doe.john@example.com"); + emailAttr.getAttributeValues().add(email2); + attrStmt1.getAttributes().add(emailAttr); + + Attribute nameAttr = attributeBuilder.buildObject(); + nameAttr.setName("name"); + XSString name = new XSStringBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSString.TYPE_NAME); + name.setValue("John Doe"); + nameAttr.getAttributeValues().add(name); + attrStmt1.getAttributes().add(nameAttr); + + Attribute ageAttr = attributeBuilder.buildObject(); + ageAttr.setName("age"); + XSInteger age = new XSIntegerBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSInteger.TYPE_NAME); + age.setValue(21); + ageAttr.getAttributeValues().add(age); + attrStmt1.getAttributes().add(ageAttr); + + attributeStatements.add(attrStmt1); + + AttributeStatement attrStmt2 = attributeStatementBuilder.buildObject(); + + Attribute websiteAttr = attributeBuilder.buildObject(); + websiteAttr.setName("website"); + XSURI uri = new XSURIBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSURI.TYPE_NAME); + uri.setValue("https://johndoe.com/"); + websiteAttr.getAttributeValues().add(uri); + attrStmt2.getAttributes().add(websiteAttr); + + Attribute registeredAttr = attributeBuilder.buildObject(); + registeredAttr.setName("registered"); + XSBoolean registered = new XSBooleanBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSBoolean.TYPE_NAME); + registered.setValue(new XSBooleanValue(true, false)); + registeredAttr.getAttributeValues().add(registered); + attrStmt2.getAttributes().add(registeredAttr); + + Attribute registeredDateAttr = attributeBuilder.buildObject(); + registeredDateAttr.setName("registeredDate"); + XSDateTime registeredDate = new XSDateTimeBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSDateTime.TYPE_NAME); + registeredDate.setValue(DateTime.parse("1970-01-01T00:00:00Z")); + registeredDateAttr.getAttributeValues().add(registeredDate); + attrStmt2.getAttributes().add(registeredDateAttr); + + attributeStatements.add(attrStmt2); + + return attributeStatements; + } }