From 5061ae9e7900b590a568a377fd58b9663323f3d6 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Tue, 4 Aug 2020 17:28:42 -0600 Subject: [PATCH] Add Saml2AuthenticationTokenConverter Closes gh-8768 --- .../saml2/Saml2LoginConfigurer.java | 28 ++++- .../saml2/Saml2LoginConfigurerTests.java | 42 +++++++ .../Saml2WebSsoAuthenticationFilter.java | 74 +++++------- .../Saml2AuthenticationTokenConverter.java | 105 ++++++++++++++++++ .../security/saml2/core/Saml2Utils.java | 70 ++++++++++++ .../TestRelyingPartyRegistrations.java | 10 +- .../Saml2WebSsoAuthenticationFilterTests.java | 2 +- ...aml2AuthenticationTokenConverterTests.java | 102 +++++++++++++++++ .../samples/config/SecurityConfigTests.java | 2 +- 9 files changed, 386 insertions(+), 49 deletions(-) create mode 100644 saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java create mode 100644 saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2Utils.java create mode 100644 saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java index 321a492f0a..ca7b6d2c08 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java @@ -28,6 +28,7 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configurers.AbstractAuthenticationFilterConfigurer; import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer; +import org.springframework.security.core.Authentication; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; @@ -38,6 +39,8 @@ import org.springframework.security.saml2.provider.service.servlet.filter.Saml2W import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; +import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter; +import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; @@ -106,10 +109,25 @@ public final class Saml2LoginConfigurer> extend private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; + private AuthenticationConverter authenticationConverter; private AuthenticationManager authenticationManager; private Saml2WebSsoAuthenticationFilter saml2WebSsoAuthenticationFilter; + /** + * Use this {@link AuthenticationConverter} when converting incoming requests to an {@link Authentication}. + * By default the {@link Saml2AuthenticationTokenConverter} is used. + * + * @param authenticationConverter the {@link AuthenticationConverter} to use + * @return the {@link Saml2LoginConfigurer} for further configuration + * @since 5.4 + */ + public Saml2LoginConfigurer authenticationConverter(AuthenticationConverter authenticationConverter) { + Assert.notNull(authenticationConverter, "authenticationConverter cannot be null"); + this.authenticationConverter = authenticationConverter; + return this; + } + /** * Allows a configuration of a {@link AuthenticationManager} to be used during SAML 2 authentication. * If none is specified, the system will create one inject it into the {@link Saml2WebSsoAuthenticationFilter} @@ -187,7 +205,7 @@ public final class Saml2LoginConfigurer> extend } saml2WebSsoAuthenticationFilter = new Saml2WebSsoAuthenticationFilter( - this.relyingPartyRegistrationRepository, + getAuthenticationConverter(http), this.loginProcessingUrl ); setAuthenticationFilter(saml2WebSsoAuthenticationFilter); @@ -241,6 +259,14 @@ public final class Saml2LoginConfigurer> extend } } + private AuthenticationConverter getAuthenticationConverter(B http) { + if (this.authenticationConverter == null) { + return new Saml2AuthenticationTokenConverter( + new DefaultRelyingPartyRegistrationResolver(this.relyingPartyRegistrationRepository)); + } + return this.authenticationConverter; + } + private void registerDefaultAuthenticationProvider(B http) { OpenSamlAuthenticationProvider provider = postProcess(new OpenSamlAuthenticationProvider()); http.authenticationProvider(provider); diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java index 21845bcec3..d4be974be4 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java @@ -65,10 +65,12 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.web.FilterChainProxy; +import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository; @@ -86,9 +88,13 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.springframework.security.config.Customizer.withDefaults; +import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential; import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext; +import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials; import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** @@ -101,6 +107,8 @@ public class Saml2LoginConfigurerTests { private static final GrantedAuthoritiesMapper AUTHORITIES_MAPPER = authorities -> Arrays.asList(new SimpleGrantedAuthority("TEST CONVERTED")); private static final Duration RESPONSE_TIME_VALIDATION_SKEW = Duration.ZERO; + private static final String SIGNED_RESPONSE = + "PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0iVVRGLTgiPz48c2FtbDJwOlJlc3BvbnNlIHhtbG5zOnNhbWwycD0idXJuOm9hc2lzOm5hbWVzOnRjOlNBTUw6Mi4wOnByb3RvY29sIiBEZXN0aW5hdGlvbj0iaHR0cHM6Ly9ycC5leGFtcGxlLm9yZy9hY3MiIElEPSJfYzE3MzM2YTAtNTM1My00MTQ5LWI3MmMtMDNkOWY5YWYzMDdlIiBJc3N1ZUluc3RhbnQ9IjIwMjAtMDgtMDRUMjI6MDQ6NDUuMDE2WiIgVmVyc2lvbj0iMi4wIj48c2FtbDI6SXNzdWVyIHhtbG5zOnNhbWwyPSJ1cm46b2FzaXM6bmFtZXM6dGM6U0FNTDoyLjA6YXNzZXJ0aW9uIj5hcC1lbnRpdHktaWQ8L3NhbWwyOklzc3Vlcj48ZHM6U2lnbmF0dXJlIHhtbG5zOmRzPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwLzA5L3htbGRzaWcjIj4KPGRzOlNpZ25lZEluZm8+CjxkczpDYW5vbmljYWxpemF0aW9uTWV0aG9kIEFsZ29yaXRobT0iaHR0cDovL3d3dy53My5vcmcvMjAwMS8xMC94bWwtZXhjLWMxNG4jIi8+CjxkczpTaWduYXR1cmVNZXRob2QgQWxnb3JpdGhtPSJodHRwOi8vd3d3LnczLm9yZy8yMDAxLzA0L3htbGRzaWctbW9yZSNyc2Etc2hhMjU2Ii8+CjxkczpSZWZlcmVuY2UgVVJJPSIjX2MxNzMzNmEwLTUzNTMtNDE0OS1iNzJjLTAzZDlmOWFmMzA3ZSI+CjxkczpUcmFuc2Zvcm1zPgo8ZHM6VHJhbnNmb3JtIEFsZ29yaXRobT0iaHR0cDovL3d3dy53My5vcmcvMjAwMC8wOS94bWxkc2lnI2VudmVsb3BlZC1zaWduYXR1cmUiLz4KPGRzOlRyYW5zZm9ybSBBbGdvcml0aG09Imh0dHA6Ly93d3cudzMub3JnLzIwMDEvMTAveG1sLWV4Yy1jMTRuIyIvPgo8L2RzOlRyYW5zZm9ybXM+CjxkczpEaWdlc3RNZXRob2QgQWxnb3JpdGhtPSJodHRwOi8vd3d3LnczLm9yZy8yMDAxLzA0L3htbGVuYyNzaGEyNTYiLz4KPGRzOkRpZ2VzdFZhbHVlPjYzTmlyenFzaDVVa0h1a3NuRWUrM0hWWU5aYWFsQW1OQXFMc1lGMlRuRDA9PC9kczpEaWdlc3RWYWx1ZT4KPC9kczpSZWZlcmVuY2U+CjwvZHM6U2lnbmVkSW5mbz4KPGRzOlNpZ25hdHVyZVZhbHVlPgpLMVlvWWJVUjBTclY4RTdVMkhxTTIvZUNTOTNoV25mOExnNnozeGZWMUlyalgzSXhWYkNvMVlYcnRBSGRwRVdvYTJKKzVOMmFNbFBHJiMxMzsKN2VpbDBZRC9xdUVRamRYbTNwQTBjZmEvY25pa2RuKzVhbnM0ZWQwanU1amo2dkpvZ2w2Smt4Q25LWUpwTU9HNzhtampmb0phengrWCYjMTM7CkM2NktQVStBYUdxeGVwUEQ1ZlhRdTFKSy9Jb3lBaitaa3k4Z2Jwc3VyZHFCSEJLRWxjdnVOWS92UGY0OGtBeFZBKzdtRGhNNUMvL1AmIzEzOwp0L084Y3NZYXB2UjZjdjZrdk45QXZ1N3FRdm9qVk1McHVxZWNJZDJwTUVYb0NSSnE2Nkd4MStNTUVPeHVpMWZZQlRoMEhhYjRmK3JyJiMxMzsKOEY2V1NFRC8xZllVeHliRkJqZ1Q4d2lEWHFBRU8wSVY4ZWRQeEE9PQo8L2RzOlNpZ25hdHVyZVZhbHVlPgo8L2RzOlNpZ25hdHVyZT48c2FtbDI6QXNzZXJ0aW9uIHhtbG5zOnNhbWwyPSJ1cm46b2FzaXM6bmFtZXM6dGM6U0FNTDoyLjA6YXNzZXJ0aW9uIiBJRD0iQWUzZjQ5OGI4LTliMTctNDA3OC05ZDM1LTg2YTA4NDA4NDk5NSIgSXNzdWVJbnN0YW50PSIyMDIwLTA4LTA0VDIyOjA0OjQ1LjA3N1oiIFZlcnNpb249IjIuMCI+PHNhbWwyOklzc3Vlcj5hcC1lbnRpdHktaWQ8L3NhbWwyOklzc3Vlcj48c2FtbDI6U3ViamVjdD48c2FtbDI6TmFtZUlEPnRlc3RAc2FtbC51c2VyPC9zYW1sMjpOYW1lSUQ+PHNhbWwyOlN1YmplY3RDb25maXJtYXRpb24gTWV0aG9kPSJ1cm46b2FzaXM6bmFtZXM6dGM6U0FNTDoyLjA6Y206YmVhcmVyIj48c2FtbDI6U3ViamVjdENvbmZpcm1hdGlvbkRhdGEgTm90QmVmb3JlPSIyMDIwLTA4LTA0VDIxOjU5OjQ1LjA5MFoiIE5vdE9uT3JBZnRlcj0iMjA0MC0wNy0zMFQyMjowNTowNi4wODhaIiBSZWNpcGllbnQ9Imh0dHBzOi8vcnAuZXhhbXBsZS5vcmcvYWNzIi8+PC9zYW1sMjpTdWJqZWN0Q29uZmlybWF0aW9uPjwvc2FtbDI6U3ViamVjdD48c2FtbDI6Q29uZGl0aW9ucyBOb3RCZWZvcmU9IjIwMjAtMDgtMDRUMjE6NTk6NDUuMDgwWiIgTm90T25PckFmdGVyPSIyMDQwLTA3LTMwVDIyOjA1OjA2LjA4N1oiLz48L3NhbWwyOkFzc2VydGlvbj48L3NhbWwycDpSZXNwb25zZT4="; @Autowired private ConfigurableApplicationContext context; @@ -181,6 +189,23 @@ public class Saml2LoginConfigurerTests { assertThat(inflated).contains("ForceAuthn=\"true\""); } + @Test + public void authenticateWhenCustomAuthenticationConverterThenUses() throws Exception { + this.spring.register(CustomAuthenticationConverter.class).autowire(); + RelyingPartyRegistration relyingPartyRegistration = noCredentials() + .assertingPartyDetails(party -> party + .verificationX509Credentials(c -> c.add(relyingPartyVerifyingCredential())) + ) + .build(); + String response = new String(samlDecode(SIGNED_RESPONSE)); + when(CustomAuthenticationConverter.authenticationConverter.convert(any(HttpServletRequest.class))) + .thenReturn(new Saml2AuthenticationToken(relyingPartyRegistration, response)); + this.mvc.perform(post("/login/saml2/sso/" + relyingPartyRegistration.getRegistrationId()) + .param("SAMLResponse", SIGNED_RESPONSE)) + .andExpect(redirectedUrl("/")); + verify(CustomAuthenticationConverter.authenticationConverter).convert(any(HttpServletRequest.class)); + } + private void validateSaml2WebSsoAuthenticationFilterConfiguration() { // get the OpenSamlAuthenticationProvider Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain); @@ -311,6 +336,23 @@ public class Saml2LoginConfigurerTests { } } + @EnableWebSecurity + @Import(Saml2LoginConfigBeans.class) + static class CustomAuthenticationConverter extends WebSecurityConfigurerAdapter { + static final AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class); + + @Override + protected void configure(HttpSecurity http) throws Exception { + http + .authorizeRequests(authz -> authz + .anyRequest().authenticated() + ) + .saml2Login(saml2 -> saml2 + .authenticationConverter(authenticationConverter) + ); + } + } + private static AuthenticationManager getAuthenticationManagerMock(String role) { return new AuthenticationManager() { diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java index ddcc854d0e..c073ff0092 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java @@ -19,23 +19,19 @@ package org.springframework.security.saml2.provider.service.servlet.filter; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import org.springframework.http.HttpMethod; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; -import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; -import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; import org.springframework.security.saml2.core.Saml2Error; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; +import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter; import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; +import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.session.ChangeSessionIdAuthenticationStrategy; -import org.springframework.security.web.util.matcher.AntPathRequestMatcher; -import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; -import static java.nio.charset.StandardCharsets.UTF_8; import static org.springframework.security.saml2.core.Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND; -import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration; import static org.springframework.util.StringUtils.hasText; /** @@ -44,8 +40,7 @@ import static org.springframework.util.StringUtils.hasText; public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProcessingFilter { public static final String DEFAULT_FILTER_PROCESSES_URI = "/login/saml2/sso/{registrationId}"; - private final RequestMatcher matcher; - private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; + private final AuthenticationConverter authenticationConverter; /** * Creates a {@code Saml2WebSsoAuthenticationFilter} authentication filter that is configured @@ -64,16 +59,30 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce public Saml2WebSsoAuthenticationFilter( RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, String filterProcessesUrl) { - super(filterProcessesUrl); - Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null"); - Assert.hasText(filterProcessesUrl, "filterProcessesUrl must contain a URL pattern"); + this(new Saml2AuthenticationTokenConverter + (new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)), + filterProcessesUrl); + } + + /** + * Creates a {@link Saml2WebSsoAuthenticationFilter} given the provided parameters + * + * @param authenticationConverter the strategy for converting an {@link HttpServletRequest} + * into an {@link Authentication} + * @param filterProcessingUrl the processing URL, must contain a {registrationId} variable + * @since 5.4 + */ + public Saml2WebSsoAuthenticationFilter( + AuthenticationConverter authenticationConverter, + String filterProcessingUrl) { + super(filterProcessingUrl); + Assert.notNull(authenticationConverter, "authenticationConverter cannot be null"); + Assert.hasText(filterProcessingUrl, "filterProcessesUrl must contain a URL pattern"); Assert.isTrue( - filterProcessesUrl.contains("{registrationId}"), + filterProcessingUrl.contains("{registrationId}"), "filterProcessesUrl must contain a {registrationId} match variable" ); - this.matcher = new AntPathRequestMatcher(filterProcessesUrl); - setRequiresAuthenticationRequestMatcher(this.matcher); - this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository; + this.authenticationConverter = authenticationConverter; setAllowSessionCreation(true); setSessionAuthenticationStrategy(new ChangeSessionIdAuthenticationStrategy()); } @@ -86,37 +95,12 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce @Override public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response) throws AuthenticationException { - String saml2Response = request.getParameter("SAMLResponse"); - byte[] b = Saml2Utils.samlDecode(saml2Response); - - String responseXml = inflateIfRequired(request, b); - String registrationId = this.matcher.matcher(request).getVariables().get("registrationId"); - RelyingPartyRegistration rp = - this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId); - if (rp == null) { + Authentication authentication = this.authenticationConverter.convert(request); + if (authentication == null) { Saml2Error saml2Error = new Saml2Error(RELYING_PARTY_REGISTRATION_NOT_FOUND, - "Relying Party Registration not found with ID: " + registrationId); + "No relying party registration found"); throw new Saml2AuthenticationException(saml2Error); } - String applicationUri = Saml2ServletUtils.getApplicationUri(request); - String relyingPartyEntityId = Saml2ServletUtils.resolveUrlTemplate(rp.getEntityId(), applicationUri, rp); - String assertionConsumerServiceLocation = Saml2ServletUtils.resolveUrlTemplate( - rp.getAssertionConsumerServiceLocation(), applicationUri, rp); - RelyingPartyRegistration relyingPartyRegistration = withRelyingPartyRegistration(rp) - .entityId(relyingPartyEntityId) - .assertionConsumerServiceLocation(assertionConsumerServiceLocation) - .build(); - Saml2AuthenticationToken authentication = new Saml2AuthenticationToken( - relyingPartyRegistration, responseXml); return getAuthenticationManager().authenticate(authentication); } - - private String inflateIfRequired(HttpServletRequest request, byte[] b) { - if (HttpMethod.GET.matches(request.getMethod())) { - return Saml2Utils.samlInflate(b); - } - else { - return new String(b, UTF_8); - } - } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java new file mode 100644 index 0000000000..50a511a9b9 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java @@ -0,0 +1,105 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.zip.Inflater; +import java.util.zip.InflaterOutputStream; +import javax.servlet.http.HttpServletRequest; + +import org.apache.commons.codec.binary.Base64; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpMethod; +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.util.Assert; + +import static java.nio.charset.StandardCharsets.UTF_8; + +/** + * An {@link AuthenticationConverter} that generates a {@link Saml2AuthenticationToken} appropriate + * for authenticated a SAML 2.0 Assertion against an + * {@link org.springframework.security.authentication.AuthenticationManager}. + * + * @author Josh Cummings + * @since 5.4 + */ +public class Saml2AuthenticationTokenConverter implements AuthenticationConverter { + private static Base64 BASE64 = new Base64(0, new byte[]{'\n'}); + + private final Converter relyingPartyRegistrationResolver; + + /** + * Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for resolving + * {@link RelyingPartyRegistration}s + * + * @param relyingPartyRegistrationResolver the strategy for resolving {@link RelyingPartyRegistration}s + */ + public Saml2AuthenticationTokenConverter + (Converter relyingPartyRegistrationResolver) { + Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null"); + this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; + } + + /** + * {@inheritDoc} + */ + @Override + public Saml2AuthenticationToken convert(HttpServletRequest request) { + RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.convert(request); + if (relyingPartyRegistration == null) { + return null; + } + String saml2Response = request.getParameter("SAMLResponse"); + if (saml2Response == null) { + return null; + } + byte[] b = samlDecode(saml2Response); + saml2Response = inflateIfRequired(request, b); + return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response); + } + + private String inflateIfRequired(HttpServletRequest request, byte[] b) { + if (HttpMethod.GET.matches(request.getMethod())) { + return samlInflate(b); + } + else { + return new String(b, UTF_8); + } + } + + private byte[] samlDecode(String s) { + return BASE64.decode(s); + } + + private String samlInflate(byte[] b) { + try { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true)); + iout.write(b); + iout.finish(); + return new String(out.toByteArray(), UTF_8); + } + catch (IOException e) { + throw new Saml2Exception("Unable to inflate string", e); + } + } +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2Utils.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2Utils.java new file mode 100644 index 0000000000..3de1cecc9b --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2Utils.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.core; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.zip.Deflater; +import java.util.zip.DeflaterOutputStream; +import java.util.zip.Inflater; +import java.util.zip.InflaterOutputStream; + +import org.apache.commons.codec.binary.Base64; + +import org.springframework.security.saml2.Saml2Exception; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.zip.Deflater.DEFLATED; + +public final class Saml2Utils { + + private static Base64 BASE64 = new Base64(0, new byte[]{'\n'}); + + public static String samlEncode(byte[] b) { + return BASE64.encodeAsString(b); + } + + public static byte[] samlDecode(String s) { + return BASE64.decode(s); + } + + public static byte[] samlDeflate(String s) { + try { + ByteArrayOutputStream b = new ByteArrayOutputStream(); + DeflaterOutputStream deflater = new DeflaterOutputStream(b, new Deflater(DEFLATED, true)); + deflater.write(s.getBytes(UTF_8)); + deflater.finish(); + return b.toByteArray(); + } + catch (IOException e) { + throw new Saml2Exception("Unable to deflate string", e); + } + } + + public static String samlInflate(byte[] b) { + try { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true)); + iout.write(b); + iout.finish(); + return new String(out.toByteArray(), UTF_8); + } + catch (IOException e) { + throw new Saml2Exception("Unable to inflate string", e); + } + } +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/TestRelyingPartyRegistrations.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/TestRelyingPartyRegistrations.java index 5aa604610c..e71ec4b489 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/TestRelyingPartyRegistrations.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/TestRelyingPartyRegistrations.java @@ -48,5 +48,13 @@ public class TestRelyingPartyRegistrations { .credentials(c -> c.add(verificationCertificate)); } - + public static RelyingPartyRegistration.Builder noCredentials() { + return RelyingPartyRegistration.withRegistrationId("registration-id") + .entityId("rp-entity-id") + .assertionConsumerServiceLocation("https://rp.example.org/acs") + .assertingPartyDetails(party -> party + .entityId("ap-entity-id") + .singleSignOnServiceLocation("https://ap.example.org/sso") + ); + } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java index 79908bc520..b685adab56 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java @@ -89,7 +89,7 @@ public class Saml2WebSsoAuthenticationFilterTests { failBecauseExceptionWasNotThrown(Saml2AuthenticationException.class); } catch (Exception e) { assertThat(e).isInstanceOf(Saml2AuthenticationException.class); - assertThat(e.getMessage()).isEqualTo("Relying Party Registration not found with ID: non-existent-id"); + assertThat(e.getMessage()).isEqualTo("No relying party registration found"); } } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java new file mode 100644 index 0000000000..3c9fa627c0 --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java @@ -0,0 +1,102 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import javax.servlet.http.HttpServletRequest; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.saml2.core.Saml2Utils; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; + +@RunWith(MockitoJUnitRunner.class) +public class Saml2AuthenticationTokenConverterTests { + @Mock + Converter relyingPartyRegistrationResolver; + + RelyingPartyRegistration relyingPartyRegistration = relyingPartyRegistration().build(); + + @Test + public void convertWhenSamlResponseThenToken() { + Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter + (this.relyingPartyRegistrationResolver); + when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))) + .thenReturn(this.relyingPartyRegistration); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setParameter("SAMLResponse", Saml2Utils.samlEncode("response".getBytes(UTF_8))); + Saml2AuthenticationToken token = converter.convert(request); + assertThat(token.getSaml2Response()).isEqualTo("response"); + assertThat(token.getRelyingPartyRegistration().getRegistrationId()) + .isEqualTo(relyingPartyRegistration.getRegistrationId()); + } + + @Test + public void convertWhenNoSamlResponseThenNull() { + Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter + (this.relyingPartyRegistrationResolver); + when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))) + .thenReturn(this.relyingPartyRegistration); + MockHttpServletRequest request = new MockHttpServletRequest(); + assertThat(converter.convert(request)).isNull(); + } + + @Test + public void convertWhenNoRelyingPartyRegistrationThenNull() { + Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter + (this.relyingPartyRegistrationResolver); + when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))) + .thenReturn(null); + MockHttpServletRequest request = new MockHttpServletRequest(); + assertThat(converter.convert(request)).isNull(); + } + + @Test + public void convertWhenGetRequestThenInflates() { + Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter + (this.relyingPartyRegistrationResolver); + when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))) + .thenReturn(this.relyingPartyRegistration); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod("GET"); + byte[] deflated = Saml2Utils.samlDeflate("response"); + String encoded = Saml2Utils.samlEncode(deflated); + request.setParameter("SAMLResponse", encoded); + Saml2AuthenticationToken token = converter.convert(request); + assertThat(token.getSaml2Response()).isEqualTo("response"); + assertThat(token.getRelyingPartyRegistration().getRegistrationId()) + .isEqualTo(relyingPartyRegistration.getRegistrationId()); + } + + @Test + public void constructorWhenResolverIsNullThenIllegalArgument() { + assertThatCode(() -> new Saml2AuthenticationTokenConverter(null)) + .isInstanceOf(IllegalArgumentException.class); + } +} diff --git a/samples/javaconfig/saml2login/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java b/samples/javaconfig/saml2login/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java index 010b39c008..0d79b19a50 100644 --- a/samples/javaconfig/saml2login/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java +++ b/samples/javaconfig/saml2login/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java @@ -55,7 +55,7 @@ public class SecurityConfigTests { ) .findFirst() .get(); - for (String field : Arrays.asList("requiresAuthenticationRequestMatcher", "matcher")) { + for (String field : Arrays.asList("requiresAuthenticationRequestMatcher")) { final Object matcher = ReflectionTestUtils.getField(filter, field); final Object pattern = ReflectionTestUtils.getField(matcher, "pattern"); Assert.assertEquals("loginProcessingUrl mismatch", "/sample/jc/saml2/sso/{registrationId}", pattern);