Polish OpenSamlAuthenticationRequestFactory

- Refactored to use SAMLMetadataSignatureSigningParametersResolver

Issue gh-7758
This commit is contained in:
Josh Cummings 2020-09-25 16:27:01 -06:00
parent 2ee455b7bf
commit a36baffb3a
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
2 changed files with 117 additions and 51 deletions

View File

@ -21,11 +21,14 @@ import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.time.Clock;
import java.time.Instant;
import java.util.Collection;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
import org.joda.time.DateTime;
import org.opensaml.core.config.ConfigurationService;
@ -37,15 +40,18 @@ import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder;
import org.opensaml.saml.saml2.core.impl.AuthnRequestMarshaller;
import org.opensaml.saml.saml2.core.impl.IssuerBuilder;
import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver;
import org.opensaml.security.SecurityException;
import org.opensaml.security.credential.BasicCredential;
import org.opensaml.security.credential.Credential;
import org.opensaml.security.credential.CredentialSupport;
import org.opensaml.security.credential.UsageType;
import org.opensaml.xmlsec.SignatureSigningParameters;
import org.opensaml.xmlsec.SignatureSigningParametersResolver;
import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion;
import org.opensaml.xmlsec.crypto.XMLSigningUtil;
import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration;
import org.opensaml.xmlsec.signature.support.SignatureConstants;
import org.opensaml.xmlsec.signature.support.SignatureException;
import org.opensaml.xmlsec.signature.support.SignatureSupport;
import org.w3c.dom.Element;
@ -58,6 +64,7 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;
import org.springframework.web.util.UriUtils;
/**
@ -105,9 +112,17 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
request.getAssertionConsumerServiceUrl(), this.protocolBindingResolver.convert(null));
for (org.springframework.security.saml2.credentials.Saml2X509Credential credential : request.getCredentials()) {
if (credential.isSigningCredential()) {
Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(),
request.getIssuer());
return serialize(sign(authnRequest, cred));
X509Certificate certificate = credential.getCertificate();
PrivateKey privateKey = credential.getPrivateKey();
BasicCredential cred = CredentialSupport.getSimpleCredential(certificate, privateKey);
cred.setEntityId(request.getIssuer());
cred.setUsageType(UsageType.SIGNING);
SignatureSigningParameters parameters = new SignatureSigningParameters();
parameters.setSigningCredential(cred);
parameters.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
parameters.setSignatureReferenceDigestMethod(SignatureConstants.ALGO_ID_DIGEST_SHA256);
parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS);
return serialize(sign(authnRequest, parameters));
}
}
throw new IllegalArgumentException("No signing credential provided");
@ -132,16 +147,13 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml));
result.samlRequest(deflatedAndEncoded).relayState(context.getRelayState());
if (context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
Collection<Saml2X509Credential> signingCredentials = context.getRelyingPartyRegistration()
.getSigningX509Credentials();
for (Saml2X509Credential credential : signingCredentials) {
Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(), "");
Map<String, String> signedParams = signQueryParameters(cred, deflatedAndEncoded,
context.getRelayState());
return result.samlRequest(signedParams.get("SAMLRequest")).relayState(signedParams.get("RelayState"))
.sigAlg(signedParams.get("SigAlg")).signature(signedParams.get("Signature")).build();
Map<String, String> parameters = new LinkedHashMap<>();
parameters.put("SAMLRequest", deflatedAndEncoded);
if (StringUtils.hasText(context.getRelayState())) {
parameters.put("RelayState", context.getRelayState());
}
throw new Saml2Exception("No signing credential provided");
sign(parameters, context.getRelyingPartyRegistration());
return result.sigAlg(parameters.get("SigAlg")).signature(parameters.get("Signature")).build();
}
return result.build();
}
@ -211,59 +223,39 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
}
private AuthnRequest sign(AuthnRequest authnRequest, RelyingPartyRegistration relyingPartyRegistration) {
for (Saml2X509Credential credential : relyingPartyRegistration.getSigningX509Credentials()) {
Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(),
relyingPartyRegistration.getEntityId());
return sign(authnRequest, cred);
}
throw new IllegalArgumentException("No signing credential provided");
SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration);
return sign(authnRequest, parameters);
}
private AuthnRequest sign(AuthnRequest authnRequest, Credential credential) {
SignatureSigningParameters parameters = new SignatureSigningParameters();
parameters.setSigningCredential(credential);
parameters.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
parameters.setSignatureReferenceDigestMethod(SignatureConstants.ALGO_ID_DIGEST_SHA256);
parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS);
private AuthnRequest sign(AuthnRequest authnRequest, SignatureSigningParameters parameters) {
try {
SignatureSupport.signObject(authnRequest, parameters);
return authnRequest;
}
catch (MarshallingException | SignatureException | SecurityException ex) {
catch (Exception ex) {
throw new Saml2Exception(ex);
}
}
private Credential getSigningCredential(X509Certificate certificate, PrivateKey privateKey, String entityId) {
BasicCredential cred = CredentialSupport.getSimpleCredential(certificate, privateKey);
cred.setEntityId(entityId);
cred.setUsageType(UsageType.SIGNING);
return cred;
private void sign(Map<String, String> components, RelyingPartyRegistration relyingPartyRegistration) {
SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration);
sign(components, parameters);
}
private Map<String, String> signQueryParameters(Credential credential, String samlRequest, String relayState) {
Assert.notNull(samlRequest, "samlRequest cannot be null");
String algorithmUri = SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256;
StringBuilder queryString = new StringBuilder();
queryString.append("SAMLRequest").append("=").append(UriUtils.encode(samlRequest, StandardCharsets.ISO_8859_1))
.append("&");
if (StringUtils.hasText(relayState)) {
queryString.append("RelayState").append("=")
.append(UriUtils.encode(relayState, StandardCharsets.ISO_8859_1)).append("&");
private void sign(Map<String, String> components, SignatureSigningParameters parameters) {
Credential credential = parameters.getSigningCredential();
String algorithmUri = parameters.getSignatureAlgorithm();
components.put("SigAlg", algorithmUri);
UriComponentsBuilder builder = UriComponentsBuilder.newInstance();
for (Map.Entry<String, String> component : components.entrySet()) {
builder.queryParam(component.getKey(), UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1));
}
queryString.append("SigAlg").append("=").append(UriUtils.encode(algorithmUri, StandardCharsets.ISO_8859_1));
String queryString = builder.build(true).toString().substring(1);
try {
byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri,
queryString.toString().getBytes(StandardCharsets.UTF_8));
queryString.getBytes(StandardCharsets.UTF_8));
String b64Signature = Saml2Utils.samlEncode(rawSignature);
Map<String, String> result = new LinkedHashMap<>();
result.put("SAMLRequest", samlRequest);
if (StringUtils.hasText(relayState)) {
result.put("RelayState", relayState);
}
result.put("SigAlg", algorithmUri);
result.put("Signature", b64Signature);
return result;
components.put("Signature", b64Signature);
}
catch (SecurityException ex) {
throw new Saml2Exception(ex);
@ -280,4 +272,40 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
}
}
private SignatureSigningParameters resolveSigningParameters(RelyingPartyRegistration relyingPartyRegistration) {
List<Credential> credentials = resolveSigningCredentials(relyingPartyRegistration);
List<String> algorithms = Collections.singletonList(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
List<String> digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256);
String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS;
SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver();
CriteriaSet criteria = new CriteriaSet();
BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration();
signingConfiguration.setSigningCredentials(credentials);
signingConfiguration.setSignatureAlgorithms(algorithms);
signingConfiguration.setSignatureReferenceDigestMethods(digests);
signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization);
criteria.add(new SignatureSigningConfigurationCriterion(signingConfiguration));
try {
SignatureSigningParameters parameters = resolver.resolveSingle(criteria);
Assert.notNull(parameters, "Failed to resolve any signing credential");
return parameters;
}
catch (Exception ex) {
throw new Saml2Exception(ex);
}
}
private List<Credential> resolveSigningCredentials(RelyingPartyRegistration relyingPartyRegistration) {
List<Credential> credentials = new ArrayList<>();
for (Saml2X509Credential x509Credential : relyingPartyRegistration.getSigningX509Credentials()) {
X509Certificate certificate = x509Credential.getCertificate();
PrivateKey privateKey = x509Credential.getPrivateKey();
BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey);
credential.setEntityId(relyingPartyRegistration.getEntityId());
credential.setUsageType(UsageType.SIGNING);
credentials.add(credential);
}
return credentials;
}
}

View File

@ -26,16 +26,20 @@ import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller;
import org.opensaml.xmlsec.signature.support.SignatureConstants;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
@ -110,6 +114,28 @@ public class OpenSamlAuthenticationRequestFactoryTests {
assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
}
@Test
public void createRedirectAuthenticationRequestWhenSignRequestThenSignatureIsPresent() {
this.context = this.contextBuilder.relayState("Relay State Value")
.relyingPartyRegistration(this.relyingPartyRegistration).build();
Saml2RedirectAuthenticationRequest request = this.factory.createRedirectAuthenticationRequest(this.context);
assertThat(request.getRelayState()).isEqualTo("Relay State Value");
assertThat(request.getSigAlg()).isEqualTo(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
assertThat(request.getSignature()).isNotNull();
}
@Test
public void createRedirectAuthenticationRequestWhenSignRequestThenCredentialIsRequired() {
Saml2X509Credential credential = org.springframework.security.saml2.core.TestSaml2X509Credentials
.relyingPartyVerifyingCredential();
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.noCredentials()
.assertingPartyDetails((party) -> party.verificationX509Credentials((c) -> c.add(credential))).build();
this.context = this.contextBuilder.relayState("Relay State Value").relyingPartyRegistration(registration)
.build();
assertThatExceptionOfType(Saml2Exception.class)
.isThrownBy(() -> this.factory.createPostAuthenticationRequest(this.context));
}
@Test
public void createPostAuthenticationRequestWhenNotSignRequestThenNoSignatureIsPresent() {
this.context = this.contextBuilder.relayState("Relay State Value")
@ -139,6 +165,18 @@ public class OpenSamlAuthenticationRequestFactoryTests {
.contains("ds:Signature");
}
@Test
public void createPostAuthenticationRequestWhenSignRequestThenCredentialIsRequired() {
Saml2X509Credential credential = org.springframework.security.saml2.core.TestSaml2X509Credentials
.relyingPartyVerifyingCredential();
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.noCredentials()
.assertingPartyDetails((party) -> party.verificationX509Credentials((c) -> c.add(credential))).build();
this.context = this.contextBuilder.relayState("Relay State Value").relyingPartyRegistration(registration)
.build();
assertThatExceptionOfType(Saml2Exception.class)
.isThrownBy(() -> this.factory.createPostAuthenticationRequest(this.context));
}
@Test
public void createAuthenticationRequestWhenDefaultThenReturnsPostBinding() {
AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST);