Add Response to Authentication Conversion Support

Closes gh-8010
This commit is contained in:
Josh Cummings 2020-08-18 15:57:02 -06:00
parent 0c696dd58b
commit da7477cd41
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
3 changed files with 128 additions and 9 deletions

View File

@ -28,7 +28,6 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import javax.annotation.Nonnull;
import javax.xml.namespace.QName;
@ -185,8 +184,10 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
private GrantedAuthoritiesMapper authoritiesMapper = (a -> a);
private Duration responseTimeValidationSkew = Duration.ofMinutes(5);
private Function<Saml2AuthenticationToken, Converter<Response, AbstractAuthenticationToken>> authenticationConverter =
token -> response -> {
private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter =
responseToken -> {
Response response = responseToken.response;
Saml2AuthenticationToken token = responseToken.token;
Assertion assertion = CollectionUtils.firstElement(response.getAssertions());
String username = assertion.getSubject().getNameID().getValue();
Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
@ -255,11 +256,42 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
this.assertionValidator = assertionValidator;
}
/**
* Set the {@link Converter} to use for converting a validated {@link Response} into
* an {@link AbstractAuthenticationToken}.
*
* You can delegate to the default behavior by calling {@link #createDefaultResponseAuthenticationConverter()}
* like so:
*
* <pre>
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
* Converter&lt;ResponseToken, Saml2Authentication&gt; authenticationConverter =
* createDefaultResponseAuthenticationConverter();
* provider.setResponseAuthenticationConverter(responseToken -> {
* Saml2Authentication authentication = authenticationConverter.convert(responseToken);
* User user = myUserRepository.findByUsername(authentication.getName());
* return new MyAuthentication(authentication, user);
* });
* </pre>
*
* This method takes precedence over {@link #setAuthoritiesExtractor(Converter)} and
* {@link #setAuthoritiesMapper(GrantedAuthoritiesMapper)}.
*
* @param responseAuthenticationConverter the {@link Converter} to use
* @since 5.4
*/
public void setResponseAuthenticationConverter(
Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter) {
Assert.notNull(responseAuthenticationConverter, "responseAuthenticationConverter cannot be null");
this.responseAuthenticationConverter = responseAuthenticationConverter;
}
/**
* Sets the {@link Converter} used for extracting assertion attributes that
* can be mapped to authorities.
* @param authoritiesExtractor the {@code Converter} used for mapping the
* assertion attributes to authorities
* @deprecated Use {@link #setResponseAuthenticationConverter(Converter)} instead
*/
public void setAuthoritiesExtractor(Converter<Assertion, Collection<? extends GrantedAuthority>> authoritiesExtractor) {
Assert.notNull(authoritiesExtractor, "authoritiesExtractor cannot be null");
@ -271,6 +303,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
* to a new set of authorities which will be associated to the {@link Saml2Authentication}.
* Note: This implementation is only retrieving
* @param authoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the user's authorities
* @deprecated Use {@link #setResponseAuthenticationConverter(Converter)} instead
*/
public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
notNull(authoritiesMapper, "authoritiesMapper cannot be null");
@ -286,6 +319,27 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
this.responseTimeValidationSkew = responseTimeValidationSkew;
}
/**
* Construct a default strategy for converting a SAML 2.0 Response and {@link Authentication}
* token into a {@link Saml2Authentication}
*
* @return the default response authentication converter strategy
* @since 5.4
*/
public static Converter<ResponseToken, Saml2Authentication>
createDefaultResponseAuthenticationConverter() {
return responseToken -> {
Saml2AuthenticationToken token = responseToken.token;
Response response = responseToken.response;
Assertion assertion = CollectionUtils.firstElement(response.getAssertions());
String username = assertion.getSubject().getNameID().getValue();
Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
return new Saml2Authentication(
new DefaultSaml2AuthenticatedPrincipal(username, attributes), token.getSaml2Response(),
Collections.singleton(new SimpleGrantedAuthority("ROLE_USER")));
};
}
/**
* @param authentication the authentication request object, must be of type
* {@link Saml2AuthenticationToken}
@ -300,7 +354,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
String serializedResponse = token.getSaml2Response();
Response response = parse(serializedResponse);
process(token, response);
return this.authenticationConverter.apply(token).convert(response);
return this.responseAuthenticationConverter.convert(new ResponseToken(response, token));
} catch (Saml2AuthenticationException e) {
throw e;
} catch (Exception e) {
@ -496,7 +550,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
}
}
private Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
private static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
Map<String, List<Object>> attributeMap = new LinkedHashMap<>();
for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) {
for (Attribute attribute : attributeStatement.getAttributes()) {
@ -515,7 +569,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return attributeMap;
}
private Object getXmlObjectValue(XMLObject xmlObject) {
private static Object getXmlObjectValue(XMLObject xmlObject) {
if (xmlObject instanceof XSAny) {
return ((XSAny) xmlObject).getTextContent();
}
@ -706,6 +760,29 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return new Saml2AuthenticationException(validationError(code, description), cause);
}
/**
* A tuple containing an OpenSAML {@link Response} and its associated authentication token.
*
* @since 5.4
*/
public static class ResponseToken {
private final Saml2AuthenticationToken token;
private final Response response;
ResponseToken(Response response, Saml2AuthenticationToken token) {
this.token = token;
this.response = response;
}
public Response getResponse() {
return this.response;
}
public Saml2AuthenticationToken getToken() {
return this.token;
}
}
/**
* A tuple containing an OpenSAML {@link Assertion} and its associated authentication token.
*

View File

@ -77,17 +77,20 @@ import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParamete
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_ASSERTION;
import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE;
import static org.springframework.security.saml2.core.Saml2ResponseValidatorResult.success;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
import static org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.createDefaultAssertionValidator;
import static org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.createDefaultResponseAuthenticationConverter;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signedResponseWithOneAssertion;
import static org.springframework.util.StringUtils.hasText;
/**
@ -103,6 +106,10 @@ public class OpenSamlAuthenticationProviderTests {
private static String ASSERTING_PARTY_ENTITY_ID = "https://some.idp.test/saml2/idp";
private OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
private Saml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal
("name", Collections.emptyMap());
private Saml2Authentication authentication = new Saml2Authentication
(this.principal, "response", Collections.emptyList());
@Rule
public ExpectedException exception = ExpectedException.none();
@ -380,7 +387,7 @@ public class OpenSamlAuthenticationProviderTests {
signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
when(validator.convert(any(OpenSamlAuthenticationProvider.AssertionToken.class)))
.thenReturn(Saml2ResponseValidatorResult.success());
.thenReturn(success());
provider.authenticate(token);
verify(validator).convert(any(OpenSamlAuthenticationProvider.AssertionToken.class));
}
@ -388,7 +395,7 @@ public class OpenSamlAuthenticationProviderTests {
@Test
public void authenticateWhenDefaultConditionValidatorNotUsedThenSignatureStillChecked() {
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
provider.setAssertionValidator(assertionToken -> Saml2ResponseValidatorResult.success());
provider.setAssertionValidator(assertionToken -> success());
Response response = response();
Assertion assertion = assertion();
signed(assertion, relyingPartyDecryptingCredential(), RELYING_PARTY_ENTITY_ID); // broken signature
@ -424,6 +431,35 @@ public class OpenSamlAuthenticationProviderTests {
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void createDefaultResponseAuthenticationConverterWhenResponseThenConverts() {
Response response = signedResponseWithOneAssertion();
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
OpenSamlAuthenticationProvider.ResponseToken responseToken =
new OpenSamlAuthenticationProvider.ResponseToken(response, token);
Saml2Authentication authentication = createDefaultResponseAuthenticationConverter()
.convert(responseToken);
assertThat(authentication.getName()).isEqualTo("test@saml.user");
}
@Test
public void authenticateWhenResponseAuthenticationConverterConfiguredThenUses() {
Converter<OpenSamlAuthenticationProvider.ResponseToken, Saml2Authentication> authenticationConverter =
mock(Converter.class);
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
provider.setResponseAuthenticationConverter(authenticationConverter);
Response response = signedResponseWithOneAssertion();
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
provider.authenticate(token);
verify(authenticationConverter).convert(any());
}
@Test
public void setResponseAuthenticationConverterWhenNullThenIllegalArgument() {
assertThatCode(() -> this.provider.setResponseAuthenticationConverter(null))
.isInstanceOf(IllegalArgumentException.class);
}
private <T extends XMLObject> T build(QName qName) {
return (T) getBuilderFactory().getBuilder(qName).buildObject(qName);
}

View File

@ -79,6 +79,7 @@ import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2X509Credential;
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.assertingPartySigningCredential;
final class TestOpenSamlObjects {
static {
@ -107,6 +108,12 @@ final class TestOpenSamlObjects {
return response;
}
static Response signedResponseWithOneAssertion() {
Response response = response();
response.getAssertions().add(assertion());
return signed(response, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
}
static Assertion assertion() {
return assertion(USERNAME, ASSERTING_PARTY_ENTITY_ID, RELYING_PARTY_ENTITY_ID, DESTINATION);
}
@ -135,7 +142,6 @@ final class TestOpenSamlObjects {
return assertion;
}
static Issuer issuer(String entityId) {
Issuer issuer = build(Issuer.DEFAULT_ELEMENT_NAME);
issuer.setValue(entityId);