diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java index de9c20407e..440a10f933 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,12 +16,16 @@ package org.springframework.security.config.annotation.web.configurers.saml2; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.net.URLDecoder; import java.time.Duration; import java.util.Arrays; import java.util.Base64; import java.util.Collection; import java.util.Collections; +import java.util.zip.Inflater; +import java.util.zip.InflaterOutputStream; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -54,9 +58,12 @@ import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; +import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider; +import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; @@ -69,7 +76,11 @@ import org.springframework.security.web.context.HttpSessionSecurityContextReposi import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; +import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -157,6 +168,20 @@ public class Saml2LoginConfigurerTests { verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class)); } + @Test + public void authenticationRequestWhenAuthnRequestConsumerResolverThenUses() throws Exception { + this.spring.register(CustomAuthnRequestConsumerResolver.class).autowire(); + + MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")) + .andReturn(); + UriComponents components = UriComponentsBuilder + .fromHttpUrl(result.getResponse().getRedirectedUrl()).build(); + String samlRequest = components.getQueryParams().getFirst("SAMLRequest"); + String decoded = URLDecoder.decode(samlRequest, "UTF-8"); + String inflated = samlInflate(samlDecode(decoded)); + assertThat(inflated).contains("ForceAuthn=\"true\""); + } + private void validateSaml2WebSsoAuthenticationFilterConfiguration() { // get the OpenSamlAuthenticationProvider Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain); @@ -275,6 +300,29 @@ public class Saml2LoginConfigurerTests { } } + @EnableWebSecurity + @Import(Saml2LoginConfigBeans.class) + static class CustomAuthnRequestConsumerResolver extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + http + .authorizeRequests(authz -> authz + .anyRequest().authenticated() + ) + .saml2Login(saml2 -> {}); + } + + @Bean + Saml2AuthenticationRequestFactory authenticationRequestFactory() { + OpenSamlAuthenticationRequestFactory authenticationRequestFactory = + new OpenSamlAuthenticationRequestFactory(); + authenticationRequestFactory.setAuthnRequestConsumerResolver( + context -> authnRequest -> authnRequest.setForceAuthn(true)); + return authenticationRequestFactory; + } + } + private static AuthenticationManager getAuthenticationManagerMock(String role) { return new AuthenticationManager() { @@ -315,4 +363,23 @@ public class Saml2LoginConfigurerTests { } } + private static org.apache.commons.codec.binary.Base64 BASE64 = + new org.apache.commons.codec.binary.Base64(0, new byte[]{'\n'}); + + private static byte[] samlDecode(String s) { + return BASE64.decode(s); + } + + private static String samlInflate(byte[] b) { + try { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true)); + iout.write(b); + iout.finish(); + return new String(out.toByteArray(), UTF_8); + } + catch (IOException e) { + throw new Saml2Exception("Unable to inflate string", e); + } + } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java index 973e4bb7b5..130172e3fb 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java @@ -21,6 +21,8 @@ import java.time.Instant; import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.function.Consumer; +import java.util.function.Function; import org.joda.time.DateTime; import org.opensaml.saml.common.xml.SAMLConstants; @@ -43,6 +45,9 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance(); private String protocolBinding = SAMLConstants.SAML2_POST_BINDING_URI; + private Function> authnRequestConsumerResolver + = context -> authnRequest -> {}; + @Override @Deprecated public String createAuthenticationRequest(Saml2AuthenticationRequest request) { @@ -95,8 +100,10 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication } private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) { - return createAuthnRequest(context.getIssuer(), + AuthnRequest authnRequest = createAuthnRequest(context.getIssuer(), context.getDestination(), context.getAssertionConsumerServiceUrl()); + this.authnRequestConsumerResolver.apply(context).accept(authnRequest); + return authnRequest; } private AuthnRequest createAuthnRequest(String issuer, String destination, String assertionConsumerServiceUrl) { @@ -114,6 +121,18 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication return auth; } + /** + * Set the {@link AuthnRequest} post-processor resolver + * + * @param authnRequestConsumerResolver + * @since 5.4 + */ + public void setAuthnRequestConsumerResolver( + Function> authnRequestConsumerResolver) { + Assert.notNull(authnRequestConsumerResolver, "authnRequestConsumerResolver cannot be null"); + this.authnRequestConsumerResolver = authnRequestConsumerResolver; + } + /** * ' * Use this {@link Clock} with {@link Instant#now()} for generating diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java index c4c5db23fd..cd504ee9bb 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java @@ -16,6 +16,9 @@ package org.springframework.security.saml2.provider.service.authentication; +import java.util.function.Consumer; +import java.util.function.Function; + import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -29,9 +32,13 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; import static org.hamcrest.CoreMatchers.containsString; -import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential; +import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode; import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration; import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST; import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT; @@ -160,6 +167,34 @@ public class OpenSamlAuthenticationRequestFactoryTests { factory.setProtocolBinding("my-invalid-binding"); } + @Test + public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() { + Function> authnRequestConsumerResolver = + mock(Function.class); + when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {}); + this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver); + + this.factory.createPostAuthenticationRequest(this.context); + verify(authnRequestConsumerResolver).apply(this.context); + } + + @Test + public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() { + Function> authnRequestConsumerResolver = + mock(Function.class); + when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {}); + this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver); + + this.factory.createRedirectAuthenticationRequest(this.context); + verify(authnRequestConsumerResolver).apply(this.context); + } + + @Test + public void setAuthnRequestConsumerResolverWhenNullThenException() { + assertThatCode(() -> this.factory.setAuthnRequestConsumerResolver(null)) + .isInstanceOf(IllegalArgumentException.class); + } + private AuthnRequest getAuthNRequest(Saml2MessageBinding binding) { AbstractSaml2AuthenticationRequest result = (binding == REDIRECT) ? factory.createRedirectAuthenticationRequest(context) :