Simplify OpenSamlImplementation

- Removed reflection usage
- Simplified method signatures

Issue gh-7711
Fixes gh-8147
This commit is contained in:
Josh Cummings 2020-03-18 16:58:26 -06:00
parent 088ea07f07
commit 7f2f210eb8
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
4 changed files with 94 additions and 101 deletions

View File

@ -16,22 +16,22 @@
package org.springframework.security.saml2.provider.service.authentication; package org.springframework.security.saml2.provider.service.authentication;
import org.joda.time.DateTime;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.Issuer;
import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest.Builder;
import org.springframework.util.Assert;
import java.time.Clock; import java.time.Clock;
import java.time.Instant; import java.time.Instant;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
import org.joda.time.DateTime;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.Issuer;
import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest.Builder;
import org.springframework.util.Assert;
import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Collections.emptyList;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDeflate; import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDeflate;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlEncode; import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlEncode;
@ -46,7 +46,9 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
@Override @Override
@Deprecated @Deprecated
public String createAuthenticationRequest(Saml2AuthenticationRequest request) { public String createAuthenticationRequest(Saml2AuthenticationRequest request) {
return createAuthenticationRequest(request, request.getCredentials()); AuthnRequest authnRequest = createAuthnRequest(request.getIssuer(),
request.getDestination(), request.getAssertionConsumerServiceUrl());
return this.saml.serialize(authnRequest, request.getCredentials());
} }
/** /**
@ -54,11 +56,11 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
*/ */
@Override @Override
public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) { public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) {
List<Saml2X509Credential> signingCredentials = context.getRelyingPartyRegistration().getProviderDetails().isSignAuthNRequest() ? AuthnRequest authnRequest = createAuthnRequest(context);
context.getRelyingPartyRegistration().getSigningCredentials() : String xml = context.getRelyingPartyRegistration().getProviderDetails().isSignAuthNRequest() ?
emptyList(); this.saml.serialize(authnRequest, context.getRelyingPartyRegistration().getSigningCredentials()) :
this.saml.serialize(authnRequest);
String xml = createAuthenticationRequest(context, signingCredentials);
return Saml2PostAuthenticationRequest.withAuthenticationRequestContext(context) return Saml2PostAuthenticationRequest.withAuthenticationRequestContext(context)
.samlRequest(samlEncode(xml.getBytes(UTF_8))) .samlRequest(samlEncode(xml.getBytes(UTF_8)))
.build(); .build();
@ -69,7 +71,8 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
*/ */
@Override @Override
public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext context) { public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext context) {
String xml = createAuthenticationRequest(context, emptyList()); AuthnRequest authnRequest = createAuthnRequest(context);
String xml = this.saml.serialize(authnRequest);
Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context); Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context);
String deflatedAndEncoded = samlEncode(samlDeflate(xml)); String deflatedAndEncoded = samlEncode(samlDeflate(xml));
result.samlRequest(deflatedAndEncoded) result.samlRequest(deflatedAndEncoded)
@ -91,27 +94,24 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
return result.build(); return result.build();
} }
private String createAuthenticationRequest(Saml2AuthenticationRequestContext request, List<Saml2X509Credential> credentials) { private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
return createAuthenticationRequest(Saml2AuthenticationRequest.withAuthenticationRequestContext(request).build(), credentials); return createAuthnRequest(context.getIssuer(),
context.getDestination(), context.getAssertionConsumerServiceUrl());
} }
private String createAuthenticationRequest(Saml2AuthenticationRequest context, List<Saml2X509Credential> credentials) { private AuthnRequest createAuthnRequest(String issuer, String destination, String assertionConsumerServiceUrl) {
AuthnRequest auth = this.saml.buildSAMLObject(AuthnRequest.class); AuthnRequest auth = this.saml.buildSamlObject(AuthnRequest.DEFAULT_ELEMENT_NAME);
auth.setID("ARQ" + UUID.randomUUID().toString().substring(1)); auth.setID("ARQ" + UUID.randomUUID().toString().substring(1));
auth.setIssueInstant(new DateTime(this.clock.millis())); auth.setIssueInstant(new DateTime(this.clock.millis()));
auth.setForceAuthn(Boolean.FALSE); auth.setForceAuthn(Boolean.FALSE);
auth.setIsPassive(Boolean.FALSE); auth.setIsPassive(Boolean.FALSE);
auth.setProtocolBinding(protocolBinding); auth.setProtocolBinding(protocolBinding);
Issuer issuer = this.saml.buildSAMLObject(Issuer.class); Issuer iss = this.saml.buildSamlObject(Issuer.DEFAULT_ELEMENT_NAME);
issuer.setValue(context.getIssuer()); iss.setValue(issuer);
auth.setIssuer(issuer); auth.setIssuer(iss);
auth.setDestination(context.getDestination()); auth.setDestination(destination);
auth.setAssertionConsumerServiceURL(context.getAssertionConsumerServiceUrl()); auth.setAssertionConsumerServiceURL(assertionConsumerServiceUrl);
return this.saml.toXml( return auth;
auth,
credentials,
context.getIssuer()
);
} }
/** /**

View File

@ -16,6 +16,16 @@
package org.springframework.security.saml2.provider.service.authentication; package org.springframework.security.saml2.provider.service.authentication;
import java.io.ByteArrayInputStream;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.xml.XMLConstants;
import javax.xml.namespace.QName;
import net.shibboleth.utilities.java.support.component.ComponentInitializationException; import net.shibboleth.utilities.java.support.component.ComponentInitializationException;
import net.shibboleth.utilities.java.support.xml.BasicParserPool; import net.shibboleth.utilities.java.support.xml.BasicParserPool;
import net.shibboleth.utilities.java.support.xml.SerializeSupport; import net.shibboleth.utilities.java.support.xml.SerializeSupport;
@ -24,13 +34,14 @@ import org.opensaml.core.config.ConfigurationService;
import org.opensaml.core.config.InitializationException; import org.opensaml.core.config.InitializationException;
import org.opensaml.core.config.InitializationService; import org.opensaml.core.config.InitializationService;
import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.XMLObjectBuilderFactory;
import org.opensaml.core.xml.config.XMLObjectProviderRegistry; import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.MarshallerFactory; import org.opensaml.core.xml.io.MarshallerFactory;
import org.opensaml.core.xml.io.MarshallingException; import org.opensaml.core.xml.io.MarshallingException;
import org.opensaml.core.xml.io.UnmarshallerFactory; import org.opensaml.core.xml.io.UnmarshallerFactory;
import org.opensaml.core.xml.io.UnmarshallingException; import org.opensaml.core.xml.io.UnmarshallingException;
import org.opensaml.saml.common.SignableSAMLObject; import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver; import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver;
import org.opensaml.security.SecurityException; import org.opensaml.security.SecurityException;
import org.opensaml.security.credential.BasicCredential; import org.opensaml.security.credential.BasicCredential;
@ -47,28 +58,17 @@ import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyR
import org.opensaml.xmlsec.signature.support.SignatureConstants; import org.opensaml.xmlsec.signature.support.SignatureConstants;
import org.opensaml.xmlsec.signature.support.SignatureException; import org.opensaml.xmlsec.signature.support.SignatureException;
import org.opensaml.xmlsec.signature.support.SignatureSupport; import org.opensaml.xmlsec.signature.support.SignatureSupport;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.authentication.Saml2Utils;
import org.springframework.util.Assert;
import org.springframework.web.util.UriUtils;
import org.w3c.dom.Document; import org.w3c.dom.Document;
import org.w3c.dom.Element; import org.w3c.dom.Element;
import javax.xml.XMLConstants; import org.springframework.security.saml2.Saml2Exception;
import javax.xml.namespace.QName; import org.springframework.security.saml2.credentials.Saml2X509Credential;
import java.io.ByteArrayInputStream; import org.springframework.util.Assert;
import java.nio.charset.Charset; import org.springframework.web.util.UriUtils;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import static java.lang.Boolean.FALSE; import static java.lang.Boolean.FALSE;
import static java.lang.Boolean.TRUE; import static java.lang.Boolean.TRUE;
import static java.util.Arrays.asList; import static java.util.Arrays.asList;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
import static org.springframework.util.StringUtils.hasText; import static org.springframework.util.StringUtils.hasText;
/** /**
@ -76,6 +76,8 @@ import static org.springframework.util.StringUtils.hasText;
*/ */
final class OpenSamlImplementation { final class OpenSamlImplementation {
private static OpenSamlImplementation instance = new OpenSamlImplementation(); private static OpenSamlImplementation instance = new OpenSamlImplementation();
private static XMLObjectBuilderFactory xmlObjectBuilderFactory =
XMLObjectProviderRegistrySupport.getBuilderFactory();
private final BasicParserPool parserPool = new BasicParserPool(); private final BasicParserPool parserPool = new BasicParserPool();
private final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver( private final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver(
@ -167,37 +169,31 @@ final class OpenSamlImplementation {
return this.encryptedKeyResolver; return this.encryptedKeyResolver;
} }
<T> T buildSAMLObject(final Class<T> clazz) { <T> T buildSamlObject(QName qName) {
try { return (T) xmlObjectBuilderFactory.getBuilder(qName).buildObject(qName);
QName defaultElementName = (QName) clazz.getDeclaredField("DEFAULT_ELEMENT_NAME").get(null);
return (T) getBuilderFactory().getBuilder(defaultElementName).buildObject(defaultElementName);
}
catch (NoSuchFieldException | IllegalAccessException e) {
throw new Saml2Exception("Could not create SAML object", e);
}
} }
XMLObject resolve(String xml) { XMLObject resolve(String xml) {
return resolve(xml.getBytes(StandardCharsets.UTF_8)); return resolve(xml.getBytes(StandardCharsets.UTF_8));
} }
String toXml(XMLObject object, List<Saml2X509Credential> signingCredentials, String localSpEntityId) { String serialize(XMLObject xmlObject) {
if (object instanceof SignableSAMLObject && null != hasSigningCredential(signingCredentials)) {
signXmlObject(
(SignableSAMLObject) object,
signingCredentials,
localSpEntityId
);
}
final MarshallerFactory marshallerFactory = XMLObjectProviderRegistrySupport.getMarshallerFactory(); final MarshallerFactory marshallerFactory = XMLObjectProviderRegistrySupport.getMarshallerFactory();
try { try {
Element element = marshallerFactory.getMarshaller(object).marshall(object); Element element = marshallerFactory.getMarshaller(xmlObject).marshall(xmlObject);
return SerializeSupport.nodeToString(element); return SerializeSupport.nodeToString(element);
} catch (MarshallingException e) { } catch (MarshallingException e) {
throw new Saml2Exception(e); throw new Saml2Exception(e);
} }
} }
String serialize(AuthnRequest authnRequest, List<Saml2X509Credential> signingCredentials) {
if (hasSigningCredential(signingCredentials) != null) {
signAuthnRequest(authnRequest, signingCredentials);
}
return serialize(authnRequest);
}
/** /**
* Returns query parameter after creating a Query String signature * Returns query parameter after creating a Query String signature
* All return values are unencoded and will need to be encoded prior to sending * All return values are unencoded and will need to be encoded prior to sending
@ -306,15 +302,15 @@ final class OpenSamlImplementation {
return cred; return cred;
} }
private void signXmlObject(SignableSAMLObject object, List<Saml2X509Credential> signingCredentials, String entityId) { private void signAuthnRequest(AuthnRequest authnRequest, List<Saml2X509Credential> signingCredentials) {
SignatureSigningParameters parameters = new SignatureSigningParameters(); SignatureSigningParameters parameters = new SignatureSigningParameters();
Credential credential = getSigningCredential(signingCredentials, entityId); Credential credential = getSigningCredential(signingCredentials, authnRequest.getIssuer().getValue());
parameters.setSigningCredential(credential); parameters.setSigningCredential(credential);
parameters.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); parameters.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
parameters.setSignatureReferenceDigestMethod(SignatureConstants.ALGO_ID_DIGEST_SHA256); parameters.setSignatureReferenceDigestMethod(SignatureConstants.ALGO_ID_DIGEST_SHA256);
parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS); parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS);
try { try {
SignatureSupport.signObject(object, parameters); SignatureSupport.signObject(authnRequest, parameters);
} catch (MarshallingException | SignatureException | SecurityException e) { } catch (MarshallingException | SignatureException | SecurityException e) {
throw new Saml2Exception(e); throw new Saml2Exception(e);
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -20,8 +20,6 @@ import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.ObjectOutputStream; import java.io.ObjectOutputStream;
import org.springframework.security.core.Authentication;
import org.hamcrest.BaseMatcher; import org.hamcrest.BaseMatcher;
import org.hamcrest.Description; import org.hamcrest.Description;
import org.joda.time.DateTime; import org.joda.time.DateTime;
@ -37,12 +35,14 @@ import org.opensaml.saml.saml2.core.EncryptedID;
import org.opensaml.saml.saml2.core.NameID; import org.opensaml.saml.saml2.core.NameID;
import org.opensaml.saml.saml2.core.Response; import org.opensaml.saml.saml2.core.Response;
import org.springframework.security.core.Authentication;
import static java.util.Collections.emptyList; import static java.util.Collections.emptyList;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationObjects.assertion;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationObjects.response;
import static org.springframework.security.saml2.provider.service.authentication.Saml2CryptoTestSupport.encryptAssertion; import static org.springframework.security.saml2.provider.service.authentication.Saml2CryptoTestSupport.encryptAssertion;
import static org.springframework.security.saml2.provider.service.authentication.Saml2CryptoTestSupport.encryptNameId; import static org.springframework.security.saml2.provider.service.authentication.Saml2CryptoTestSupport.encryptNameId;
import static org.springframework.security.saml2.provider.service.authentication.Saml2CryptoTestSupport.signXmlObject; import static org.springframework.security.saml2.provider.service.authentication.Saml2CryptoTestSupport.signXmlObject;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationObjects.assertion;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationObjects.response;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2X509Credentials.assertingPartyCredentials; import static org.springframework.security.saml2.provider.service.authentication.TestSaml2X509Credentials.assertingPartyCredentials;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2X509Credentials.relyingPartyCredentials; import static org.springframework.security.saml2.provider.service.authentication.TestSaml2X509Credentials.relyingPartyCredentials;
import static org.springframework.test.util.AssertionErrors.assertTrue; import static org.springframework.test.util.AssertionErrors.assertTrue;
@ -95,7 +95,7 @@ public class OpenSamlAuthenticationProviderTests {
@Test @Test
public void authenticateWhenUnknownDataClassThenThrowAuthenticationException() { public void authenticateWhenUnknownDataClassThenThrowAuthenticationException() {
Assertion assertion = defaultAssertion(); Assertion assertion = defaultAssertion();
token = responseXml(assertion, idpEntityId); token = responseXml(assertion);
exception.expect(authenticationMatcher(Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS)); exception.expect(authenticationMatcher(Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS));
provider.authenticate(token); provider.authenticate(token);
} }
@ -116,7 +116,7 @@ public class OpenSamlAuthenticationProviderTests {
@Test @Test
public void authenticateWhenInvalidDestinationThenThrowAuthenticationException() { public void authenticateWhenInvalidDestinationThenThrowAuthenticationException() {
Response response = response(recipientUri + "invalid", idpEntityId); Response response = response(recipientUri + "invalid", idpEntityId);
token = responseXml(response, idpEntityId); token = responseXml(response);
exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_DESTINATION)); exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_DESTINATION));
provider.authenticate(token); provider.authenticate(token);
} }
@ -124,7 +124,7 @@ public class OpenSamlAuthenticationProviderTests {
@Test @Test
public void authenticateWhenNoAssertionsPresentThenThrowAuthenticationException() { public void authenticateWhenNoAssertionsPresentThenThrowAuthenticationException() {
Response response = response(recipientUri, idpEntityId); Response response = response(recipientUri, idpEntityId);
token = responseXml(response, idpEntityId); token = responseXml(response);
exception.expect( exception.expect(
authenticationMatcher( authenticationMatcher(
Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, Saml2ErrorCodes.MALFORMED_RESPONSE_DATA,
@ -139,7 +139,7 @@ public class OpenSamlAuthenticationProviderTests {
Response response = response(recipientUri, idpEntityId); Response response = response(recipientUri, idpEntityId);
Assertion assertion = defaultAssertion(); Assertion assertion = defaultAssertion();
response.getAssertions().add(assertion); response.getAssertions().add(assertion);
token = responseXml(response, idpEntityId); token = responseXml(response);
exception.expect( exception.expect(
authenticationMatcher( authenticationMatcher(
Saml2ErrorCodes.INVALID_SIGNATURE Saml2ErrorCodes.INVALID_SIGNATURE
@ -164,7 +164,7 @@ public class OpenSamlAuthenticationProviderTests {
recipientEntityId recipientEntityId
); );
response.getAssertions().add(assertion); response.getAssertions().add(assertion);
token = responseXml(response, idpEntityId); token = responseXml(response);
exception.expect( exception.expect(
authenticationMatcher( authenticationMatcher(
@ -185,7 +185,7 @@ public class OpenSamlAuthenticationProviderTests {
recipientEntityId recipientEntityId
); );
response.getAssertions().add(assertion); response.getAssertions().add(assertion);
token = responseXml(response, idpEntityId); token = responseXml(response);
exception.expect( exception.expect(
authenticationMatcher( authenticationMatcher(
@ -209,7 +209,7 @@ public class OpenSamlAuthenticationProviderTests {
recipientEntityId recipientEntityId
); );
response.getAssertions().add(assertion); response.getAssertions().add(assertion);
token = responseXml(response, idpEntityId); token = responseXml(response);
exception.expect( exception.expect(
authenticationMatcher( authenticationMatcher(
@ -232,7 +232,7 @@ public class OpenSamlAuthenticationProviderTests {
recipientEntityId recipientEntityId
); );
response.getAssertions().add(assertion); response.getAssertions().add(assertion);
token = responseXml(response, idpEntityId); token = responseXml(response);
provider.authenticate(token); provider.authenticate(token);
} }
@ -242,7 +242,7 @@ public class OpenSamlAuthenticationProviderTests {
Assertion assertion = defaultAssertion(); Assertion assertion = defaultAssertion();
EncryptedAssertion encryptedAssertion = encryptAssertion(assertion, assertingPartyCredentials()); EncryptedAssertion encryptedAssertion = encryptAssertion(assertion, assertingPartyCredentials());
response.getEncryptedAssertions().add(encryptedAssertion); response.getEncryptedAssertions().add(encryptedAssertion);
token = responseXml(response, idpEntityId); token = responseXml(response);
exception.expect( exception.expect(
authenticationMatcher( authenticationMatcher(
Saml2ErrorCodes.INVALID_SIGNATURE Saml2ErrorCodes.INVALID_SIGNATURE
@ -262,7 +262,7 @@ public class OpenSamlAuthenticationProviderTests {
); );
EncryptedAssertion encryptedAssertion = encryptAssertion(assertion, assertingPartyCredentials()); EncryptedAssertion encryptedAssertion = encryptAssertion(assertion, assertingPartyCredentials());
response.getEncryptedAssertions().add(encryptedAssertion); response.getEncryptedAssertions().add(encryptedAssertion);
token = responseXml(response, idpEntityId); token = responseXml(response);
provider.authenticate(token); provider.authenticate(token);
} }
@ -277,7 +277,7 @@ public class OpenSamlAuthenticationProviderTests {
assertingPartyCredentials(), assertingPartyCredentials(),
recipientEntityId recipientEntityId
); );
token = responseXml(response, idpEntityId); token = responseXml(response);
provider.authenticate(token); provider.authenticate(token);
} }
@ -295,7 +295,7 @@ public class OpenSamlAuthenticationProviderTests {
recipientEntityId recipientEntityId
); );
response.getAssertions().add(assertion); response.getAssertions().add(assertion);
token = responseXml(response, idpEntityId); token = responseXml(response);
provider.authenticate(token); provider.authenticate(token);
} }
@ -306,7 +306,7 @@ public class OpenSamlAuthenticationProviderTests {
Assertion assertion = defaultAssertion(); Assertion assertion = defaultAssertion();
EncryptedAssertion encryptedAssertion = encryptAssertion(assertion, assertingPartyCredentials()); EncryptedAssertion encryptedAssertion = encryptAssertion(assertion, assertingPartyCredentials());
response.getEncryptedAssertions().add(encryptedAssertion); response.getEncryptedAssertions().add(encryptedAssertion);
token = responseXml(response, idpEntityId); token = responseXml(response);
token = new Saml2AuthenticationToken( token = new Saml2AuthenticationToken(
token.getSaml2Response(), token.getSaml2Response(),
@ -331,7 +331,7 @@ public class OpenSamlAuthenticationProviderTests {
Assertion assertion = defaultAssertion(); Assertion assertion = defaultAssertion();
EncryptedAssertion encryptedAssertion = encryptAssertion(assertion, assertingPartyCredentials()); EncryptedAssertion encryptedAssertion = encryptAssertion(assertion, assertingPartyCredentials());
response.getEncryptedAssertions().add(encryptedAssertion); response.getEncryptedAssertions().add(encryptedAssertion);
token = responseXml(response, idpEntityId); token = responseXml(response);
token = new Saml2AuthenticationToken( token = new Saml2AuthenticationToken(
token.getSaml2Response(), token.getSaml2Response(),
@ -361,7 +361,7 @@ public class OpenSamlAuthenticationProviderTests {
); );
EncryptedAssertion encryptedAssertion = encryptAssertion(assertion, assertingPartyCredentials()); EncryptedAssertion encryptedAssertion = encryptAssertion(assertion, assertingPartyCredentials());
response.getEncryptedAssertions().add(encryptedAssertion); response.getEncryptedAssertions().add(encryptedAssertion);
token = responseXml(response, idpEntityId); token = responseXml(response);
Saml2Authentication authentication = (Saml2Authentication) provider.authenticate(token); Saml2Authentication authentication = (Saml2Authentication) provider.authenticate(token);
@ -381,11 +381,8 @@ public class OpenSamlAuthenticationProviderTests {
); );
} }
private Saml2AuthenticationToken responseXml( private Saml2AuthenticationToken responseXml(XMLObject assertion) {
XMLObject object, String xml = saml.serialize(assertion);
String issuerEntityId
) {
String xml = saml.toXml(object, emptyList(), issuerEntityId);
return new Saml2AuthenticationToken( return new Saml2AuthenticationToken(
xml, xml,
recipientUri, recipientUri,

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,6 +16,8 @@
package org.springframework.security.saml2.provider.service.authentication; package org.springframework.security.saml2.provider.service.authentication;
import java.util.UUID;
import org.joda.time.DateTime; import org.joda.time.DateTime;
import org.joda.time.Duration; import org.joda.time.Duration;
import org.opensaml.saml.common.SAMLVersion; import org.opensaml.saml.common.SAMLVersion;
@ -28,13 +30,11 @@ import org.opensaml.saml.saml2.core.Subject;
import org.opensaml.saml.saml2.core.SubjectConfirmation; import org.opensaml.saml.saml2.core.SubjectConfirmation;
import org.opensaml.saml.saml2.core.SubjectConfirmationData; import org.opensaml.saml.saml2.core.SubjectConfirmationData;
import java.util.UUID;
final class TestSaml2AuthenticationObjects { final class TestSaml2AuthenticationObjects {
private static OpenSamlImplementation saml = OpenSamlImplementation.getInstance(); private static OpenSamlImplementation saml = OpenSamlImplementation.getInstance();
static Response response(String destination, String issuerEntityId) { static Response response(String destination, String issuerEntityId) {
Response response = saml.buildSAMLObject(Response.class); Response response = saml.buildSamlObject(Response.DEFAULT_ELEMENT_NAME);
response.setID("R"+UUID.randomUUID().toString()); response.setID("R"+UUID.randomUUID().toString());
response.setIssueInstant(DateTime.now()); response.setIssueInstant(DateTime.now());
response.setVersion(SAMLVersion.VERSION_20); response.setVersion(SAMLVersion.VERSION_20);
@ -49,7 +49,7 @@ final class TestSaml2AuthenticationObjects {
String recipientEntityId, String recipientEntityId,
String recipientUri String recipientUri
) { ) {
Assertion assertion = saml.buildSAMLObject(Assertion.class); Assertion assertion = saml.buildSamlObject(Assertion.DEFAULT_ELEMENT_NAME);
assertion.setID("A"+ UUID.randomUUID().toString()); assertion.setID("A"+ UUID.randomUUID().toString());
assertion.setIssueInstant(DateTime.now()); assertion.setIssueInstant(DateTime.now());
assertion.setVersion(SAMLVersion.VERSION_20); assertion.setVersion(SAMLVersion.VERSION_20);
@ -69,13 +69,13 @@ final class TestSaml2AuthenticationObjects {
static Issuer issuer(String entityId) { static Issuer issuer(String entityId) {
Issuer issuer = saml.buildSAMLObject(Issuer.class); Issuer issuer = saml.buildSamlObject(Issuer.DEFAULT_ELEMENT_NAME);
issuer.setValue(entityId); issuer.setValue(entityId);
return issuer; return issuer;
} }
static Subject subject(String principalName) { static Subject subject(String principalName) {
Subject subject = saml.buildSAMLObject(Subject.class); Subject subject = saml.buildSamlObject(Subject.DEFAULT_ELEMENT_NAME);
if (principalName != null) { if (principalName != null) {
subject.setNameID(nameId(principalName)); subject.setNameID(nameId(principalName));
@ -85,17 +85,17 @@ final class TestSaml2AuthenticationObjects {
} }
static NameID nameId(String principalName) { static NameID nameId(String principalName) {
NameID nameId = saml.buildSAMLObject(NameID.class); NameID nameId = saml.buildSamlObject(NameID.DEFAULT_ELEMENT_NAME);
nameId.setValue(principalName); nameId.setValue(principalName);
return nameId; return nameId;
} }
static SubjectConfirmation subjectConfirmation() { static SubjectConfirmation subjectConfirmation() {
return saml.buildSAMLObject(SubjectConfirmation.class); return saml.buildSamlObject(SubjectConfirmation.DEFAULT_ELEMENT_NAME);
} }
static SubjectConfirmationData subjectConfirmationData(String recipient) { static SubjectConfirmationData subjectConfirmationData(String recipient) {
SubjectConfirmationData subject = saml.buildSAMLObject(SubjectConfirmationData.class); SubjectConfirmationData subject = saml.buildSamlObject(SubjectConfirmationData.DEFAULT_ELEMENT_NAME);
subject.setRecipient(recipient); subject.setRecipient(recipient);
subject.setNotBefore(DateTime.now().minus(Duration.millis(5 * 60 * 1000))); subject.setNotBefore(DateTime.now().minus(Duration.millis(5 * 60 * 1000)));
subject.setNotOnOrAfter(DateTime.now().plus(Duration.millis(5 * 60 * 1000))); subject.setNotOnOrAfter(DateTime.now().plus(Duration.millis(5 * 60 * 1000)));
@ -103,7 +103,7 @@ final class TestSaml2AuthenticationObjects {
} }
static Conditions conditions() { static Conditions conditions() {
Conditions conditions = saml.buildSAMLObject(Conditions.class); Conditions conditions = saml.buildSamlObject(Conditions.DEFAULT_ELEMENT_NAME);
conditions.setNotBefore(DateTime.now().minus(Duration.millis(5 * 60 * 1000))); conditions.setNotBefore(DateTime.now().minus(Duration.millis(5 * 60 * 1000)));
conditions.setNotOnOrAfter(DateTime.now().plus(Duration.millis(5 * 60 * 1000))); conditions.setNotOnOrAfter(DateTime.now().plus(Duration.millis(5 * 60 * 1000)));
return conditions; return conditions;