diff --git a/docs/modules/ROOT/pages/servlet/saml2/login/authentication.adoc b/docs/modules/ROOT/pages/servlet/saml2/login/authentication.adoc index b6f5fbefbb..97bcbccfbc 100644 --- a/docs/modules/ROOT/pages/servlet/saml2/login/authentication.adoc +++ b/docs/modules/ROOT/pages/servlet/saml2/login/authentication.adoc @@ -13,35 +13,11 @@ You can configure this in a number of ways including: To configure these, you'll use the `saml2Login#authenticationManager` method in the DSL. -[[relyingpartyregistrationresolver-apply]] -== Changing `RelyingPartyRegistration` Lookup +[[saml2-response-processing-endpoint]] +== Changing the SAML Response Processing Endpoint -`RelyingPartyRegistration` lookup is customized xref:servlet/saml2/login/overview.adoc#servlet-saml2login-rpr-relyingpartyregistrationresolver[in a `RelyingPartyRegistrationResolver`]. - -To apply a `RelyingPartyRegistrationResolver` when processing `` payloads, you should first publish a `Saml2AuthenticationTokenConverter` bean like so: - -==== -.Java -[source,java,role="primary"] ----- -@Bean -Saml2AuthenticationTokenConverter authenticationConverter(InMemoryRelyingPartyRegistrationRepository registrations) { - return new Saml2AuthenticationTokenConverter(new MyRelyingPartyRegistrationResolver(registrations)); -} ----- - -.Kotlin -[source,kotlin,role="secondary"] ----- -@Bean -fun authenticationConverter(val registrations: InMemoryRelyingPartyRegistrationRepository): Saml2AuthenticationTokenConverter { - return Saml2AuthenticationTokenConverter(MyRelyingPartyRegistrationResolver(registrations)); -} ----- -==== - -Recall that the Assertion Consumer Service URL is `+/saml2/login/sso/{registrationId}+` by default. -If you are no longer wanting the `registrationId` in the URL, change it in the filter chain and in your relying party metadata: +The default endpoint is `+/login/saml2/sso/{registrationId}+`. +You can change this in the DSL and in the associated metadata like so: ==== .Java @@ -82,13 +58,55 @@ and: .Java [source,java,role="primary"] ---- -relyingPartyRegistrationBuilder.assertionConsumerServiceLocation("/saml2/login/sso") +relyingPartyRegistrationBuilder.assertionConsumerServiceLocation("/saml/SSO") ---- .Kotlin [source,kotlin,role="secondary"] ---- -relyingPartyRegistrationBuilder.assertionConsumerServiceLocation("/saml2/login/sso") +relyingPartyRegistrationBuilder.assertionConsumerServiceLocation("/saml/SSO") +---- +==== + +[[relyingpartyregistrationresolver-apply]] +== Changing `RelyingPartyRegistration` lookup + +By default, this converter will match against any associated `` or any `registrationId` it finds in the URL. +Or, if it cannot find one in either of those cases, then it attempts to look it up by the `` element. + +There are a number of circumstances where you might need something more sophisticated, like if you are supporting `ARTIFACT` binding. +In those cases, you can customize lookup through a custom `AuthenticationConverter`, which you can customize like so: + +==== +.Java +[source,java,role="primary"] +---- +@Bean +SecurityFilterChain securityFilters(HttpSecurity http, AuthenticationConverter authenticationConverter) throws Exception { + http + // ... + .saml2Login((saml2) -> saml2.authenticationConverter(authenticationConverter)) + // ... + + return http.build(); +} +---- + +.Kotlin +[source,kotlin,role="secondary"] +---- +@Bean +fun securityFilters(val http: HttpSecurity, val converter: AuthenticationConverter): SecurityFilterChain { + http { + // ... + .saml2Login { + authenticationConverter = converter + } + // ... + } + + return http.build() +} ---- ==== diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverter.java index a513bc1bb8..3546a1e220 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.security.saml2.provider.service.web; +import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.nio.charset.StandardCharsets; import java.util.Arrays; @@ -25,8 +26,18 @@ import java.util.zip.Inflater; import java.util.zip.InflaterOutputStream; import jakarta.servlet.http.HttpServletRequest; +import net.shibboleth.utilities.java.support.xml.ParserPool; +import org.opensaml.core.config.ConfigurationService; +import org.opensaml.core.xml.config.XMLObjectProviderRegistry; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.impl.ResponseUnmarshaller; +import org.w3c.dom.Document; +import org.w3c.dom.Element; import org.springframework.http.HttpMethod; +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.OpenSamlInitializationService; import org.springframework.security.saml2.core.Saml2Error; import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.core.Saml2ParameterNames; @@ -34,7 +45,12 @@ import org.springframework.security.saml2.provider.service.authentication.Abstra import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver; import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.OrRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; /** @@ -43,9 +59,13 @@ import org.springframework.util.Assert; * {@link org.springframework.security.authentication.AuthenticationManager}. * * @author Josh Cummings - * @since 5.4 + * @since 6.1 */ -public final class Saml2AuthenticationTokenConverter implements AuthenticationConverter { +public final class OpenSamlAuthenticationTokenConverter implements AuthenticationConverter { + + static { + OpenSamlInitializationService.initialize(); + } // MimeDecoder allows extra line-breaks as well as other non-alphabet values. // This matches the behaviour of the commons-codec decoder. @@ -53,39 +73,120 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo private static final Base64Checker BASE_64_CHECKER = new Base64Checker(); - private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver; + private final RelyingPartyRegistrationRepository registrations; + + private RequestMatcher requestMatcher = new OrRequestMatcher( + new AntPathRequestMatcher("/login/saml2/sso/{registrationId}"), + new AntPathRequestMatcher("/login/saml2/sso")); + + private final ParserPool parserPool; + + private final ResponseUnmarshaller unmarshaller; private Function loader; /** - * Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for - * resolving {@link RelyingPartyRegistration}s - * @param relyingPartyRegistrationResolver the strategy for resolving + * Constructs a {@link OpenSamlAuthenticationTokenConverter} given a repository for + * {@link RelyingPartyRegistration}s + * @param registrations the repository for {@link RelyingPartyRegistration}s * {@link RelyingPartyRegistration}s */ - public Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { - Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null"); - this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; + public OpenSamlAuthenticationTokenConverter(RelyingPartyRegistrationRepository registrations) { + Assert.notNull(registrations, "relyingPartyRegistrationRepository cannot be null"); + XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class); + this.parserPool = registry.getParserPool(); + this.unmarshaller = (ResponseUnmarshaller) XMLObjectProviderRegistrySupport.getUnmarshallerFactory() + .getUnmarshaller(Response.DEFAULT_ELEMENT_NAME); + this.registrations = registrations; this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest; } + /** + * Resolve an authentication request from the given {@link HttpServletRequest}. + * + *

+ * First uses the configured {@link RequestMatcher} to deduce whether an + * authentication request is being made and optionally for which + * {@code registrationId}. + * + *

+ * If there is an associated {@code }, then the + * {@code registrationId} is looked up and used. + * + *

+ * If a {@code registrationId} is found in the request, then it is looked up and used. + * In that case, if none is found a {@link Saml2AuthenticationException} is thrown. + * + *

+ * Finally, if no {@code registrationId} is found in the request, then the code + * attempts to resolve the {@link RelyingPartyRegistration} from the SAML Response's + * Issuer. + * @param request the HTTP request + * @return the {@link Saml2AuthenticationToken} authentication request + * @throws Saml2AuthenticationException if the {@link RequestMatcher} specifies a + * non-existent {@code registrationId} + */ @Override public Saml2AuthenticationToken convert(HttpServletRequest request) { + String serialized = request.getParameter(Saml2ParameterNames.SAML_RESPONSE); + if (serialized == null) { + return null; + } + RequestMatcher.MatchResult result = this.requestMatcher.matcher(request); + if (!result.isMatch()) { + return null; + } + Saml2AuthenticationToken token = tokenByAuthenticationRequest(request); + if (token == null) { + token = tokenByRegistrationId(request, result); + } + if (token == null) { + token = tokenByEntityId(request); + } + return token; + } + + private Saml2AuthenticationToken tokenByAuthenticationRequest(HttpServletRequest request) { AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request); - String relyingPartyRegistrationId = (authenticationRequest != null) - ? authenticationRequest.getRelyingPartyRegistrationId() : null; - RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request, - relyingPartyRegistrationId); - if (relyingPartyRegistration == null) { + if (authenticationRequest == null) { return null; } - String saml2Response = request.getParameter(Saml2ParameterNames.SAML_RESPONSE); - if (saml2Response == null) { + String registrationId = authenticationRequest.getRelyingPartyRegistrationId(); + RelyingPartyRegistration registration = this.registrations.findByRegistrationId(registrationId); + return tokenByRegistration(request, registration, authenticationRequest); + } + + private Saml2AuthenticationToken tokenByRegistrationId(HttpServletRequest request, + RequestMatcher.MatchResult result) { + String registrationId = result.getVariables().get("registrationId"); + if (registrationId == null) { return null; } - byte[] b = samlDecode(saml2Response); - saml2Response = inflateIfRequired(request, b); - return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response, authenticationRequest); + RelyingPartyRegistration registration = this.registrations.findByRegistrationId(registrationId); + return tokenByRegistration(request, registration, null); + } + + private Saml2AuthenticationToken tokenByEntityId(HttpServletRequest request) { + String serialized = request.getParameter(Saml2ParameterNames.SAML_RESPONSE); + String decoded = new String(samlDecode(serialized), StandardCharsets.UTF_8); + Response response = parse(decoded); + String issuer = response.getIssuer().getValue(); + RelyingPartyRegistration registration = this.registrations.findUniqueByAssertingPartyEntityId(issuer); + return tokenByRegistration(request, registration, null); + } + + private Saml2AuthenticationToken tokenByRegistration(HttpServletRequest request, + RelyingPartyRegistration registration, AbstractSaml2AuthenticationRequest authenticationRequest) { + if (registration == null) { + return null; + } + String serialized = request.getParameter(Saml2ParameterNames.SAML_RESPONSE); + String decoded = inflateIfRequired(request, samlDecode(serialized)); + UriResolver resolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration); + registration = registration.mutate().entityId(resolver.resolve(registration.getEntityId())) + .assertionConsumerServiceLocation(resolver.resolve(registration.getAssertionConsumerServiceLocation())) + .build(); + return new Saml2AuthenticationToken(registration, decoded, authenticationRequest); } /** @@ -100,6 +201,15 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo this.loader = authenticationRequestRepository::loadAuthenticationRequest; } + /** + * Use the given {@link RequestMatcher} to match the request. + * @param requestMatcher the {@link RequestMatcher} to use + */ + public void setRequestMatcher(RequestMatcher requestMatcher) { + Assert.notNull(requestMatcher, "requestMatcher cannot be null"); + this.requestMatcher = requestMatcher; + } + private AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) { return this.loader.apply(request); } @@ -136,6 +246,18 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo } } + private Response parse(String request) throws Saml2Exception { + try { + Document document = this.parserPool + .parse(new ByteArrayInputStream(request.getBytes(StandardCharsets.UTF_8))); + Element element = document.getDocumentElement(); + return (Response) this.unmarshaller.unmarshall(element); + } + catch (Exception ex) { + throw new Saml2Exception("Failed to deserialize LogoutRequest", ex); + } + } + static class Base64Checker { private static final int[] values = genValueMapping(); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java index 5fd9b3076d..5699fd832a 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java @@ -105,7 +105,7 @@ public final class TestOpenSamlObjects { public static String RELYING_PARTY_ENTITY_ID = "https://localhost/saml2/service-provider-metadata/idp-alias"; - private static String ASSERTING_PARTY_ENTITY_ID = "https://some.idp.test/saml2/idp"; + public static String ASSERTING_PARTY_ENTITY_ID = "https://some.idp.test/saml2/idp"; private static SecretKey SECRET_KEY = new SecretKeySpec( Base64.getDecoder().decode("shOnwNMoCv88HKMEa91+FlYoD5RNvzMTAL5LGxZKIFk="), "AES"); @@ -113,7 +113,7 @@ public final class TestOpenSamlObjects { private TestOpenSamlObjects() { } - static Response response() { + public static Response response() { return response(DESTINATION, ASSERTING_PARTY_ENTITY_ID); } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverterTests.java new file mode 100644 index 0000000000..181e50706a --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverterTests.java @@ -0,0 +1,258 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.time.Instant; + +import jakarta.servlet.http.HttpServletRequest; +import net.shibboleth.utilities.java.support.xml.SerializeSupport; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.Marshaller; +import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.saml.common.SignableSAMLObject; +import org.opensaml.saml.saml2.core.Response; +import org.w3c.dom.Element; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.Saml2Utils; +import org.springframework.security.saml2.core.TestSaml2X509Credentials; +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; +import org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.util.StreamUtils; +import org.springframework.web.util.UriUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link OpenSamlAuthenticationTokenConverter} + */ +@ExtendWith(MockitoExtension.class) +public final class OpenSamlAuthenticationTokenConverterTests { + + @Mock + RelyingPartyRegistrationRepository registrations; + + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build(); + + @Test + public void convertWhenSamlResponseThenToken() { + OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations); + given(this.registrations.findByRegistrationId(any())).willReturn(this.registration); + MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId()); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, + Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8))); + Saml2AuthenticationToken token = converter.convert(request); + assertThat(token.getSaml2Response()).isEqualTo("response"); + assertThat(token.getRelyingPartyRegistration().getRegistrationId()) + .isEqualTo(this.registration.getRegistrationId()); + } + + @Test + public void convertWhenSamlResponseInvalidBase64ThenSaml2AuthenticationException() { + OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations); + given(this.registrations.findByRegistrationId(any())).willReturn(this.registration); + MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId()); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "invalid"); + assertThatExceptionOfType(Saml2AuthenticationException.class).isThrownBy(() -> converter.convert(request)) + .withCauseInstanceOf(IllegalArgumentException.class) + .satisfies((ex) -> assertThat(ex.getSaml2Error().getErrorCode()) + .isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE)) + .satisfies((ex) -> assertThat(ex.getSaml2Error().getDescription()) + .isEqualTo("Failed to decode SAMLResponse")); + } + + @Test + public void convertWhenNoSamlResponseThenNull() { + OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations); + MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId()); + assertThat(converter.convert(request)).isNull(); + } + + @Test + public void convertWhenNoMatchingRequestThenNull() { + OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "ignored"); + assertThat(converter.convert(request)).isNull(); + } + + @Test + public void convertWhenNoRelyingPartyRegistrationThenNull() { + OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations); + MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId()); + String response = Saml2Utils.samlEncode(serialize(signed(response())).getBytes(StandardCharsets.UTF_8)); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, response); + assertThat(converter.convert(request)).isNull(); + } + + @Test + public void convertWhenGetRequestThenInflates() { + OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations); + given(this.registrations.findByRegistrationId(any())).willReturn(this.registration); + MockHttpServletRequest request = get("/login/saml2/sso/" + this.registration.getRegistrationId()); + byte[] deflated = Saml2Utils.samlDeflate("response"); + String encoded = Saml2Utils.samlEncode(deflated); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, encoded); + Saml2AuthenticationToken token = converter.convert(request); + assertThat(token.getSaml2Response()).isEqualTo("response"); + assertThat(token.getRelyingPartyRegistration().getRegistrationId()) + .isEqualTo(this.registration.getRegistrationId()); + } + + @Test + public void convertWhenGetRequestInvalidDeflatedThenSaml2AuthenticationException() { + OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations); + given(this.registrations.findByRegistrationId(any())).willReturn(this.registration); + MockHttpServletRequest request = get("/login/saml2/sso/" + this.registration.getRegistrationId()); + byte[] invalidDeflated = "invalid".getBytes(); + String encoded = Saml2Utils.samlEncode(invalidDeflated); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, encoded); + assertThatExceptionOfType(Saml2AuthenticationException.class).isThrownBy(() -> converter.convert(request)) + .withCauseInstanceOf(IOException.class) + .satisfies((ex) -> assertThat(ex.getSaml2Error().getErrorCode()) + .isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE)) + .satisfies( + (ex) -> assertThat(ex.getSaml2Error().getDescription()).isEqualTo("Unable to inflate string")); + } + + @Test + public void convertWhenUsingSamlUtilsBase64ThenXmlIsValid() throws Exception { + OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations); + given(this.registrations.findByRegistrationId(any())).willReturn(this.registration); + MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId()); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, getSsoCircleEncodedXml()); + Saml2AuthenticationToken token = converter.convert(request); + validateSsoCircleXml(token.getSaml2Response()); + } + + @Test + public void convertWhenSavedAuthenticationRequestThenToken() { + Saml2AuthenticationRequestRepository authenticationRequestRepository = mock( + Saml2AuthenticationRequestRepository.class); + AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class); + given(authenticationRequest.getRelyingPartyRegistrationId()).willReturn(this.registration.getRegistrationId()); + OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations); + converter.setAuthenticationRequestRepository(authenticationRequestRepository); + given(this.registrations.findByRegistrationId(any())).willReturn(this.registration); + given(authenticationRequestRepository.loadAuthenticationRequest(any(HttpServletRequest.class))) + .willReturn(authenticationRequest); + MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId()); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, + Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8))); + Saml2AuthenticationToken token = converter.convert(request); + assertThat(token.getSaml2Response()).isEqualTo("response"); + assertThat(token.getRelyingPartyRegistration().getRegistrationId()) + .isEqualTo(this.registration.getRegistrationId()); + assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest); + } + + @Test + public void convertWhenMatchingNoRegistrationIdThenLooksUpByAssertingEntityId() { + OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations); + String response = serialize(signed(response())); + String encoded = Saml2Utils.samlEncode(response.getBytes(StandardCharsets.UTF_8)); + given(this.registrations.findUniqueByAssertingPartyEntityId(TestOpenSamlObjects.ASSERTING_PARTY_ENTITY_ID)) + .willReturn(this.registration); + MockHttpServletRequest request = post("/login/saml2/sso"); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, encoded); + Saml2AuthenticationToken token = converter.convert(request); + assertThat(token.getSaml2Response()).isEqualTo(response); + assertThat(token.getRelyingPartyRegistration().getRegistrationId()) + .isEqualTo(this.registration.getRegistrationId()); + } + + @Test + public void constructorWhenResolverIsNullThenIllegalArgument() { + assertThatIllegalArgumentException().isThrownBy(() -> new Saml2AuthenticationTokenConverter(null)); + } + + @Test + public void setAuthenticationRequestRepositoryWhenNullThenIllegalArgument() { + OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> converter.setAuthenticationRequestRepository(null)); + } + + private void validateSsoCircleXml(String xml) { + assertThat(xml).contains("InResponseTo=\"ARQ9a73ead-7dcf-45a8-89eb-26f3c9900c36\"") + .contains(" ID=\"s246d157446618e90e43fb79bdd4d9e9e19cf2c7c4\"") + .contains("https://idp.ssocircle.com"); + } + + private String getSsoCircleEncodedXml() throws IOException { + ClassPathResource resource = new ClassPathResource("saml2-response-sso-circle.encoded"); + String response = StreamUtils.copyToString(resource.getInputStream(), StandardCharsets.UTF_8); + return UriUtils.decode(response, StandardCharsets.UTF_8); + } + + private MockHttpServletRequest post(String uri) { + MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); + request.setServletPath(uri); + return request; + } + + private MockHttpServletRequest get(String uri) { + MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); + request.setServletPath(uri); + return request; + } + + private T signed(T toSign) { + TestOpenSamlObjects.signed(toSign, TestSaml2X509Credentials.assertingPartySigningCredential(), + TestOpenSamlObjects.RELYING_PARTY_ENTITY_ID); + return toSign; + } + + private Response response() { + Response response = TestOpenSamlObjects.response(); + response.setIssueInstant(Instant.now()); + return response; + } + + private String serialize(XMLObject object) { + try { + Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); + Element element = marshaller.marshall(object); + return SerializeSupport.nodeToString(element); + } + catch (MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + +}