From a54e77a3c3d33cb63803628e113720379c981e51 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Fri, 17 Jul 2020 12:19:27 -0600 Subject: [PATCH] Saml2AuthenticationToken takes a RelyingPartyRegistration Closes gh-8845 --- .../OpenSamlAuthenticationProvider.java | 31 +++---- .../Saml2AuthenticationToken.java | 80 +++++++++++++++---- .../Saml2WebSsoAuthenticationFilter.java | 18 +++-- .../OpenSamlAuthenticationProviderTests.java | 8 +- 4 files changed, 93 insertions(+), 44 deletions(-) diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java index 8ed7bf0c94..8dc7fb1385 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java @@ -480,13 +480,10 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi private SignatureTrustEngine buildSignatureTrustEngine(Saml2AuthenticationToken token) { Set credentials = new HashSet<>(); - for (Saml2X509Credential key : token.getX509Credentials()) { - if (!key.isSignatureVerficationCredential()) { - continue; - } + for (Saml2X509Credential key : token.getRelyingPartyRegistration().getVerificationCredentials()) { BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); cred.setUsageType(UsageType.SIGNING); - cred.setEntityId(token.getIdpEntityId()); + cred.setEntityId(token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId()); credentials.add(cred); } CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); @@ -506,13 +503,14 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi Map validationExceptions = new LinkedHashMap<>(); String destination = response.getDestination(); - if (StringUtils.hasText(destination) && !destination.equals(token.getRecipientUri())) { + String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); + if (StringUtils.hasText(destination) && !destination.equals(location)) { String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]"; validationExceptions.put(INVALID_DESTINATION, authException(INVALID_DESTINATION, message)); } String issuer = response.getIssuer().getValue(); - String assertingPartyEntityId = token.getIdpEntityId(); + String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId(); if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) { String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID()); validationExceptions.put(INVALID_ISSUER, authException(INVALID_ISSUER, message)); @@ -538,11 +536,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi return encrypted -> { Saml2AuthenticationException last = authException(DECRYPTION_ERROR, "No valid decryption credentials found."); - List decryptionCredentials = token.getX509Credentials(); + List decryptionCredentials = token.getRelyingPartyRegistration().getDecryptionCredentials(); for (Saml2X509Credential key : decryptionCredentials) { - if (!key.isDecryptionCredential()) { - continue; - } Decrypter decrypter = getDecrypter(key); try { return decrypter.decrypt(encrypted); @@ -623,11 +618,10 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi private SignatureTrustEngine buildSignatureTrustEngine(Saml2AuthenticationToken token) { Set credentials = new HashSet<>(); - for (Saml2X509Credential key : token.getX509Credentials()) { - if (!key.isSignatureVerficationCredential()) continue; + for (Saml2X509Credential key : token.getRelyingPartyRegistration().getVerificationCredentials()) { BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); cred.setUsageType(UsageType.SIGNING); - cred.setEntityId(token.getIdpEntityId()); + cred.setEntityId(token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId()); credentials.add(cred); } CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); @@ -709,10 +703,12 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi } }, token -> { + String audience = token.getRelyingPartyRegistration().getEntityId(); + String recipient = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); Map params = new HashMap<>(); params.put(CLOCK_SKEW, Duration.ofMinutes(5).toMillis()); - params.put(COND_VALID_AUDIENCES, singleton(token.getIdpEntityId())); - params.put(SC_VALID_RECIPIENTS, singleton(token.getRecipientUri())); + params.put(COND_VALID_AUDIENCES, singleton(audience)); + params.put(SC_VALID_RECIPIENTS, singleton(recipient)); params.putAll(this.validationContextParameters); return new ValidationContext(params); }); @@ -734,9 +730,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi return encrypted -> { Saml2AuthenticationException last = authException(DECRYPTION_ERROR, "No valid decryption credentials found."); - List decryptionCredentials = token.getX509Credentials(); + List decryptionCredentials = token.getRelyingPartyRegistration().getDecryptionCredentials(); for (Saml2X509Credential key : decryptionCredentials) { - if (!key.isDecryptionCredential()) continue; Decrypter decrypter = getDecrypter(key); try { return (NameID) decrypter.decrypt(encrypted); diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java index a19a024eed..22146994f0 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,23 +16,51 @@ package org.springframework.security.saml2.provider.service.authentication; +import java.util.Collections; +import java.util.List; + import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.saml2.credentials.Saml2X509Credential; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.util.Assert; -import java.util.List; +import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRegistrationId; /** * Represents an incoming SAML 2.0 response containing an assertion that has not been validated. * {@link Saml2AuthenticationToken#isAuthenticated()} will always return false. + * * @since 5.2 + * @author Filip Hanik + * @author Josh Cummings */ public class Saml2AuthenticationToken extends AbstractAuthenticationToken { + private final RelyingPartyRegistration relyingPartyRegistration; private final String saml2Response; - private final String recipientUri; - private String idpEntityId; - private String localSpEntityId; - private List credentials; + + /** + * Creates a {@link Saml2AuthenticationToken} with the provided parameters + * + * Note that the given {@link RelyingPartyRegistration} should have all its + * templates resolved at this point. See + * {@link org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter} + * for an example of performing that resolution. + * + * @param relyingPartyRegistration the resolved {@link RelyingPartyRegistration} to use + * @param saml2Response the SAML 2.0 response to authenticate + * + * @since 5.4 + */ + public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, + String saml2Response) { + + super(Collections.emptyList()); + Assert.notNull(relyingPartyRegistration, "relyingPartyRegistration cannot be null"); + Assert.notNull(saml2Response, "saml2Response cannot be null"); + this.relyingPartyRegistration = relyingPartyRegistration; + this.saml2Response = saml2Response; + } /** * Creates an authentication token from an incoming SAML 2 Response object @@ -41,18 +69,24 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken { * @param idpEntityId the entity ID of the asserting entity * @param localSpEntityId the configured local SP, the relying party, entity ID * @param credentials the credentials configured for signature verification and decryption + * @deprecated Use {@link Saml2AuthenticationToken(RelyingPartyRegistration, String)} instead */ + @Deprecated public Saml2AuthenticationToken(String saml2Response, String recipientUri, String idpEntityId, String localSpEntityId, List credentials) { super(null); + this.relyingPartyRegistration = withRegistrationId(idpEntityId) + .entityId(localSpEntityId) + .assertionConsumerServiceLocation(recipientUri) + .credentials(c -> c.addAll(credentials)) + .assertingPartyDetails(assertingParty -> assertingParty + .entityId(idpEntityId) + .singleSignOnServiceLocation(idpEntityId)) + .build(); this.saml2Response = saml2Response; - this.recipientUri = recipientUri; - this.idpEntityId = idpEntityId; - this.localSpEntityId = localSpEntityId; - this.credentials = credentials; } /** @@ -73,6 +107,16 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken { return null; } + /** + * Get the resolved {@link RelyingPartyRegistration} associated with the request + * + * @return the resolved {@link RelyingPartyRegistration} + * @since 5.4 + */ + public RelyingPartyRegistration getRelyingPartyRegistration() { + return this.relyingPartyRegistration; + } + /** * Returns inflated and decoded XML representation of the SAML 2 Response * @return inflated and decoded XML representation of the SAML 2 Response @@ -84,25 +128,31 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken { /** * Returns the URI that the SAML 2 Response object came in on * @return URI as a string + * @deprecated Use {@link #getRelyingPartyRegistration().getAssertionConsumerServiceLocation()} instead */ + @Deprecated public String getRecipientUri() { - return this.recipientUri; + return this.relyingPartyRegistration.getAssertionConsumerServiceLocation(); } /** * Returns the configured entity ID of the receiving relying party, SP * @return an entityID for the configured local relying party + * @deprecated Use {@link #getRelyingPartyRegistration().getEntityId()} instead */ + @Deprecated public String getLocalSpEntityId() { - return this.localSpEntityId; + return this.relyingPartyRegistration.getEntityId(); } /** * Returns all the credentials associated with the relying party configuraiton * @return + * @deprecated Get the credentials through {@link #getRelyingPartyRegistration()} instead */ + @Deprecated public List getX509Credentials() { - return this.credentials; + return this.relyingPartyRegistration.getCredentials(); } /** @@ -126,8 +176,10 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken { /** * Returns the configured IDP, asserting party, entity ID * @return a string representing the entity ID + * @deprecated Use {@link #getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId()} instead */ + @Deprecated public String getIdpEntityId() { - return this.idpEntityId; + return this.relyingPartyRegistration.getAssertingPartyDetails().getEntityId(); } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java index e04456b004..ddcc854d0e 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java @@ -35,6 +35,7 @@ import org.springframework.util.Assert; import static java.nio.charset.StandardCharsets.UTF_8; import static org.springframework.security.saml2.core.Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND; +import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration; import static org.springframework.util.StringUtils.hasText; /** @@ -98,14 +99,15 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce throw new Saml2AuthenticationException(saml2Error); } String applicationUri = Saml2ServletUtils.getApplicationUri(request); - String localSpEntityId = Saml2ServletUtils.resolveUrlTemplate(rp.getEntityId(), applicationUri, rp); - final Saml2AuthenticationToken authentication = new Saml2AuthenticationToken( - responseXml, - request.getRequestURL().toString(), - rp.getAssertingPartyDetails().getEntityId(), - localSpEntityId, - rp.getCredentials() - ); + String relyingPartyEntityId = Saml2ServletUtils.resolveUrlTemplate(rp.getEntityId(), applicationUri, rp); + String assertionConsumerServiceLocation = Saml2ServletUtils.resolveUrlTemplate( + rp.getAssertionConsumerServiceLocation(), applicationUri, rp); + RelyingPartyRegistration relyingPartyRegistration = withRelyingPartyRegistration(rp) + .entityId(relyingPartyEntityId) + .assertionConsumerServiceLocation(assertionConsumerServiceLocation) + .build(); + Saml2AuthenticationToken authentication = new Saml2AuthenticationToken( + relyingPartyRegistration, responseXml); return getAuthenticationManager().authenticate(authentication); } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java index 64f5946dba..69983e956c 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java @@ -111,14 +111,14 @@ public class OpenSamlAuthenticationProviderTests { this.exception.expect(authenticationMatcher(Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS)); Assertion assertion = this.saml.buildSamlObject(Assertion.DEFAULT_ELEMENT_NAME); - this.provider.authenticate(token(this.saml.serialize(assertion))); + this.provider.authenticate(token(this.saml.serialize(assertion), relyingPartyVerifyingCredential())); } @Test public void authenticateWhenXmlErrorThenThrowAuthenticationException() { this.exception.expect(authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA)); - Saml2AuthenticationToken token = token("invalid xml"); + Saml2AuthenticationToken token = token("invalid xml", relyingPartyVerifyingCredential()); this.provider.authenticate(token); } @@ -149,7 +149,7 @@ public class OpenSamlAuthenticationProviderTests { Response response = response(); response.getAssertions().add(assertion()); - Saml2AuthenticationToken token = token(response); + Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); this.provider.authenticate(token); } @@ -316,7 +316,7 @@ public class OpenSamlAuthenticationProviderTests { Response response = response(); EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); - Saml2AuthenticationToken token = token(this.saml.serialize(response)); + Saml2AuthenticationToken token = token(this.saml.serialize(response), relyingPartyVerifyingCredential()); this.provider.authenticate(token); }