Fix CSRF / DefaultLoginPageGeneratingFilter package tangle

Issue: gh-4636
This commit is contained in:
Rob Winch 2017-10-13 15:11:08 -05:00
parent 7fd1cff3ce
commit a74f7c6faa
3 changed files with 46 additions and 6 deletions

View File

@ -19,6 +19,9 @@ import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
import org.springframework.security.web.csrf.CsrfToken;
import java.util.Collections;
/** /**
* Adds a Filter that will generate a login page if one is not specified otherwise when * Adds a Filter that will generate a login page if one is not specified otherwise when
@ -65,6 +68,13 @@ public final class DefaultLoginPageConfigurer<H extends HttpSecurityBuilder<H>>
@Override @Override
public void init(H http) throws Exception { public void init(H http) throws Exception {
this.loginPageGeneratingFilter.setResolveHiddenInputs( request -> {
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
if(token == null) {
return Collections.emptyMap();
}
return Collections.singletonMap(token.getParameterName(), token.getToken());
});
http.setSharedObject(DefaultLoginPageGeneratingFilter.class, http.setSharedObject(DefaultLoginPageGeneratingFilter.class,
loginPageGeneratingFilter); loginPageGeneratingFilter);
} }

View File

@ -50,13 +50,16 @@ import org.springframework.security.web.authentication.preauth.x509.X509Authenti
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
import org.springframework.security.web.authentication.www.BasicAuthenticationEntryPoint; import org.springframework.security.web.authentication.www.BasicAuthenticationEntryPoint;
import org.springframework.security.web.authentication.www.BasicAuthenticationFilter; import org.springframework.security.web.authentication.www.BasicAuthenticationFilter;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils; import org.springframework.util.xml.DomUtils;
import org.w3c.dom.Element; import org.w3c.dom.Element;
import javax.servlet.http.HttpServletRequest;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.util.*; import java.util.*;
import java.util.function.Function;
/** /**
* Handles creation of authentication mechanism filters and related beans for &lt;http&gt; * Handles creation of authentication mechanism filters and related beans for &lt;http&gt;
@ -539,6 +542,7 @@ final class AuthenticationConfigBuilder {
+ "' attribute to set the URL of the login page."); + "' attribute to set the URL of the login page.");
BeanDefinitionBuilder loginPageFilter = BeanDefinitionBuilder BeanDefinitionBuilder loginPageFilter = BeanDefinitionBuilder
.rootBeanDefinition(DefaultLoginPageGeneratingFilter.class); .rootBeanDefinition(DefaultLoginPageGeneratingFilter.class);
loginPageFilter.addPropertyValue("resolveHiddenInputs", new CsrfTokenHiddenInputFunction());
if (formFilterId != null) { if (formFilterId != null) {
loginPageFilter.addConstructorArgReference(formFilterId); loginPageFilter.addConstructorArgReference(formFilterId);
@ -831,4 +835,16 @@ final class AuthenticationConfigBuilder {
return providers; return providers;
} }
private static class CsrfTokenHiddenInputFunction implements
Function<HttpServletRequest,Map<String,String>> {
@Override
public Map<String, String> apply(HttpServletRequest request) {
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
if(token == null) {
return Collections.emptyMap();
}
return Collections.singletonMap(token.getParameterName(), token.getToken());
}
}
} }

View File

@ -20,7 +20,7 @@ import org.springframework.security.web.WebAttributes;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter; import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices; import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices;
import org.springframework.security.web.csrf.CsrfToken; import org.springframework.util.Assert;
import org.springframework.web.filter.GenericFilterBean; import org.springframework.web.filter.GenericFilterBean;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
@ -31,7 +31,9 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession; import javax.servlet.http.HttpSession;
import java.io.IOException; import java.io.IOException;
import java.util.Collections;
import java.util.Map; import java.util.Map;
import java.util.function.Function;
/** /**
* For internal use with namespace configuration in the case where a user doesn't * For internal use with namespace configuration in the case where a user doesn't
@ -60,6 +62,8 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
private String openIDusernameParameter; private String openIDusernameParameter;
private String openIDrememberMeParameter; private String openIDrememberMeParameter;
private Map<String, String> oauth2AuthenticationUrlToClientName; private Map<String, String> oauth2AuthenticationUrlToClientName;
private Function<HttpServletRequest,Map<String,String>> resolveHiddenInputs = request -> Collections
.emptyMap();
public DefaultLoginPageGeneratingFilter() { public DefaultLoginPageGeneratingFilter() {
@ -107,6 +111,18 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
} }
} }
/**
* Sets a Function used to resolve a Map of the hidden inputs where the key is the
* name of the input and the value is the value of the input. Typically this is used
* to resolve the CSRF token.
* @param resolveHiddenInputs the function to resolve the inputs
*/
public void setResolveHiddenInputs(
Function<HttpServletRequest, Map<String, String>> resolveHiddenInputs) {
Assert.notNull(resolveHiddenInputs, "resolveHiddenInputs cannot be null");
this.resolveHiddenInputs = resolveHiddenInputs;
}
public boolean isEnabled() { public boolean isEnabled() {
return formLoginEnabled || openIdEnabled || oauth2LoginEnabled; return formLoginEnabled || openIdEnabled || oauth2LoginEnabled;
} }
@ -282,11 +298,9 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
} }
private void renderHiddenInputs(StringBuilder sb, HttpServletRequest request) { private void renderHiddenInputs(StringBuilder sb, HttpServletRequest request) {
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); for(Map.Entry<String,String> input : this.resolveHiddenInputs.apply(request).entrySet()) {
sb.append(" <input name=\"" + input.getKey()
if (token != null) { + "\" type=\"hidden\" value=\"" + input.getValue() + "\" />\n");
sb.append(" <input name=\"" + token.getParameterName()
+ "\" type=\"hidden\" value=\"" + token.getToken() + "\" />\n");
} }
} }