diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ErrorCodes.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ErrorCodes.java index b7f4d9c799..ecee9777f1 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ErrorCodes.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ErrorCodes.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,13 @@ public interface Saml2ErrorCodes { */ String UNKNOWN_RESPONSE_CLASS = "unknown_response_class"; + /** + * The serialized AuthNRequest could not be deserialized correctly. + * + * @since 5.7 + */ + String MALFORMED_REQUEST_DATA = "malformed_request_data"; + /** * The response data is malformed or incomplete. An invalid XML object was received, * and XML unmarshalling failed. @@ -116,4 +123,11 @@ public interface Saml2ErrorCodes { */ String RELYING_PARTY_REGISTRATION_NOT_FOUND = "relying_party_registration_not_found"; + /** + * The InResponseTo content of the response does not match the ID of the AuthNRequest. + * + * @since 5.7 + */ + String INVALID_IN_RESPONSE_TO = "invalid_in_response_to"; + } diff --git a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java index 51d9dc1abb..835fb61677 100644 --- a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java @@ -37,6 +37,7 @@ import org.apache.commons.logging.LogFactory; import org.opensaml.core.config.ConfigurationService; import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.config.XMLObjectProviderRegistry; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.core.xml.schema.XSAny; import org.opensaml.core.xml.schema.XSBoolean; import org.opensaml.core.xml.schema.XSBooleanValue; @@ -57,13 +58,17 @@ import org.opensaml.saml.saml2.assertion.impl.DelegationRestrictionConditionVali import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.Attribute; import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.AuthnStatement; import org.opensaml.saml.saml2.core.Condition; import org.opensaml.saml.saml2.core.EncryptedAssertion; import org.opensaml.saml.saml2.core.OneTimeUse; import org.opensaml.saml.saml2.core.Response; import org.opensaml.saml.saml2.core.StatusCode; +import org.opensaml.saml.saml2.core.Subject; import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.core.SubjectConfirmationData; +import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller; import org.opensaml.saml.saml2.core.impl.ResponseUnmarshaller; import org.opensaml.saml.saml2.encryption.Decrypter; import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; @@ -85,6 +90,7 @@ import org.springframework.security.saml2.core.Saml2Error; import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -349,6 +355,37 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv this.responseAuthenticationConverter = responseAuthenticationConverter; } + private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest, + String inResponseTo) { + if (!StringUtils.hasText(inResponseTo)) { + return Saml2ResponseValidatorResult.success(); + } + AuthnRequest request; + try { + request = parseRequest(storedRequest); + } + catch (Exception ex) { + String message = "The stored AuthNRequest could not be properly deserialized [" + ex.getMessage() + "]"; + return Saml2ResponseValidatorResult + .failure(new Saml2Error(Saml2ErrorCodes.MALFORMED_REQUEST_DATA, message)); + } + if (request == null) { + String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]" + + " but no saved AuthNRequest request was found"; + return Saml2ResponseValidatorResult + .failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message)); + } + else if (!request.getID().equals(inResponseTo)) { + String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the " + + "AuthNRequest [" + request.getID() + "]"; + return Saml2ResponseValidatorResult + .failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message)); + } + else { + return Saml2ResponseValidatorResult.success(); + } + } + /** * Construct a default strategy for validating the SAML 2.0 Response * @return the default response validator strategy @@ -365,6 +402,10 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv response.getID()); result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message)); } + + String inResponseTo = response.getInResponseTo(); + result = result.concat(validateInResponseTo(token.getAuthenticationRequest(), inResponseTo)); + String issuer = response.getIssuer().getValue(); String destination = response.getDestination(); String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); @@ -447,7 +488,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv try { Saml2AuthenticationToken token = (Saml2AuthenticationToken) authentication; String serializedResponse = token.getSaml2Response(); - Response response = parse(serializedResponse); + Response response = parseResponse(serializedResponse); process(token, response); AbstractAuthenticationToken authenticationResponse = this.responseAuthenticationConverter .convert(new ResponseToken(response, token)); @@ -469,7 +510,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv return authentication != null && Saml2AuthenticationToken.class.isAssignableFrom(authentication); } - private Response parse(String response) throws Saml2Exception, Saml2AuthenticationException { + private Response parseResponse(String response) throws Saml2Exception, Saml2AuthenticationException { try { Document document = this.parserPool .parse(new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8))); @@ -481,6 +522,28 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv } } + private static AuthnRequest parseRequest(AbstractSaml2AuthenticationRequest request) throws Exception { + if (request == null) { + return null; + } + String samlRequest = request.getSamlRequest(); + if (!StringUtils.hasText(samlRequest)) { + return null; + } + if (request.getBinding() == Saml2MessageBinding.REDIRECT) { + samlRequest = Saml2Utils.samlInflate(Saml2Utils.samlDecode(samlRequest)); + } + else { + samlRequest = new String(Saml2Utils.samlDecode(samlRequest), StandardCharsets.UTF_8); + } + Document document = XMLObjectProviderRegistrySupport.getParserPool() + .parse(new ByteArrayInputStream(samlRequest.getBytes(StandardCharsets.UTF_8))); + Element element = document.getDocumentElement(); + AuthnRequestUnmarshaller unmarshaller = (AuthnRequestUnmarshaller) XMLObjectProviderRegistrySupport + .getUnmarshallerFactory().getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME); + return (AuthnRequest) unmarshaller.unmarshall(element); + } + private void process(Saml2AuthenticationToken token, Response response) { String issuer = response.getIssuer().getValue(); this.logger.debug(LogMessage.format("Processing SAML response from %s", issuer)); @@ -685,6 +748,30 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv }; } + private static boolean assertionContainsInResponseTo(Assertion assertion) { + Subject subject = (assertion != null) ? assertion.getSubject() : null; + List confirmations = (subject != null) ? subject.getSubjectConfirmations() + : new ArrayList<>(); + return confirmations.stream().filter((confirmation) -> { + SubjectConfirmationData confirmationData = confirmation.getSubjectConfirmationData(); + return confirmationData != null && StringUtils.hasText(confirmationData.getInResponseTo()); + }).findFirst().orElse(null) != null; + } + + private static void addRequestIdToValidationContext(AbstractSaml2AuthenticationRequest storedRequest, + Map context) { + String requestId = null; + try { + AuthnRequest request = parseRequest(storedRequest); + requestId = (request != null) ? request.getID() : null; + } + catch (Exception ex) { + } + if (StringUtils.hasText(requestId)) { + context.put(SAML2AssertionValidationParameters.SC_VALID_IN_RESPONSE_TO, requestId); + } + } + private static ValidationContext createValidationContext(AssertionToken assertionToken, Consumer> paramsConsumer) { RelyingPartyRegistration relyingPartyRegistration = assertionToken.token.getRelyingPartyRegistration(); @@ -692,6 +779,10 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv String recipient = relyingPartyRegistration.getAssertionConsumerServiceLocation(); String assertingPartyEntityId = relyingPartyRegistration.getAssertingPartyDetails().getEntityId(); Map params = new HashMap<>(); + Assertion assertion = assertionToken.getAssertion(); + if (assertionContainsInResponseTo(assertion)) { + addRequestIdToValidationContext(assertionToken.token.getAuthenticationRequest(), params); + } params.put(SAML2AssertionValidationParameters.COND_VALID_AUDIENCES, Collections.singleton(audience)); params.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton(recipient)); params.put(SAML2AssertionValidationParameters.VALID_ISSUERS, Collections.singleton(assertingPartyEntityId)); @@ -733,13 +824,6 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv // applications should validate their own addresses - gh-7514 return ValidationResult.VALID; } - - @Override - protected ValidationResult validateInResponseTo(SubjectConfirmation confirmation, Assertion assertion, - ValidationContext context, boolean required) { - // applications should validate their own in response to - return ValidationResult.VALID; - } }); } diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java index fdff0619bf..2cf7c62ba7 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java @@ -19,6 +19,7 @@ package org.springframework.security.saml2.provider.service.authentication; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectOutputStream; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; import java.util.Arrays; @@ -46,6 +47,7 @@ import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.Attribute; import org.opensaml.saml.saml2.core.AttributeStatement; import org.opensaml.saml.saml2.core.AttributeValue; +import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.Conditions; import org.opensaml.saml.saml2.core.EncryptedAssertion; import org.opensaml.saml.saml2.core.EncryptedAttribute; @@ -74,6 +76,7 @@ import org.springframework.security.saml2.core.TestSaml2X509Credentials; import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider.ResponseToken; import org.springframework.security.saml2.provider.service.authentication.TestCustomOpenSamlObjects.CustomOpenSamlObject; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.util.StringUtils; @@ -217,6 +220,111 @@ public class OpenSaml4AuthenticationProviderTests { this.provider.authenticate(token); } + @Test + public void evaluateInResponseToSucceedsWhenInResponseToInResponseAndAssertionsMatchRequestID() { + Response response = response(); + response.setInResponseTo("SAML2"); + response.getAssertions().add(signed(assertion("SAML2"))); + response.getAssertions().add(signed(assertion("SAML2"))); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", + Saml2MessageBinding.POST, false); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + this.provider.authenticate(token); + } + + @Test + public void evaluateInResponseToSucceedsWhenInResponseToInAssertionOnlyMatchRequestID() { + Response response = response(); + response.getAssertions().add(signed(assertion())); + response.getAssertions().add(signed(assertion("SAML2"))); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", + Saml2MessageBinding.POST, false); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + this.provider.authenticate(token); + } + + @Test + public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndCorruptedStoredRequest() { + Response response = response(); + response.getAssertions().add(signed(assertion())); + response.getAssertions().add(signed(assertion("SAML2"))); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", + Saml2MessageBinding.POST, true); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion"); + } + + @Test + public void evaluateInResponseToFailsWhenInResponseToInAssertionMismatchWithRequestID() { + Response response = response(); + response.setInResponseTo("SAML2"); + response.getAssertions().add(signed(assertion("SAML2"))); + response.getAssertions().add(signed(assertion("BAD"))); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", + Saml2MessageBinding.POST, false); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion"); + } + + @Test + public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndMismatchWithRequestID() { + Response response = response(); + response.getAssertions().add(signed(assertion())); + response.getAssertions().add(signed(assertion("BAD"))); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", + Saml2MessageBinding.POST, false); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion"); + } + + @Test + public void evaluateInResponseToFailsWhenInResponseInToResponseMismatchWithRequestID() { + Response response = response(); + response.setInResponseTo("BAD"); + response.getAssertions().add(signed(assertion("SAML2"))); + response.getAssertions().add(signed(assertion("SAML2"))); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", + Saml2MessageBinding.POST, false); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_in_response_to"); + } + + @Test + public void evaluateInResponseToFailsWhenInResponseInToResponseAndCorruptedStoredRequest() { + Response response = response(); + response.setInResponseTo("SAML2"); + response.getAssertions().add(signed(assertion())); + response.getAssertions().add(signed(assertion())); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", + Saml2MessageBinding.POST, true); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("malformed_request_data"); + } + + @Test + public void evaluateInResponseToFailsWhenInResponseToInResponseButNoSavedRequest() { + Response response = response(); + response.setInResponseTo("BAD"); + Saml2AuthenticationToken token = token(response, verifying(registration())); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_in_response_to"); + } + + @Test + public void evaluateInResponseToSucceedsWhenNoInResponseToInResponseOrAssertions() { + Response response = response(); + response.getAssertions().add(signed(assertion())); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", + Saml2MessageBinding.POST, false); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + this.provider.authenticate(token); + } + @Test public void authenticateWhenAssertionContainsAttributesThenItSucceeds() { Response response = response(); @@ -658,13 +766,27 @@ public class OpenSaml4AuthenticationProviderTests { return response; } - private Assertion assertion() { + private AuthnRequest request() { + AuthnRequest request = TestOpenSamlObjects.authnRequest(); + return request; + } + + private String serializedRequest(AuthnRequest request, Saml2MessageBinding binding) { + String xml = serialize(request); + return (binding == Saml2MessageBinding.POST) ? Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8)) + : Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml)); + } + + private Assertion assertion(String inResponseTo) { Assertion assertion = TestOpenSamlObjects.assertion(); assertion.setIssueInstant(Instant.now()); for (SubjectConfirmation confirmation : assertion.getSubject().getSubjectConfirmations()) { SubjectConfirmationData data = confirmation.getSubjectConfirmationData(); data.setNotBefore(Instant.now().minus(Duration.ofMillis(5 * 60 * 1000))); data.setNotOnOrAfter(Instant.now().plus(Duration.ofMillis(5 * 60 * 1000))); + if (StringUtils.hasText(inResponseTo)) { + data.setInResponseTo(inResponseTo); + } } Conditions conditions = assertion.getConditions(); conditions.setNotBefore(Instant.now().minus(Duration.ofMillis(5 * 60 * 1000))); @@ -672,6 +794,10 @@ public class OpenSaml4AuthenticationProviderTests { return assertion; } + private Assertion assertion() { + return assertion(null); + } + private T signed(T toSign) { TestOpenSamlObjects.signed(toSign, TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); @@ -701,6 +827,27 @@ public class OpenSaml4AuthenticationProviderTests { return new Saml2AuthenticationToken(registration.build(), serialize(response)); } + private Saml2AuthenticationToken token(Response response, RelyingPartyRegistration.Builder registration, + AbstractSaml2AuthenticationRequest authenticationRequest) { + return new Saml2AuthenticationToken(registration.build(), serialize(response), authenticationRequest); + } + + private AbstractSaml2AuthenticationRequest mockedStoredAuthenticationRequest(String requestId, + Saml2MessageBinding binding, boolean corruptRequestString) { + AuthnRequest request = request(); + if (requestId != null) { + request.setID(requestId); + } + String serializedRequest = serializedRequest(request, binding); + if (corruptRequestString) { + serializedRequest = serializedRequest.substring(2, serializedRequest.length() - 2); + } + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class); + given(mockAuthenticationRequest.getSamlRequest()).willReturn(serializedRequest); + given(mockAuthenticationRequest.getBinding()).willReturn(binding); + return mockAuthenticationRequest; + } + private RelyingPartyRegistration.Builder registration() { return TestRelyingPartyRegistrations.noCredentials().entityId(RELYING_PARTY_ENTITY_ID) .assertionConsumerServiceLocation(DESTINATION)