Simplify OpenSamlImplementation

- Removed reflection usage
- Simplified method signatures

Issue gh-7711
Fixes gh-8147
This commit is contained in:
Josh Cummings 2020-03-18 16:58:26 -06:00
parent 1bbbf3be3d
commit 15cc15cc3c
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
4 changed files with 94 additions and 101 deletions

View File

@ -16,22 +16,22 @@
package org.springframework.security.saml2.provider.service.authentication;
import org.joda.time.DateTime;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.Issuer;
import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest.Builder;
import org.springframework.util.Assert;
import java.time.Clock;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.joda.time.DateTime;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.Issuer;
import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest.Builder;
import org.springframework.util.Assert;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Collections.emptyList;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDeflate;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlEncode;
@ -46,7 +46,9 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
@Override
@Deprecated
public String createAuthenticationRequest(Saml2AuthenticationRequest request) {
return createAuthenticationRequest(request, request.getCredentials());
AuthnRequest authnRequest = createAuthnRequest(request.getIssuer(),
request.getDestination(), request.getAssertionConsumerServiceUrl());
return this.saml.serialize(authnRequest, request.getCredentials());
}
/**
@ -54,11 +56,11 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
*/
@Override
public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) {
List<Saml2X509Credential> signingCredentials = context.getRelyingPartyRegistration().getProviderDetails().isSignAuthNRequest() ?
context.getRelyingPartyRegistration().getSigningCredentials() :
emptyList();
AuthnRequest authnRequest = createAuthnRequest(context);
String xml = context.getRelyingPartyRegistration().getProviderDetails().isSignAuthNRequest() ?
this.saml.serialize(authnRequest, context.getRelyingPartyRegistration().getSigningCredentials()) :
this.saml.serialize(authnRequest);
String xml = createAuthenticationRequest(context, signingCredentials);
return Saml2PostAuthenticationRequest.withAuthenticationRequestContext(context)
.samlRequest(samlEncode(xml.getBytes(UTF_8)))
.build();
@ -69,7 +71,8 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
*/
@Override
public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext context) {
String xml = createAuthenticationRequest(context, emptyList());
AuthnRequest authnRequest = createAuthnRequest(context);
String xml = this.saml.serialize(authnRequest);
Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context);
String deflatedAndEncoded = samlEncode(samlDeflate(xml));
result.samlRequest(deflatedAndEncoded)
@ -91,27 +94,24 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
return result.build();
}
private String createAuthenticationRequest(Saml2AuthenticationRequestContext request, List<Saml2X509Credential> credentials) {
return createAuthenticationRequest(Saml2AuthenticationRequest.withAuthenticationRequestContext(request).build(), credentials);
private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
return createAuthnRequest(context.getIssuer(),
context.getDestination(), context.getAssertionConsumerServiceUrl());
}
private String createAuthenticationRequest(Saml2AuthenticationRequest context, List<Saml2X509Credential> credentials) {
AuthnRequest auth = this.saml.buildSAMLObject(AuthnRequest.class);
private AuthnRequest createAuthnRequest(String issuer, String destination, String assertionConsumerServiceUrl) {
AuthnRequest auth = this.saml.buildSamlObject(AuthnRequest.DEFAULT_ELEMENT_NAME);
auth.setID("ARQ" + UUID.randomUUID().toString().substring(1));
auth.setIssueInstant(new DateTime(this.clock.millis()));
auth.setForceAuthn(Boolean.FALSE);
auth.setIsPassive(Boolean.FALSE);
auth.setProtocolBinding(protocolBinding);
Issuer issuer = this.saml.buildSAMLObject(Issuer.class);
issuer.setValue(context.getIssuer());
auth.setIssuer(issuer);
auth.setDestination(context.getDestination());
auth.setAssertionConsumerServiceURL(context.getAssertionConsumerServiceUrl());
return this.saml.toXml(
auth,
credentials,
context.getIssuer()
);
Issuer iss = this.saml.buildSamlObject(Issuer.DEFAULT_ELEMENT_NAME);
iss.setValue(issuer);
auth.setIssuer(iss);
auth.setDestination(destination);
auth.setAssertionConsumerServiceURL(assertionConsumerServiceUrl);
return auth;
}
/**

View File

@ -16,6 +16,16 @@
package org.springframework.security.saml2.provider.service.authentication;
import java.io.ByteArrayInputStream;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.xml.XMLConstants;
import javax.xml.namespace.QName;
import net.shibboleth.utilities.java.support.component.ComponentInitializationException;
import net.shibboleth.utilities.java.support.xml.BasicParserPool;
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
@ -24,13 +34,14 @@ import org.opensaml.core.config.ConfigurationService;
import org.opensaml.core.config.InitializationException;
import org.opensaml.core.config.InitializationService;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.XMLObjectBuilderFactory;
import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.MarshallerFactory;
import org.opensaml.core.xml.io.MarshallingException;
import org.opensaml.core.xml.io.UnmarshallerFactory;
import org.opensaml.core.xml.io.UnmarshallingException;
import org.opensaml.saml.common.SignableSAMLObject;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver;
import org.opensaml.security.SecurityException;
import org.opensaml.security.credential.BasicCredential;
@ -47,28 +58,17 @@ import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyR
import org.opensaml.xmlsec.signature.support.SignatureConstants;
import org.opensaml.xmlsec.signature.support.SignatureException;
import org.opensaml.xmlsec.signature.support.SignatureSupport;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.authentication.Saml2Utils;
import org.springframework.util.Assert;
import org.springframework.web.util.UriUtils;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import javax.xml.XMLConstants;
import javax.xml.namespace.QName;
import java.io.ByteArrayInputStream;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.util.Assert;
import org.springframework.web.util.UriUtils;
import static java.lang.Boolean.FALSE;
import static java.lang.Boolean.TRUE;
import static java.util.Arrays.asList;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
import static org.springframework.util.StringUtils.hasText;
/**
@ -76,6 +76,8 @@ import static org.springframework.util.StringUtils.hasText;
*/
final class OpenSamlImplementation {
private static OpenSamlImplementation instance = new OpenSamlImplementation();
private static XMLObjectBuilderFactory xmlObjectBuilderFactory =
XMLObjectProviderRegistrySupport.getBuilderFactory();
private final BasicParserPool parserPool = new BasicParserPool();
private final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver(
@ -167,37 +169,31 @@ final class OpenSamlImplementation {
return this.encryptedKeyResolver;
}
<T> T buildSAMLObject(final Class<T> clazz) {
try {
QName defaultElementName = (QName) clazz.getDeclaredField("DEFAULT_ELEMENT_NAME").get(null);
return (T) getBuilderFactory().getBuilder(defaultElementName).buildObject(defaultElementName);
}
catch (NoSuchFieldException | IllegalAccessException e) {
throw new Saml2Exception("Could not create SAML object", e);
}
<T> T buildSamlObject(QName qName) {
return (T) xmlObjectBuilderFactory.getBuilder(qName).buildObject(qName);
}
XMLObject resolve(String xml) {
return resolve(xml.getBytes(StandardCharsets.UTF_8));
}
String toXml(XMLObject object, List<Saml2X509Credential> signingCredentials, String localSpEntityId) {
if (object instanceof SignableSAMLObject && null != hasSigningCredential(signingCredentials)) {
signXmlObject(
(SignableSAMLObject) object,
signingCredentials,
localSpEntityId
);
}
String serialize(XMLObject xmlObject) {
final MarshallerFactory marshallerFactory = XMLObjectProviderRegistrySupport.getMarshallerFactory();
try {
Element element = marshallerFactory.getMarshaller(object).marshall(object);
Element element = marshallerFactory.getMarshaller(xmlObject).marshall(xmlObject);
return SerializeSupport.nodeToString(element);
} catch (MarshallingException e) {
throw new Saml2Exception(e);
}
}
String serialize(AuthnRequest authnRequest, List<Saml2X509Credential> signingCredentials) {
if (hasSigningCredential(signingCredentials) != null) {
signAuthnRequest(authnRequest, signingCredentials);
}
return serialize(authnRequest);
}
/**
* Returns query parameter after creating a Query String signature
* All return values are unencoded and will need to be encoded prior to sending
@ -306,15 +302,15 @@ final class OpenSamlImplementation {
return cred;
}
private void signXmlObject(SignableSAMLObject object, List<Saml2X509Credential> signingCredentials, String entityId) {
private void signAuthnRequest(AuthnRequest authnRequest, List<Saml2X509Credential> signingCredentials) {
SignatureSigningParameters parameters = new SignatureSigningParameters();
Credential credential = getSigningCredential(signingCredentials, entityId);
Credential credential = getSigningCredential(signingCredentials, authnRequest.getIssuer().getValue());
parameters.setSigningCredential(credential);
parameters.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
parameters.setSignatureReferenceDigestMethod(SignatureConstants.ALGO_ID_DIGEST_SHA256);
parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS);
try {
SignatureSupport.signObject(object, parameters);
SignatureSupport.signObject(authnRequest, parameters);
} catch (MarshallingException | SignatureException | SecurityException e) {
throw new Saml2Exception(e);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,8 +20,6 @@ import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import org.springframework.security.core.Authentication;
import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;
import org.joda.time.DateTime;
@ -37,12 +35,14 @@ import org.opensaml.saml.saml2.core.EncryptedID;
import org.opensaml.saml.saml2.core.NameID;
import org.opensaml.saml.saml2.core.Response;
import org.springframework.security.core.Authentication;
import static java.util.Collections.emptyList;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationObjects.assertion;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationObjects.response;
import static org.springframework.security.saml2.provider.service.authentication.Saml2CryptoTestSupport.encryptAssertion;
import static org.springframework.security.saml2.provider.service.authentication.Saml2CryptoTestSupport.encryptNameId;
import static org.springframework.security.saml2.provider.service.authentication.Saml2CryptoTestSupport.signXmlObject;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationObjects.assertion;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationObjects.response;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2X509Credentials.assertingPartyCredentials;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2X509Credentials.relyingPartyCredentials;
import static org.springframework.test.util.AssertionErrors.assertTrue;
@ -95,7 +95,7 @@ public class OpenSamlAuthenticationProviderTests {
@Test
public void authenticateWhenUnknownDataClassThenThrowAuthenticationException() {
Assertion assertion = defaultAssertion();
token = responseXml(assertion, idpEntityId);
token = responseXml(assertion);
exception.expect(authenticationMatcher(Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS));
provider.authenticate(token);
}
@ -116,7 +116,7 @@ public class OpenSamlAuthenticationProviderTests {
@Test
public void authenticateWhenInvalidDestinationThenThrowAuthenticationException() {
Response response = response(recipientUri + "invalid", idpEntityId);
token = responseXml(response, idpEntityId);
token = responseXml(response);
exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_DESTINATION));
provider.authenticate(token);
}
@ -124,7 +124,7 @@ public class OpenSamlAuthenticationProviderTests {
@Test
public void authenticateWhenNoAssertionsPresentThenThrowAuthenticationException() {
Response response = response(recipientUri, idpEntityId);
token = responseXml(response, idpEntityId);
token = responseXml(response);
exception.expect(
authenticationMatcher(
Saml2ErrorCodes.MALFORMED_RESPONSE_DATA,
@ -139,7 +139,7 @@ public class OpenSamlAuthenticationProviderTests {
Response response = response(recipientUri, idpEntityId);
Assertion assertion = defaultAssertion();
response.getAssertions().add(assertion);
token = responseXml(response, idpEntityId);
token = responseXml(response);
exception.expect(
authenticationMatcher(
Saml2ErrorCodes.INVALID_SIGNATURE
@ -164,7 +164,7 @@ public class OpenSamlAuthenticationProviderTests {
recipientEntityId
);
response.getAssertions().add(assertion);
token = responseXml(response, idpEntityId);
token = responseXml(response);
exception.expect(
authenticationMatcher(
@ -185,7 +185,7 @@ public class OpenSamlAuthenticationProviderTests {
recipientEntityId
);
response.getAssertions().add(assertion);
token = responseXml(response, idpEntityId);
token = responseXml(response);
exception.expect(
authenticationMatcher(
@ -209,7 +209,7 @@ public class OpenSamlAuthenticationProviderTests {
recipientEntityId
);
response.getAssertions().add(assertion);
token = responseXml(response, idpEntityId);
token = responseXml(response);
exception.expect(
authenticationMatcher(
@ -232,7 +232,7 @@ public class OpenSamlAuthenticationProviderTests {
recipientEntityId
);
response.getAssertions().add(assertion);
token = responseXml(response, idpEntityId);
token = responseXml(response);
provider.authenticate(token);
}
@ -242,7 +242,7 @@ public class OpenSamlAuthenticationProviderTests {
Assertion assertion = defaultAssertion();
EncryptedAssertion encryptedAssertion = encryptAssertion(assertion, assertingPartyCredentials());
response.getEncryptedAssertions().add(encryptedAssertion);
token = responseXml(response, idpEntityId);
token = responseXml(response);
exception.expect(
authenticationMatcher(
Saml2ErrorCodes.INVALID_SIGNATURE
@ -262,7 +262,7 @@ public class OpenSamlAuthenticationProviderTests {
);
EncryptedAssertion encryptedAssertion = encryptAssertion(assertion, assertingPartyCredentials());
response.getEncryptedAssertions().add(encryptedAssertion);
token = responseXml(response, idpEntityId);
token = responseXml(response);
provider.authenticate(token);
}
@ -277,7 +277,7 @@ public class OpenSamlAuthenticationProviderTests {
assertingPartyCredentials(),
recipientEntityId
);
token = responseXml(response, idpEntityId);
token = responseXml(response);
provider.authenticate(token);
}
@ -295,7 +295,7 @@ public class OpenSamlAuthenticationProviderTests {
recipientEntityId
);
response.getAssertions().add(assertion);
token = responseXml(response, idpEntityId);
token = responseXml(response);
provider.authenticate(token);
}
@ -306,7 +306,7 @@ public class OpenSamlAuthenticationProviderTests {
Assertion assertion = defaultAssertion();
EncryptedAssertion encryptedAssertion = encryptAssertion(assertion, assertingPartyCredentials());
response.getEncryptedAssertions().add(encryptedAssertion);
token = responseXml(response, idpEntityId);
token = responseXml(response);
token = new Saml2AuthenticationToken(
token.getSaml2Response(),
@ -331,7 +331,7 @@ public class OpenSamlAuthenticationProviderTests {
Assertion assertion = defaultAssertion();
EncryptedAssertion encryptedAssertion = encryptAssertion(assertion, assertingPartyCredentials());
response.getEncryptedAssertions().add(encryptedAssertion);
token = responseXml(response, idpEntityId);
token = responseXml(response);
token = new Saml2AuthenticationToken(
token.getSaml2Response(),
@ -361,7 +361,7 @@ public class OpenSamlAuthenticationProviderTests {
);
EncryptedAssertion encryptedAssertion = encryptAssertion(assertion, assertingPartyCredentials());
response.getEncryptedAssertions().add(encryptedAssertion);
token = responseXml(response, idpEntityId);
token = responseXml(response);
Saml2Authentication authentication = (Saml2Authentication) provider.authenticate(token);
@ -381,11 +381,8 @@ public class OpenSamlAuthenticationProviderTests {
);
}
private Saml2AuthenticationToken responseXml(
XMLObject object,
String issuerEntityId
) {
String xml = saml.toXml(object, emptyList(), issuerEntityId);
private Saml2AuthenticationToken responseXml(XMLObject assertion) {
String xml = saml.serialize(assertion);
return new Saml2AuthenticationToken(
xml,
recipientUri,

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,8 @@
package org.springframework.security.saml2.provider.service.authentication;
import java.util.UUID;
import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.opensaml.saml.common.SAMLVersion;
@ -28,13 +30,11 @@ import org.opensaml.saml.saml2.core.Subject;
import org.opensaml.saml.saml2.core.SubjectConfirmation;
import org.opensaml.saml.saml2.core.SubjectConfirmationData;
import java.util.UUID;
final class TestSaml2AuthenticationObjects {
private static OpenSamlImplementation saml = OpenSamlImplementation.getInstance();
static Response response(String destination, String issuerEntityId) {
Response response = saml.buildSAMLObject(Response.class);
Response response = saml.buildSamlObject(Response.DEFAULT_ELEMENT_NAME);
response.setID("R"+UUID.randomUUID().toString());
response.setIssueInstant(DateTime.now());
response.setVersion(SAMLVersion.VERSION_20);
@ -49,7 +49,7 @@ final class TestSaml2AuthenticationObjects {
String recipientEntityId,
String recipientUri
) {
Assertion assertion = saml.buildSAMLObject(Assertion.class);
Assertion assertion = saml.buildSamlObject(Assertion.DEFAULT_ELEMENT_NAME);
assertion.setID("A"+ UUID.randomUUID().toString());
assertion.setIssueInstant(DateTime.now());
assertion.setVersion(SAMLVersion.VERSION_20);
@ -69,13 +69,13 @@ final class TestSaml2AuthenticationObjects {
static Issuer issuer(String entityId) {
Issuer issuer = saml.buildSAMLObject(Issuer.class);
Issuer issuer = saml.buildSamlObject(Issuer.DEFAULT_ELEMENT_NAME);
issuer.setValue(entityId);
return issuer;
}
static Subject subject(String principalName) {
Subject subject = saml.buildSAMLObject(Subject.class);
Subject subject = saml.buildSamlObject(Subject.DEFAULT_ELEMENT_NAME);
if (principalName != null) {
subject.setNameID(nameId(principalName));
@ -85,17 +85,17 @@ final class TestSaml2AuthenticationObjects {
}
static NameID nameId(String principalName) {
NameID nameId = saml.buildSAMLObject(NameID.class);
NameID nameId = saml.buildSamlObject(NameID.DEFAULT_ELEMENT_NAME);
nameId.setValue(principalName);
return nameId;
}
static SubjectConfirmation subjectConfirmation() {
return saml.buildSAMLObject(SubjectConfirmation.class);
return saml.buildSamlObject(SubjectConfirmation.DEFAULT_ELEMENT_NAME);
}
static SubjectConfirmationData subjectConfirmationData(String recipient) {
SubjectConfirmationData subject = saml.buildSAMLObject(SubjectConfirmationData.class);
SubjectConfirmationData subject = saml.buildSamlObject(SubjectConfirmationData.DEFAULT_ELEMENT_NAME);
subject.setRecipient(recipient);
subject.setNotBefore(DateTime.now().minus(Duration.millis(5 * 60 * 1000)));
subject.setNotOnOrAfter(DateTime.now().plus(Duration.millis(5 * 60 * 1000)));
@ -103,7 +103,7 @@ final class TestSaml2AuthenticationObjects {
}
static Conditions conditions() {
Conditions conditions = saml.buildSAMLObject(Conditions.class);
Conditions conditions = saml.buildSamlObject(Conditions.DEFAULT_ELEMENT_NAME);
conditions.setNotBefore(DateTime.now().minus(Duration.millis(5 * 60 * 1000)));
conditions.setNotOnOrAfter(DateTime.now().plus(Duration.millis(5 * 60 * 1000)));
return conditions;