Add SAML Attribute Support

Closes gh-8661
This commit is contained in:
Nikola Kostic 2020-06-09 13:32:37 +02:00 committed by Josh Cummings
parent efb6953017
commit eed33228f4
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
6 changed files with 287 additions and 9 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,11 +17,13 @@ package org.springframework.security.saml2.provider.service.authentication;
import java.security.cert.X509Certificate;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
@ -31,7 +33,19 @@ import javax.annotation.Nonnull;
import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.joda.time.DateTime;
import org.opensaml.core.criterion.EntityIdCriterion;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.Marshaller;
import org.opensaml.core.xml.schema.XSAny;
import org.opensaml.core.xml.schema.XSBoolean;
import org.opensaml.core.xml.schema.XSBooleanValue;
import org.opensaml.core.xml.schema.XSDateTime;
import org.opensaml.core.xml.schema.XSInteger;
import org.opensaml.core.xml.schema.XSString;
import org.opensaml.core.xml.schema.XSURI;
import org.opensaml.saml.common.assertion.ValidationContext;
import org.opensaml.saml.common.assertion.ValidationResult;
import org.opensaml.saml.common.xml.SAMLConstants;
@ -45,6 +59,8 @@ import org.opensaml.saml.saml2.assertion.SubjectConfirmationValidator;
import org.opensaml.saml.saml2.assertion.impl.AudienceRestrictionConditionValidator;
import org.opensaml.saml.saml2.assertion.impl.BearerSubjectConfirmationValidator;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.EncryptedAssertion;
import org.opensaml.saml.saml2.core.EncryptedID;
import org.opensaml.saml.saml2.core.NameID;
@ -205,8 +221,9 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
List<Assertion> validAssertions = validateResponse(token, response);
Assertion assertion = validAssertions.get(0);
String username = getUsername(token, assertion);
Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
return new Saml2Authentication(
new SimpleSaml2AuthenticatedPrincipal(username), token.getSaml2Response(),
new SimpleSaml2AuthenticatedPrincipal(username, attributes), token.getSaml2Response(),
this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)));
} catch (Saml2AuthenticationException e) {
throw e;
@ -494,6 +511,60 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
throw last;
}
private Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
Map<String, List<Object>> attributeMap = new LinkedHashMap<>();
for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) {
for (Attribute attribute : attributeStatement.getAttributes()) {
List<Object> attributeValues = new ArrayList<>();
for (XMLObject xmlObject : attribute.getAttributeValues()) {
Object attributeValue = getXmlObjectValue(xmlObject);
if (attributeValue != null) {
attributeValues.add(attributeValue);
}
}
attributeMap.put(attribute.getName(), attributeValues);
}
}
return attributeMap;
}
private Object getXmlObjectValue(XMLObject xmlObject) {
if (xmlObject == null) {
return null;
}
if (xmlObject instanceof XSAny) {
return getXSAnyObjectValue((XSAny) xmlObject);
}
if (xmlObject instanceof XSString) {
return ((XSString) xmlObject).getValue();
}
if (xmlObject instanceof XSInteger) {
return ((XSInteger) xmlObject).getValue();
}
if (xmlObject instanceof XSURI) {
return ((XSURI) xmlObject).getValue();
}
if (xmlObject instanceof XSBoolean) {
XSBooleanValue xsBooleanValue = ((XSBoolean) xmlObject).getValue();
return xsBooleanValue != null ? xsBooleanValue.getValue() : null;
}
if (xmlObject instanceof XSDateTime) {
DateTime dateTime = ((XSDateTime) xmlObject).getValue();
return dateTime != null ? Instant.ofEpochMilli(dateTime.getMillis()) : null;
}
return null;
}
private Object getXSAnyObjectValue(XSAny xsAny) {
Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(xsAny);
if (marshaller != null) {
return this.saml.serialize(xsAny);
}
return xsAny.getTextContent();
}
private Saml2Error validationError(String code, String description) {
return new Saml2Error(code, description);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,7 +16,13 @@
package org.springframework.security.saml2.provider.service.authentication;
import org.springframework.lang.Nullable;
import org.springframework.security.core.AuthenticatedPrincipal;
import org.springframework.util.CollectionUtils;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* Saml2 representation of an {@link AuthenticatedPrincipal}.
@ -25,4 +31,40 @@ import org.springframework.security.core.AuthenticatedPrincipal;
* @since 5.2.2
*/
public interface Saml2AuthenticatedPrincipal extends AuthenticatedPrincipal {
/**
* Get the first value of Saml2 token attribute by name
*
* @param name the name of the attribute
* @param <A> the type of the attribute
* @return the first attribute value or {@code null} otherwise
* @since 5.4
*/
@Nullable
default <A> A getFirstAttribute(String name) {
List<A> values = getAttribute(name);
return CollectionUtils.firstElement(values);
}
/**
* Get the Saml2 token attribute by name
*
* @param name the name of the attribute
* @param <A> the type of the attribute
* @return the attribute or {@code null} otherwise
* @since 5.4
*/
@Nullable
default <A> List<A> getAttribute(String name) {
return (List<A>) getAttributes().get(name);
}
/**
* Get the Saml2 token attributes
*
* @return the Saml2 token attributes
* @since 5.4
*/
default Map<String, List<Object>> getAttributes() {
return Collections.emptyMap();
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,6 +17,8 @@
package org.springframework.security.saml2.provider.service.authentication;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
/**
* Default implementation of a {@link Saml2AuthenticatedPrincipal}.
@ -27,13 +29,20 @@ import java.io.Serializable;
class SimpleSaml2AuthenticatedPrincipal implements Saml2AuthenticatedPrincipal, Serializable {
private final String name;
private final Map<String, List<Object>> attributes;
SimpleSaml2AuthenticatedPrincipal(String name) {
SimpleSaml2AuthenticatedPrincipal(String name, Map<String, List<Object>> attributes) {
this.name = name;
this.attributes = attributes;
}
@Override
public String getName() {
return this.name;
}
@Override
public Map<String, List<Object>> getAttributes() {
return this.attributes;
}
}

View File

@ -19,7 +19,11 @@ package org.springframework.security.saml2.provider.service.authentication;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;
@ -39,6 +43,7 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.credentials.Saml2X509Credential;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed;
@ -47,6 +52,7 @@ import static org.springframework.security.saml2.credentials.TestSaml2X509Creden
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
import static org.springframework.test.util.AssertionErrors.assertEquals;
import static org.springframework.test.util.AssertionErrors.assertTrue;
import static org.springframework.util.StringUtils.hasText;
@ -193,6 +199,30 @@ public class OpenSamlAuthenticationProviderTests {
this.provider.authenticate(token);
}
@Test
public void authenticateWhenAssertionContainsAttributesThenItSucceeds() {
Response response = response();
Assertion assertion = assertion();
attributeStatements().forEach(as -> assertion.getAttributeStatements().add(as));
signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
Authentication authentication = this.provider.authenticate(token);
Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal();
Map<String, Object> attributes = new LinkedHashMap<>();
attributes.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com"));
attributes.put("name", Collections.singletonList("John Doe"));
attributes.put("age", Collections.singletonList(21));
attributes.put("website", Collections.singletonList("https://johndoe.com/"));
attributes.put("registered", Collections.singletonList(true));
Instant registeredDate = Instant.ofEpochMilli(DateTime.parse("1970-01-01T00:00:00Z").getMillis());
attributes.put("registeredDate", Collections.singletonList(registeredDate));
assertEquals("Values should be equal", "John Doe", principal.getFirstAttribute("name"));
assertTrue("Attributes should be equal", attributes.equals(principal.getAttributes()));
}
@Test
public void authenticateWhenEncryptedAssertionWithoutSignatureThenItFails() throws Exception {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_SIGNATURE));

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,15 +16,58 @@
package org.springframework.security.saml2.provider.service.authentication;
import org.junit.Assert;
import org.joda.time.DateTime;
import org.junit.Test;
import java.time.Instant;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
public class SimpleSaml2AuthenticatedPrincipalTests {
@Test
public void createSimpleSaml2AuthenticatedPrincipal() {
SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user");
Map<String, List<Object>> attributes = new LinkedHashMap<>();
attributes.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com"));
SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user", attributes);
assertThat(principal.getName()).isEqualTo("user");
assertThat(principal.getAttributes()).isEqualTo(attributes);
}
Assert.assertEquals("user", principal.getName());
@Test
public void getFirstAttributeWhenStringValueThenReturnsValue() {
Map<String, List<Object>> attributes = new LinkedHashMap<>();
attributes.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com"));
SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user", attributes);
assertThat(principal.<String>getFirstAttribute("email")).isEqualTo(attributes.get("email").get(0));
}
@Test
public void getAttributeWhenStringValuesThenReturnsValues() {
Map<String, List<Object>> attributes = new LinkedHashMap<>();
attributes.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com"));
SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user", attributes);
assertThat(principal.<String>getAttribute("email")).isEqualTo(attributes.get("email"));
}
@Test
public void getAttributeWhenDistinctValuesThenReturnsValues() {
final Boolean registered = true;
final Instant registeredDate = Instant.ofEpochMilli(DateTime.parse("1970-01-01T00:00:00Z").getMillis());
Map<String, List<Object>> attributes = new LinkedHashMap<>();
attributes.put("registration", Arrays.asList(registered, registeredDate));
SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user", attributes);
List<Object> registrationInfo = principal.getAttribute("registration");
assertThat(registrationInfo).isNotNull();
assertThat((Boolean) registrationInfo.get(0)).isEqualTo(registered);
assertThat((Instant) registrationInfo.get(1)).isEqualTo(registeredDate);
}
}

View File

@ -19,6 +19,8 @@ package org.springframework.security.saml2.provider.service.authentication;
import java.security.cert.X509Certificate;
import java.util.Base64;
import java.util.UUID;
import java.util.List;
import java.util.ArrayList;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
@ -26,9 +28,26 @@ import org.apache.xml.security.encryption.XMLCipherParameters;
import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.opensaml.core.xml.io.MarshallingException;
import org.opensaml.core.xml.schema.XSAny;
import org.opensaml.core.xml.schema.XSBoolean;
import org.opensaml.core.xml.schema.XSBooleanValue;
import org.opensaml.core.xml.schema.XSDateTime;
import org.opensaml.core.xml.schema.XSInteger;
import org.opensaml.core.xml.schema.XSString;
import org.opensaml.core.xml.schema.XSURI;
import org.opensaml.core.xml.schema.impl.XSAnyBuilder;
import org.opensaml.core.xml.schema.impl.XSBooleanBuilder;
import org.opensaml.core.xml.schema.impl.XSDateTimeBuilder;
import org.opensaml.core.xml.schema.impl.XSIntegerBuilder;
import org.opensaml.core.xml.schema.impl.XSStringBuilder;
import org.opensaml.core.xml.schema.impl.XSURIBuilder;
import org.opensaml.saml.common.SAMLVersion;
import org.opensaml.saml.common.SignableSAMLObject;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AttributeValue;
import org.opensaml.saml.saml2.core.Conditions;
import org.opensaml.saml.saml2.core.EncryptedAssertion;
import org.opensaml.saml.saml2.core.EncryptedID;
@ -38,6 +57,8 @@ import org.opensaml.saml.saml2.core.Response;
import org.opensaml.saml.saml2.core.Subject;
import org.opensaml.saml.saml2.core.SubjectConfirmation;
import org.opensaml.saml.saml2.core.SubjectConfirmationData;
import org.opensaml.saml.saml2.core.impl.AttributeBuilder;
import org.opensaml.saml.saml2.core.impl.AttributeStatementBuilder;
import org.opensaml.saml.saml2.encryption.Encrypter;
import org.opensaml.security.SecurityException;
import org.opensaml.security.credential.BasicCredential;
@ -222,4 +243,66 @@ final class TestOpenSamlObjects {
return encrypter;
}
static List<AttributeStatement> attributeStatements() {
List<AttributeStatement> attributeStatements = new ArrayList<>();
AttributeStatementBuilder attributeStatementBuilder = new AttributeStatementBuilder();
AttributeBuilder attributeBuilder = new AttributeBuilder();
AttributeStatement attrStmt1 = attributeStatementBuilder.buildObject();
Attribute emailAttr = attributeBuilder.buildObject();
emailAttr.setName("email");
XSAny email1 = new XSAnyBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME);
email1.setTextContent("john.doe@example.com");
emailAttr.getAttributeValues().add(email1);
XSAny email2 = new XSAnyBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME);
email2.setTextContent("doe.john@example.com");
emailAttr.getAttributeValues().add(email2);
attrStmt1.getAttributes().add(emailAttr);
Attribute nameAttr = attributeBuilder.buildObject();
nameAttr.setName("name");
XSString name = new XSStringBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSString.TYPE_NAME);
name.setValue("John Doe");
nameAttr.getAttributeValues().add(name);
attrStmt1.getAttributes().add(nameAttr);
Attribute ageAttr = attributeBuilder.buildObject();
ageAttr.setName("age");
XSInteger age = new XSIntegerBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSInteger.TYPE_NAME);
age.setValue(21);
ageAttr.getAttributeValues().add(age);
attrStmt1.getAttributes().add(ageAttr);
attributeStatements.add(attrStmt1);
AttributeStatement attrStmt2 = attributeStatementBuilder.buildObject();
Attribute websiteAttr = attributeBuilder.buildObject();
websiteAttr.setName("website");
XSURI uri = new XSURIBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSURI.TYPE_NAME);
uri.setValue("https://johndoe.com/");
websiteAttr.getAttributeValues().add(uri);
attrStmt2.getAttributes().add(websiteAttr);
Attribute registeredAttr = attributeBuilder.buildObject();
registeredAttr.setName("registered");
XSBoolean registered = new XSBooleanBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSBoolean.TYPE_NAME);
registered.setValue(new XSBooleanValue(true, false));
registeredAttr.getAttributeValues().add(registered);
attrStmt2.getAttributes().add(registeredAttr);
Attribute registeredDateAttr = attributeBuilder.buildObject();
registeredDateAttr.setName("registeredDate");
XSDateTime registeredDate = new XSDateTimeBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSDateTime.TYPE_NAME);
registeredDate.setValue(DateTime.parse("1970-01-01T00:00:00Z"));
registeredDateAttr.getAttributeValues().add(registeredDate);
attrStmt2.getAttributes().add(registeredDateAttr);
attributeStatements.add(attrStmt2);
return attributeStatements;
}
}