OpenSamlAuthenticationProvider Uses OpenSAML Directly

Closes gh-8773
This commit is contained in:
Josh Cummings 2020-07-23 16:08:16 -06:00
parent 77128a94e2
commit 2e2da06bdb
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
2 changed files with 59 additions and 22 deletions

View File

@ -15,6 +15,8 @@
*/
package org.springframework.security.saml2.provider.service.authentication;
import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
@ -32,13 +34,17 @@ import java.util.function.Function;
import javax.annotation.Nonnull;
import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
import net.shibboleth.utilities.java.support.xml.ParserPool;
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.joda.time.DateTime;
import org.opensaml.core.config.ConfigurationService;
import org.opensaml.core.criterion.EntityIdCriterion;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
import org.opensaml.core.xml.io.Marshaller;
import org.opensaml.core.xml.io.MarshallingException;
import org.opensaml.core.xml.schema.XSAny;
import org.opensaml.core.xml.schema.XSBoolean;
import org.opensaml.core.xml.schema.XSBooleanValue;
@ -65,6 +71,7 @@ import org.opensaml.saml.saml2.core.EncryptedID;
import org.opensaml.saml.saml2.core.NameID;
import org.opensaml.saml.saml2.core.Response;
import org.opensaml.saml.saml2.core.SubjectConfirmation;
import org.opensaml.saml.saml2.core.impl.ResponseUnmarshaller;
import org.opensaml.saml.saml2.encryption.Decrypter;
import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver;
import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator;
@ -88,6 +95,8 @@ import org.opensaml.xmlsec.keyinfo.impl.StaticKeyInfoCredentialResolver;
import org.opensaml.xmlsec.signature.support.SignaturePrevalidator;
import org.opensaml.xmlsec.signature.support.SignatureTrustEngine;
import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.authentication.AbstractAuthenticationToken;
@ -120,7 +129,6 @@ import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_IS
import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.MALFORMED_RESPONSE_DATA;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.SUBJECT_NOT_FOUND;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS;
import static org.springframework.util.Assert.notNull;
/**
@ -167,7 +175,9 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
private static Log logger = LogFactory.getLog(OpenSamlAuthenticationProvider.class);
private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance();
private final XMLObjectProviderRegistry registry;
private final ResponseUnmarshaller responseUnmarshaller;
private final ParserPool parserPool;
private Converter<Assertion, Collection<? extends GrantedAuthority>> authoritiesExtractor =
(a -> singletonList(new SimpleGrantedAuthority("ROLE_USER")));
@ -192,6 +202,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)));
};
/**
* Creates an {@link OpenSamlAuthenticationProvider}
*/
public OpenSamlAuthenticationProvider() {
this.registry = ConfigurationService.get(XMLObjectProviderRegistry.class);
this.responseUnmarshaller = (ResponseUnmarshaller) this.registry.getUnmarshallerFactory()
.getUnmarshaller(Response.DEFAULT_ELEMENT_NAME);
this.parserPool = this.registry.getParserPool();
}
/**
* Sets the {@link Converter} used for extracting assertion attributes that
* can be mapped to authorities.
@ -265,15 +285,13 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
private Response parse(String response) throws Saml2Exception, Saml2AuthenticationException {
try {
Object result = this.saml.resolve(response);
if (result instanceof Response) {
return (Response) result;
}
else {
throw authException(UNKNOWN_RESPONSE_CLASS, "Invalid response class:" + result.getClass().getName());
}
} catch (Saml2Exception x) {
throw authException(MALFORMED_RESPONSE_DATA, x.getMessage(), x);
Document document = this.parserPool.parse(new ByteArrayInputStream(
response.getBytes(StandardCharsets.UTF_8)));
Element element = document.getDocumentElement();
return (Response) this.responseUnmarshaller.unmarshall(element);
}
catch (Exception e) {
throw authException(MALFORMED_RESPONSE_DATA, e.getMessage(), e);
}
}
@ -427,9 +445,14 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
}
private Object getXSAnyObjectValue(XSAny xsAny) {
Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(xsAny);
Marshaller marshaller = this.registry.getMarshallerFactory().getMarshaller(xsAny);
if (marshaller != null) {
return this.saml.serialize(xsAny);
try {
Element element = marshaller.marshall(xsAny);
return SerializeSupport.nodeToString(element);
} catch (MarshallingException e) {
throw new Saml2Exception(e);
}
}
return xsAny.getTextContent();
}

View File

@ -29,6 +29,7 @@ import java.util.Map;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;
import org.hamcrest.Matcher;
@ -40,6 +41,7 @@ import org.junit.rules.ExpectedException;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.Marshaller;
import org.opensaml.core.xml.io.MarshallingException;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AttributeValue;
@ -52,6 +54,7 @@ import org.w3c.dom.Element;
import org.xml.sax.InputSource;
import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.credentials.Saml2X509Credential;
import static org.assertj.core.api.Assertions.assertThat;
@ -60,6 +63,8 @@ import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential;
@ -85,8 +90,6 @@ public class OpenSamlAuthenticationProviderTests {
private static String ASSERTING_PARTY_ENTITY_ID = "https://some.idp.test/saml2/idp";
private OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
private OpenSamlImplementation saml = OpenSamlImplementation.getInstance();
@Rule
public ExpectedException exception = ExpectedException.none();
@ -108,10 +111,11 @@ public class OpenSamlAuthenticationProviderTests {
@Test
public void authenticateWhenUnknownDataClassThenThrowAuthenticationException() {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS));
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA));
Assertion assertion = this.saml.buildSamlObject(Assertion.DEFAULT_ELEMENT_NAME);
this.provider.authenticate(token(this.saml.serialize(assertion), relyingPartyVerifyingCredential()));
Assertion assertion = (Assertion) getBuilderFactory().getBuilder(Assertion.DEFAULT_ELEMENT_NAME)
.buildObject(Assertion.DEFAULT_ELEMENT_NAME);
this.provider.authenticate(token(serialize(assertion), relyingPartyVerifyingCredential()));
}
@Test
@ -316,7 +320,7 @@ public class OpenSamlAuthenticationProviderTests {
Response response = response();
EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential());
response.getEncryptedAssertions().add(encryptedAssertion);
Saml2AuthenticationToken token = token(this.saml.serialize(response), relyingPartyVerifyingCredential());
Saml2AuthenticationToken token = token(serialize(response), relyingPartyVerifyingCredential());
this.provider.authenticate(token);
}
@ -329,7 +333,7 @@ public class OpenSamlAuthenticationProviderTests {
Response response = response();
EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential());
response.getEncryptedAssertions().add(encryptedAssertion);
Saml2AuthenticationToken token = token(this.saml.serialize(response), assertingPartyPrivateCredential());
Saml2AuthenticationToken token = token(serialize(response), assertingPartyPrivateCredential());
this.provider.authenticate(token);
}
@ -349,6 +353,16 @@ public class OpenSamlAuthenticationProviderTests {
objectOutputStream.flush();
}
private String serialize(XMLObject object) {
try {
Marshaller marshaller = getMarshallerFactory().getMarshaller(object);
Element element = marshaller.marshall(object);
return SerializeSupport.nodeToString(element);
} catch (MarshallingException e) {
throw new Saml2Exception(e);
}
}
private Matcher<Saml2AuthenticationException> authenticationMatcher(String code) {
return authenticationMatcher(code, null);
}
@ -382,7 +396,7 @@ public class OpenSamlAuthenticationProviderTests {
}
private Saml2AuthenticationToken token(Response response, Saml2X509Credential... credentials) {
String payload = this.saml.serialize(response);
String payload = serialize(response);
return token(payload, credentials);
}