diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurer.java index 57d03b5b76..5d2fa4930d 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurer.java @@ -19,6 +19,9 @@ import org.springframework.security.config.annotation.web.HttpSecurityBuilder; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; +import org.springframework.security.web.csrf.CsrfToken; + +import java.util.Collections; /** * Adds a Filter that will generate a login page if one is not specified otherwise when @@ -65,6 +68,13 @@ public final class DefaultLoginPageConfigurer> @Override public void init(H http) throws Exception { + this.loginPageGeneratingFilter.setResolveHiddenInputs( request -> { + CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); + if(token == null) { + return Collections.emptyMap(); + } + return Collections.singletonMap(token.getParameterName(), token.getToken()); + }); http.setSharedObject(DefaultLoginPageGeneratingFilter.class, loginPageGeneratingFilter); } diff --git a/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java b/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java index 48530b996c..3622d215c7 100644 --- a/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java +++ b/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java @@ -50,13 +50,16 @@ import org.springframework.security.web.authentication.preauth.x509.X509Authenti import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; import org.springframework.security.web.authentication.www.BasicAuthenticationEntryPoint; import org.springframework.security.web.authentication.www.BasicAuthenticationFilter; +import org.springframework.security.web.csrf.CsrfToken; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; import org.w3c.dom.Element; +import javax.servlet.http.HttpServletRequest; import java.security.SecureRandom; import java.util.*; +import java.util.function.Function; /** * Handles creation of authentication mechanism filters and related beans for <http> @@ -539,6 +542,7 @@ final class AuthenticationConfigBuilder { + "' attribute to set the URL of the login page."); BeanDefinitionBuilder loginPageFilter = BeanDefinitionBuilder .rootBeanDefinition(DefaultLoginPageGeneratingFilter.class); + loginPageFilter.addPropertyValue("resolveHiddenInputs", new CsrfTokenHiddenInputFunction()); if (formFilterId != null) { loginPageFilter.addConstructorArgReference(formFilterId); @@ -831,4 +835,16 @@ final class AuthenticationConfigBuilder { return providers; } + private static class CsrfTokenHiddenInputFunction implements + Function> { + + @Override + public Map apply(HttpServletRequest request) { + CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); + if(token == null) { + return Collections.emptyMap(); + } + return Collections.singletonMap(token.getParameterName(), token.getToken()); + } + } } diff --git a/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLoginPageGeneratingFilter.java b/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLoginPageGeneratingFilter.java index c5152eacb2..21b5332ce0 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLoginPageGeneratingFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLoginPageGeneratingFilter.java @@ -20,7 +20,7 @@ import org.springframework.security.web.WebAttributes; import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter; import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices; -import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.util.Assert; import org.springframework.web.filter.GenericFilterBean; import javax.servlet.FilterChain; @@ -31,7 +31,9 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; import java.io.IOException; +import java.util.Collections; import java.util.Map; +import java.util.function.Function; /** * For internal use with namespace configuration in the case where a user doesn't @@ -60,6 +62,8 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean { private String openIDusernameParameter; private String openIDrememberMeParameter; private Map oauth2AuthenticationUrlToClientName; + private Function> resolveHiddenInputs = request -> Collections + .emptyMap(); public DefaultLoginPageGeneratingFilter() { @@ -107,6 +111,18 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean { } } + /** + * Sets a Function used to resolve a Map of the hidden inputs where the key is the + * name of the input and the value is the value of the input. Typically this is used + * to resolve the CSRF token. + * @param resolveHiddenInputs the function to resolve the inputs + */ + public void setResolveHiddenInputs( + Function> resolveHiddenInputs) { + Assert.notNull(resolveHiddenInputs, "resolveHiddenInputs cannot be null"); + this.resolveHiddenInputs = resolveHiddenInputs; + } + public boolean isEnabled() { return formLoginEnabled || openIdEnabled || oauth2LoginEnabled; } @@ -282,11 +298,9 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean { } private void renderHiddenInputs(StringBuilder sb, HttpServletRequest request) { - CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); - - if (token != null) { - sb.append(" \n"); + for(Map.Entry input : this.resolveHiddenInputs.apply(request).entrySet()) { + sb.append(" \n"); } }