Saml2AuthenticationToken takes a RelyingPartyRegistration

Closes gh-8845
This commit is contained in:
Josh Cummings 2020-07-17 12:19:27 -06:00
parent 44ec061f05
commit a54e77a3c3
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
4 changed files with 93 additions and 44 deletions

View File

@ -480,13 +480,10 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
private SignatureTrustEngine buildSignatureTrustEngine(Saml2AuthenticationToken token) { private SignatureTrustEngine buildSignatureTrustEngine(Saml2AuthenticationToken token) {
Set<Credential> credentials = new HashSet<>(); Set<Credential> credentials = new HashSet<>();
for (Saml2X509Credential key : token.getX509Credentials()) { for (Saml2X509Credential key : token.getRelyingPartyRegistration().getVerificationCredentials()) {
if (!key.isSignatureVerficationCredential()) {
continue;
}
BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); BasicX509Credential cred = new BasicX509Credential(key.getCertificate());
cred.setUsageType(UsageType.SIGNING); cred.setUsageType(UsageType.SIGNING);
cred.setEntityId(token.getIdpEntityId()); cred.setEntityId(token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId());
credentials.add(cred); credentials.add(cred);
} }
CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials);
@ -506,13 +503,14 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
Map<String, Saml2AuthenticationException> validationExceptions = new LinkedHashMap<>(); Map<String, Saml2AuthenticationException> validationExceptions = new LinkedHashMap<>();
String destination = response.getDestination(); String destination = response.getDestination();
if (StringUtils.hasText(destination) && !destination.equals(token.getRecipientUri())) { String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
if (StringUtils.hasText(destination) && !destination.equals(location)) {
String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]"; String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]";
validationExceptions.put(INVALID_DESTINATION, authException(INVALID_DESTINATION, message)); validationExceptions.put(INVALID_DESTINATION, authException(INVALID_DESTINATION, message));
} }
String issuer = response.getIssuer().getValue(); String issuer = response.getIssuer().getValue();
String assertingPartyEntityId = token.getIdpEntityId(); String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId();
if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) { if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID()); String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
validationExceptions.put(INVALID_ISSUER, authException(INVALID_ISSUER, message)); validationExceptions.put(INVALID_ISSUER, authException(INVALID_ISSUER, message));
@ -538,11 +536,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return encrypted -> { return encrypted -> {
Saml2AuthenticationException last = Saml2AuthenticationException last =
authException(DECRYPTION_ERROR, "No valid decryption credentials found."); authException(DECRYPTION_ERROR, "No valid decryption credentials found.");
List<Saml2X509Credential> decryptionCredentials = token.getX509Credentials(); List<Saml2X509Credential> decryptionCredentials = token.getRelyingPartyRegistration().getDecryptionCredentials();
for (Saml2X509Credential key : decryptionCredentials) { for (Saml2X509Credential key : decryptionCredentials) {
if (!key.isDecryptionCredential()) {
continue;
}
Decrypter decrypter = getDecrypter(key); Decrypter decrypter = getDecrypter(key);
try { try {
return decrypter.decrypt(encrypted); return decrypter.decrypt(encrypted);
@ -623,11 +618,10 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
private SignatureTrustEngine buildSignatureTrustEngine(Saml2AuthenticationToken token) { private SignatureTrustEngine buildSignatureTrustEngine(Saml2AuthenticationToken token) {
Set<Credential> credentials = new HashSet<>(); Set<Credential> credentials = new HashSet<>();
for (Saml2X509Credential key : token.getX509Credentials()) { for (Saml2X509Credential key : token.getRelyingPartyRegistration().getVerificationCredentials()) {
if (!key.isSignatureVerficationCredential()) continue;
BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); BasicX509Credential cred = new BasicX509Credential(key.getCertificate());
cred.setUsageType(UsageType.SIGNING); cred.setUsageType(UsageType.SIGNING);
cred.setEntityId(token.getIdpEntityId()); cred.setEntityId(token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId());
credentials.add(cred); credentials.add(cred);
} }
CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials);
@ -709,10 +703,12 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
} }
}, },
token -> { token -> {
String audience = token.getRelyingPartyRegistration().getEntityId();
String recipient = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
Map<String, Object> params = new HashMap<>(); Map<String, Object> params = new HashMap<>();
params.put(CLOCK_SKEW, Duration.ofMinutes(5).toMillis()); params.put(CLOCK_SKEW, Duration.ofMinutes(5).toMillis());
params.put(COND_VALID_AUDIENCES, singleton(token.getIdpEntityId())); params.put(COND_VALID_AUDIENCES, singleton(audience));
params.put(SC_VALID_RECIPIENTS, singleton(token.getRecipientUri())); params.put(SC_VALID_RECIPIENTS, singleton(recipient));
params.putAll(this.validationContextParameters); params.putAll(this.validationContextParameters);
return new ValidationContext(params); return new ValidationContext(params);
}); });
@ -734,9 +730,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return encrypted -> { return encrypted -> {
Saml2AuthenticationException last = Saml2AuthenticationException last =
authException(DECRYPTION_ERROR, "No valid decryption credentials found."); authException(DECRYPTION_ERROR, "No valid decryption credentials found.");
List<Saml2X509Credential> decryptionCredentials = token.getX509Credentials(); List<Saml2X509Credential> decryptionCredentials = token.getRelyingPartyRegistration().getDecryptionCredentials();
for (Saml2X509Credential key : decryptionCredentials) { for (Saml2X509Credential key : decryptionCredentials) {
if (!key.isDecryptionCredential()) continue;
Decrypter decrypter = getDecrypter(key); Decrypter decrypter = getDecrypter(key);
try { try {
return (NameID) decrypter.decrypt(encrypted); return (NameID) decrypter.decrypt(encrypted);

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,23 +16,51 @@
package org.springframework.security.saml2.provider.service.authentication; package org.springframework.security.saml2.provider.service.authentication;
import java.util.Collections;
import java.util.List;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.saml2.credentials.Saml2X509Credential; import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.util.Assert;
import java.util.List; import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRegistrationId;
/** /**
* Represents an incoming SAML 2.0 response containing an assertion that has not been validated. * Represents an incoming SAML 2.0 response containing an assertion that has not been validated.
* {@link Saml2AuthenticationToken#isAuthenticated()} will always return false. * {@link Saml2AuthenticationToken#isAuthenticated()} will always return false.
*
* @since 5.2 * @since 5.2
* @author Filip Hanik
* @author Josh Cummings
*/ */
public class Saml2AuthenticationToken extends AbstractAuthenticationToken { public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
private final RelyingPartyRegistration relyingPartyRegistration;
private final String saml2Response; private final String saml2Response;
private final String recipientUri;
private String idpEntityId; /**
private String localSpEntityId; * Creates a {@link Saml2AuthenticationToken} with the provided parameters
private List<Saml2X509Credential> credentials; *
* Note that the given {@link RelyingPartyRegistration} should have all its
* templates resolved at this point. See
* {@link org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter}
* for an example of performing that resolution.
*
* @param relyingPartyRegistration the resolved {@link RelyingPartyRegistration} to use
* @param saml2Response the SAML 2.0 response to authenticate
*
* @since 5.4
*/
public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration,
String saml2Response) {
super(Collections.emptyList());
Assert.notNull(relyingPartyRegistration, "relyingPartyRegistration cannot be null");
Assert.notNull(saml2Response, "saml2Response cannot be null");
this.relyingPartyRegistration = relyingPartyRegistration;
this.saml2Response = saml2Response;
}
/** /**
* Creates an authentication token from an incoming SAML 2 Response object * Creates an authentication token from an incoming SAML 2 Response object
@ -41,18 +69,24 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
* @param idpEntityId the entity ID of the asserting entity * @param idpEntityId the entity ID of the asserting entity
* @param localSpEntityId the configured local SP, the relying party, entity ID * @param localSpEntityId the configured local SP, the relying party, entity ID
* @param credentials the credentials configured for signature verification and decryption * @param credentials the credentials configured for signature verification and decryption
* @deprecated Use {@link Saml2AuthenticationToken(RelyingPartyRegistration, String)} instead
*/ */
@Deprecated
public Saml2AuthenticationToken(String saml2Response, public Saml2AuthenticationToken(String saml2Response,
String recipientUri, String recipientUri,
String idpEntityId, String idpEntityId,
String localSpEntityId, String localSpEntityId,
List<Saml2X509Credential> credentials) { List<Saml2X509Credential> credentials) {
super(null); super(null);
this.relyingPartyRegistration = withRegistrationId(idpEntityId)
.entityId(localSpEntityId)
.assertionConsumerServiceLocation(recipientUri)
.credentials(c -> c.addAll(credentials))
.assertingPartyDetails(assertingParty -> assertingParty
.entityId(idpEntityId)
.singleSignOnServiceLocation(idpEntityId))
.build();
this.saml2Response = saml2Response; this.saml2Response = saml2Response;
this.recipientUri = recipientUri;
this.idpEntityId = idpEntityId;
this.localSpEntityId = localSpEntityId;
this.credentials = credentials;
} }
/** /**
@ -73,6 +107,16 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
return null; return null;
} }
/**
* Get the resolved {@link RelyingPartyRegistration} associated with the request
*
* @return the resolved {@link RelyingPartyRegistration}
* @since 5.4
*/
public RelyingPartyRegistration getRelyingPartyRegistration() {
return this.relyingPartyRegistration;
}
/** /**
* Returns inflated and decoded XML representation of the SAML 2 Response * Returns inflated and decoded XML representation of the SAML 2 Response
* @return inflated and decoded XML representation of the SAML 2 Response * @return inflated and decoded XML representation of the SAML 2 Response
@ -84,25 +128,31 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
/** /**
* Returns the URI that the SAML 2 Response object came in on * Returns the URI that the SAML 2 Response object came in on
* @return URI as a string * @return URI as a string
* @deprecated Use {@link #getRelyingPartyRegistration().getAssertionConsumerServiceLocation()} instead
*/ */
@Deprecated
public String getRecipientUri() { public String getRecipientUri() {
return this.recipientUri; return this.relyingPartyRegistration.getAssertionConsumerServiceLocation();
} }
/** /**
* Returns the configured entity ID of the receiving relying party, SP * Returns the configured entity ID of the receiving relying party, SP
* @return an entityID for the configured local relying party * @return an entityID for the configured local relying party
* @deprecated Use {@link #getRelyingPartyRegistration().getEntityId()} instead
*/ */
@Deprecated
public String getLocalSpEntityId() { public String getLocalSpEntityId() {
return this.localSpEntityId; return this.relyingPartyRegistration.getEntityId();
} }
/** /**
* Returns all the credentials associated with the relying party configuraiton * Returns all the credentials associated with the relying party configuraiton
* @return * @return
* @deprecated Get the credentials through {@link #getRelyingPartyRegistration()} instead
*/ */
@Deprecated
public List<Saml2X509Credential> getX509Credentials() { public List<Saml2X509Credential> getX509Credentials() {
return this.credentials; return this.relyingPartyRegistration.getCredentials();
} }
/** /**
@ -126,8 +176,10 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
/** /**
* Returns the configured IDP, asserting party, entity ID * Returns the configured IDP, asserting party, entity ID
* @return a string representing the entity ID * @return a string representing the entity ID
* @deprecated Use {@link #getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId()} instead
*/ */
@Deprecated
public String getIdpEntityId() { public String getIdpEntityId() {
return this.idpEntityId; return this.relyingPartyRegistration.getAssertingPartyDetails().getEntityId();
} }
} }

View File

@ -35,6 +35,7 @@ import org.springframework.util.Assert;
import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.charset.StandardCharsets.UTF_8;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND; import static org.springframework.security.saml2.core.Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND;
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
import static org.springframework.util.StringUtils.hasText; import static org.springframework.util.StringUtils.hasText;
/** /**
@ -98,14 +99,15 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
throw new Saml2AuthenticationException(saml2Error); throw new Saml2AuthenticationException(saml2Error);
} }
String applicationUri = Saml2ServletUtils.getApplicationUri(request); String applicationUri = Saml2ServletUtils.getApplicationUri(request);
String localSpEntityId = Saml2ServletUtils.resolveUrlTemplate(rp.getEntityId(), applicationUri, rp); String relyingPartyEntityId = Saml2ServletUtils.resolveUrlTemplate(rp.getEntityId(), applicationUri, rp);
final Saml2AuthenticationToken authentication = new Saml2AuthenticationToken( String assertionConsumerServiceLocation = Saml2ServletUtils.resolveUrlTemplate(
responseXml, rp.getAssertionConsumerServiceLocation(), applicationUri, rp);
request.getRequestURL().toString(), RelyingPartyRegistration relyingPartyRegistration = withRelyingPartyRegistration(rp)
rp.getAssertingPartyDetails().getEntityId(), .entityId(relyingPartyEntityId)
localSpEntityId, .assertionConsumerServiceLocation(assertionConsumerServiceLocation)
rp.getCredentials() .build();
); Saml2AuthenticationToken authentication = new Saml2AuthenticationToken(
relyingPartyRegistration, responseXml);
return getAuthenticationManager().authenticate(authentication); return getAuthenticationManager().authenticate(authentication);
} }

View File

@ -111,14 +111,14 @@ public class OpenSamlAuthenticationProviderTests {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS)); this.exception.expect(authenticationMatcher(Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS));
Assertion assertion = this.saml.buildSamlObject(Assertion.DEFAULT_ELEMENT_NAME); Assertion assertion = this.saml.buildSamlObject(Assertion.DEFAULT_ELEMENT_NAME);
this.provider.authenticate(token(this.saml.serialize(assertion))); this.provider.authenticate(token(this.saml.serialize(assertion), relyingPartyVerifyingCredential()));
} }
@Test @Test
public void authenticateWhenXmlErrorThenThrowAuthenticationException() { public void authenticateWhenXmlErrorThenThrowAuthenticationException() {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA)); this.exception.expect(authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA));
Saml2AuthenticationToken token = token("invalid xml"); Saml2AuthenticationToken token = token("invalid xml", relyingPartyVerifyingCredential());
this.provider.authenticate(token); this.provider.authenticate(token);
} }
@ -149,7 +149,7 @@ public class OpenSamlAuthenticationProviderTests {
Response response = response(); Response response = response();
response.getAssertions().add(assertion()); response.getAssertions().add(assertion());
Saml2AuthenticationToken token = token(response); Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
this.provider.authenticate(token); this.provider.authenticate(token);
} }
@ -316,7 +316,7 @@ public class OpenSamlAuthenticationProviderTests {
Response response = response(); Response response = response();
EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential()); EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential());
response.getEncryptedAssertions().add(encryptedAssertion); response.getEncryptedAssertions().add(encryptedAssertion);
Saml2AuthenticationToken token = token(this.saml.serialize(response)); Saml2AuthenticationToken token = token(this.saml.serialize(response), relyingPartyVerifyingCredential());
this.provider.authenticate(token); this.provider.authenticate(token);
} }