Polish spring-security-web main code

Manually polish `spring-security-web` following the formatting
and checkstyle fixes.

Issue gh-8945
This commit is contained in:
Phillip Webb 2020-08-03 22:57:18 -07:00 committed by Rob Winch
parent ef951bae90
commit 5bdd757108
178 changed files with 1676 additions and 2791 deletions

View File

@ -45,7 +45,6 @@ public interface AuthenticationEntryPoint {
* @param request that resulted in an <code>AuthenticationException</code> * @param request that resulted in an <code>AuthenticationException</code>
* @param response so that the user agent can begin authentication * @param response so that the user agent can begin authentication
* @param authException that caused the invocation * @param authException that caused the invocation
*
*/ */
void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException authException) void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException authException)
throws IOException, ServletException; throws IOException, ServletException;

View File

@ -24,7 +24,9 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.web.util.UrlUtils; import org.springframework.security.web.util.UrlUtils;
import org.springframework.util.Assert;
/** /**
* Simple implementation of <tt>RedirectStrategy</tt> which is the default used throughout * Simple implementation of <tt>RedirectStrategy</tt> which is the default used throughout
@ -51,11 +53,7 @@ public class DefaultRedirectStrategy implements RedirectStrategy {
public void sendRedirect(HttpServletRequest request, HttpServletResponse response, String url) throws IOException { public void sendRedirect(HttpServletRequest request, HttpServletResponse response, String url) throws IOException {
String redirectUrl = calculateRedirectUrl(request.getContextPath(), url); String redirectUrl = calculateRedirectUrl(request.getContextPath(), url);
redirectUrl = response.encodeRedirectURL(redirectUrl); redirectUrl = response.encodeRedirectURL(redirectUrl);
this.logger.debug(LogMessage.format("Redirecting to '%s'", redirectUrl));
if (this.logger.isDebugEnabled()) {
this.logger.debug("Redirecting to '" + redirectUrl + "'");
}
response.sendRedirect(redirectUrl); response.sendRedirect(redirectUrl);
} }
@ -64,30 +62,20 @@ public class DefaultRedirectStrategy implements RedirectStrategy {
if (isContextRelative()) { if (isContextRelative()) {
return url; return url;
} }
else { return contextPath + url;
return contextPath + url;
}
} }
// Full URL, including http(s):// // Full URL, including http(s)://
if (!isContextRelative()) { if (!isContextRelative()) {
return url; return url;
} }
Assert.isTrue(url.contains(contextPath), "The fully qualified URL does not include context path.");
if (!url.contains(contextPath)) {
throw new IllegalArgumentException("The fully qualified URL does not include context path.");
}
// Calculate the relative URL from the fully qualified URL, minus the last // Calculate the relative URL from the fully qualified URL, minus the last
// occurrence of the scheme and base context. // occurrence of the scheme and base context.
url = url.substring(url.lastIndexOf("://") + 3); // strip off scheme url = url.substring(url.lastIndexOf("://") + 3);
url = url.substring(url.indexOf(contextPath) + contextPath.length()); url = url.substring(url.indexOf(contextPath) + contextPath.length());
if (url.length() > 1 && url.charAt(0) == '/') { if (url.length() > 1 && url.charAt(0) == '/') {
url = url.substring(1); url = url.substring(1);
} }
return url; return url;
} }

View File

@ -26,6 +26,7 @@ import javax.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher;
/** /**
@ -47,7 +48,7 @@ public final class DefaultSecurityFilterChain implements SecurityFilterChain {
} }
public DefaultSecurityFilterChain(RequestMatcher requestMatcher, List<Filter> filters) { public DefaultSecurityFilterChain(RequestMatcher requestMatcher, List<Filter> filters) {
logger.info("Creating filter chain: " + requestMatcher + ", " + filters); logger.info(LogMessage.format("Creating filter chain: %s, %s", requestMatcher, filters));
this.requestMatcher = requestMatcher; this.requestMatcher = requestMatcher;
this.filters = new ArrayList<>(filters); this.filters = new ArrayList<>(filters);
} }

View File

@ -32,6 +32,7 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.firewall.DefaultRequestRejectedHandler; import org.springframework.security.web.firewall.DefaultRequestRejectedHandler;
import org.springframework.security.web.firewall.FirewalledRequest; import org.springframework.security.web.firewall.FirewalledRequest;
@ -173,47 +174,37 @@ public class FilterChainProxy extends GenericFilterBean {
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException { throws IOException, ServletException {
boolean clearContext = request.getAttribute(FILTER_APPLIED) == null; boolean clearContext = request.getAttribute(FILTER_APPLIED) == null;
if (clearContext) { if (!clearContext) {
try {
request.setAttribute(FILTER_APPLIED, Boolean.TRUE);
doFilterInternal(request, response, chain);
}
catch (RequestRejectedException ex) {
this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response, ex);
}
finally {
SecurityContextHolder.clearContext();
request.removeAttribute(FILTER_APPLIED);
}
}
else {
doFilterInternal(request, response, chain); doFilterInternal(request, response, chain);
return;
}
try {
request.setAttribute(FILTER_APPLIED, Boolean.TRUE);
doFilterInternal(request, response, chain);
}
catch (RequestRejectedException ex) {
this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response, ex);
}
finally {
SecurityContextHolder.clearContext();
request.removeAttribute(FILTER_APPLIED);
} }
} }
private void doFilterInternal(ServletRequest request, ServletResponse response, FilterChain chain) private void doFilterInternal(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException { throws IOException, ServletException {
FirewalledRequest firewallRequest = this.firewall.getFirewalledRequest((HttpServletRequest) request);
FirewalledRequest fwRequest = this.firewall.getFirewalledRequest((HttpServletRequest) request); HttpServletResponse firewallResponse = this.firewall.getFirewalledResponse((HttpServletResponse) response);
HttpServletResponse fwResponse = this.firewall.getFirewalledResponse((HttpServletResponse) response); List<Filter> filters = getFilters(firewallRequest);
List<Filter> filters = getFilters(fwRequest);
if (filters == null || filters.size() == 0) { if (filters == null || filters.size() == 0) {
if (logger.isDebugEnabled()) { logger.debug(LogMessage.of(() -> UrlUtils.buildRequestUrl(firewallRequest)
logger.debug(UrlUtils.buildRequestUrl(fwRequest) + ((filters != null) ? " has an empty filter list" : " has no matching filters")));
+ ((filters != null) ? " has an empty filter list" : " has no matching filters")); firewallRequest.reset();
} chain.doFilter(firewallRequest, firewallResponse);
fwRequest.reset();
chain.doFilter(fwRequest, fwResponse);
return; return;
} }
VirtualFilterChain virtualFilterChain = new VirtualFilterChain(firewallRequest, chain, filters);
VirtualFilterChain vfc = new VirtualFilterChain(fwRequest, chain, filters); virtualFilterChain.doFilter(firewallRequest, firewallResponse);
vfc.doFilter(fwRequest, fwResponse);
} }
/** /**
@ -227,7 +218,6 @@ public class FilterChainProxy extends GenericFilterBean {
return chain.getFilters(); return chain.getFilters();
} }
} }
return null; return null;
} }
@ -286,7 +276,6 @@ public class FilterChainProxy extends GenericFilterBean {
sb.append("Filter Chains: "); sb.append("Filter Chains: ");
sb.append(this.filterChains); sb.append(this.filterChains);
sb.append("]"); sb.append("]");
return sb.toString(); return sb.toString();
} }
@ -317,30 +306,19 @@ public class FilterChainProxy extends GenericFilterBean {
@Override @Override
public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException {
if (this.currentPosition == this.size) { if (this.currentPosition == this.size) {
if (logger.isDebugEnabled()) { logger.debug(LogMessage.of(() -> UrlUtils.buildRequestUrl(this.firewalledRequest)
logger.debug(UrlUtils.buildRequestUrl(this.firewalledRequest) + " reached end of additional filter chain; proceeding with original chain"));
+ " reached end of additional filter chain; proceeding with original chain");
}
// Deactivate path stripping as we exit the security filter chain // Deactivate path stripping as we exit the security filter chain
this.firewalledRequest.reset(); this.firewalledRequest.reset();
this.originalChain.doFilter(request, response); this.originalChain.doFilter(request, response);
return;
} }
else { this.currentPosition++;
this.currentPosition++; Filter nextFilter = this.additionalFilters.get(this.currentPosition - 1);
logger.debug(LogMessage.of(() -> UrlUtils.buildRequestUrl(this.firewalledRequest) + " at position "
Filter nextFilter = this.additionalFilters.get(this.currentPosition - 1); + this.currentPosition + " of " + this.size + " in additional filter chain; firing Filter: '"
+ nextFilter.getClass().getSimpleName() + "'"));
if (logger.isDebugEnabled()) { nextFilter.doFilter(request, response, this);
logger.debug(
UrlUtils.buildRequestUrl(this.firewalledRequest) + " at position " + this.currentPosition
+ " of " + this.size + " in additional filter chain; firing Filter: '"
+ nextFilter.getClass().getSimpleName() + "'");
}
nextFilter.doFilter(request, response, this);
}
} }
} }

View File

@ -37,6 +37,7 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.security.web.util.UrlUtils; import org.springframework.security.web.util.UrlUtils;
import org.springframework.util.Assert;
/** /**
* Holds objects associated with a HTTP filter. * Holds objects associated with a HTTP filter.
@ -65,10 +66,7 @@ public class FilterInvocation {
private HttpServletResponse response; private HttpServletResponse response;
public FilterInvocation(ServletRequest request, ServletResponse response, FilterChain chain) { public FilterInvocation(ServletRequest request, ServletResponse response, FilterChain chain) {
if ((request == null) || (response == null) || (chain == null)) { Assert.isTrue(request != null && response != null && chain != null, "Cannot pass null values to constructor");
throw new IllegalArgumentException("Cannot pass null values to constructor");
}
this.request = (HttpServletRequest) request; this.request = (HttpServletRequest) request;
this.response = (HttpServletResponse) response; this.response = (HttpServletResponse) response;
this.chain = chain; this.chain = chain;
@ -84,9 +82,7 @@ public class FilterInvocation {
public FilterInvocation(String contextPath, String servletPath, String pathInfo, String query, String method) { public FilterInvocation(String contextPath, String servletPath, String pathInfo, String query, String method) {
DummyRequest request = new DummyRequest(); DummyRequest request = new DummyRequest();
if (contextPath == null) { contextPath = (contextPath != null) ? contextPath : "/cp";
contextPath = "/cp";
}
request.setContextPath(contextPath); request.setContextPath(contextPath);
request.setServletPath(servletPath); request.setServletPath(servletPath);
request.setRequestURI(contextPath + servletPath + ((pathInfo != null) ? pathInfo : "")); request.setRequestURI(contextPath + servletPath + ((pathInfo != null) ? pathInfo : ""));
@ -256,9 +252,7 @@ public class FilterInvocation {
if (value == null) { if (value == null) {
return -1; return -1;
} }
else { return Integer.parseInt(value);
return Integer.parseInt(value);
}
} }
void addHeader(String name, String value) { void addHeader(String name, String value) {
@ -267,8 +261,8 @@ public class FilterInvocation {
@Override @Override
public String getParameter(String name) { public String getParameter(String name) {
String[] arr = this.parameters.get(name); String[] array = this.parameters.get(name);
return (arr != null && arr.length > 0) ? arr[0] : null; return (array != null && array.length > 0) ? array[0] : null;
} }
@Override @Override
@ -317,7 +311,6 @@ public class FilterInvocation {
private Object invokeDefaultMethodForJdk8(Object proxy, Method method, Object[] args) throws Throwable { private Object invokeDefaultMethodForJdk8(Object proxy, Method method, Object[] args) throws Throwable {
Constructor<Lookup> constructor = Lookup.class.getDeclaredConstructor(Class.class); Constructor<Lookup> constructor = Lookup.class.getDeclaredConstructor(Class.class);
constructor.setAccessible(true); constructor.setAccessible(true);
Class<?> clazz = method.getDeclaringClass(); Class<?> clazz = method.getDeclaringClass();
return constructor.newInstance(clazz).in(clazz).unreflectSpecial(method, clazz).bindTo(proxy) return constructor.newInstance(clazz).in(clazz).unreflectSpecial(method, clazz).bindTo(proxy)
.invokeWithArguments(args); .invokeWithArguments(args);

View File

@ -56,7 +56,6 @@ public class PortMapperImpl implements PortMapper {
return httpPort; return httpPort;
} }
} }
return null; return null;
} }
@ -88,24 +87,19 @@ public class PortMapperImpl implements PortMapper {
*/ */
public void setPortMappings(Map<String, String> newMappings) { public void setPortMappings(Map<String, String> newMappings) {
Assert.notNull(newMappings, "A valid list of HTTPS port mappings must be provided"); Assert.notNull(newMappings, "A valid list of HTTPS port mappings must be provided");
this.httpsPortMappings.clear(); this.httpsPortMappings.clear();
for (Map.Entry<String, String> entry : newMappings.entrySet()) { for (Map.Entry<String, String> entry : newMappings.entrySet()) {
Integer httpPort = Integer.valueOf(entry.getKey()); Integer httpPort = Integer.valueOf(entry.getKey());
Integer httpsPort = Integer.valueOf(entry.getValue()); Integer httpsPort = Integer.valueOf(entry.getValue());
Assert.isTrue(isInPortRange(httpPort) && isInPortRange(httpsPort),
if ((httpPort < 1) || (httpPort > 65535) || (httpsPort < 1) || (httpsPort > 65535)) { () -> "one or both ports out of legal range: " + httpPort + ", " + httpsPort);
throw new IllegalArgumentException(
"one or both ports out of legal range: " + httpPort + ", " + httpsPort);
}
this.httpsPortMappings.put(httpPort, httpsPort); this.httpsPortMappings.put(httpPort, httpsPort);
} }
Assert.isTrue(!this.httpsPortMappings.isEmpty(), "must map at least one port");
}
if (this.httpsPortMappings.size() < 1) { private boolean isInPortRange(int port) {
throw new IllegalArgumentException("must map at least one port"); return port >= 1 && port <= 65535;
}
} }
} }

View File

@ -45,24 +45,19 @@ public class PortResolverImpl implements PortResolver {
@Override @Override
public int getServerPort(ServletRequest request) { public int getServerPort(ServletRequest request) {
int serverPort = request.getServerPort(); int serverPort = request.getServerPort();
Integer portLookup = null;
String scheme = request.getScheme().toLowerCase(); String scheme = request.getScheme().toLowerCase();
Integer mappedPort = getMappedPort(serverPort, scheme);
return (mappedPort != null) ? mappedPort : serverPort;
}
private Integer getMappedPort(int serverPort, String scheme) {
if ("http".equals(scheme)) { if ("http".equals(scheme)) {
portLookup = this.portMapper.lookupHttpPort(serverPort); return this.portMapper.lookupHttpPort(serverPort);
} }
else if ("https".equals(scheme)) { if ("https".equals(scheme)) {
portLookup = this.portMapper.lookupHttpsPort(serverPort); return this.portMapper.lookupHttpsPort(serverPort);
} }
return null;
if (portLookup != null) {
// IE 6 bug
serverPort = portLookup;
}
return serverPort;
} }
public void setPortMapper(PortMapper portMapper) { public void setPortMapper(PortMapper portMapper) {

View File

@ -18,7 +18,6 @@ package org.springframework.security.web.access;
import java.io.IOException; import java.io.IOException;
import javax.servlet.RequestDispatcher;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
@ -29,6 +28,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.web.WebAttributes; import org.springframework.security.web.WebAttributes;
import org.springframework.util.Assert;
/** /**
* Base implementation of {@link AccessDeniedHandler}. * Base implementation of {@link AccessDeniedHandler}.
@ -52,22 +52,19 @@ public class AccessDeniedHandlerImpl implements AccessDeniedHandler {
@Override @Override
public void handle(HttpServletRequest request, HttpServletResponse response, public void handle(HttpServletRequest request, HttpServletResponse response,
AccessDeniedException accessDeniedException) throws IOException, ServletException { AccessDeniedException accessDeniedException) throws IOException, ServletException {
if (!response.isCommitted()) { if (response.isCommitted()) {
if (this.errorPage != null) { return;
// Put exception into request scope (perhaps of use to a view)
request.setAttribute(WebAttributes.ACCESS_DENIED_403, accessDeniedException);
// Set the 403 status code.
response.setStatus(HttpStatus.FORBIDDEN.value());
// forward to error page.
RequestDispatcher dispatcher = request.getRequestDispatcher(this.errorPage);
dispatcher.forward(request, response);
}
else {
response.sendError(HttpStatus.FORBIDDEN.value(), HttpStatus.FORBIDDEN.getReasonPhrase());
}
} }
if (this.errorPage == null) {
response.sendError(HttpStatus.FORBIDDEN.value(), HttpStatus.FORBIDDEN.getReasonPhrase());
return;
}
// Put exception into request scope (perhaps of use to a view)
request.setAttribute(WebAttributes.ACCESS_DENIED_403, accessDeniedException);
// Set the 403 status code.
response.setStatus(HttpStatus.FORBIDDEN.value());
// forward to error page.
request.getRequestDispatcher(this.errorPage).forward(request, response);
} }
/** /**
@ -78,10 +75,7 @@ public class AccessDeniedHandlerImpl implements AccessDeniedHandler {
* limitations * limitations
*/ */
public void setErrorPage(String errorPage) { public void setErrorPage(String errorPage) {
if ((errorPage != null) && !errorPage.startsWith("/")) { Assert.isTrue(errorPage == null || errorPage.startsWith("/"), "errorPage must begin with '/'");
throw new IllegalArgumentException("errorPage must begin with '/'");
}
this.errorPage = errorPage; this.errorPage = errorPage;
} }

View File

@ -21,6 +21,7 @@ import java.util.Collection;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.ConfigAttribute;
import org.springframework.security.access.intercept.AbstractSecurityInterceptor; import org.springframework.security.access.intercept.AbstractSecurityInterceptor;
@ -47,7 +48,6 @@ public class DefaultWebInvocationPrivilegeEvaluator implements WebInvocationPriv
"AbstractSecurityInterceptor does not support FilterInvocations"); "AbstractSecurityInterceptor does not support FilterInvocations");
Assert.notNull(securityInterceptor.getAccessDecisionManager(), Assert.notNull(securityInterceptor.getAccessDecisionManager(),
"AbstractSecurityInterceptor must provide a non-null AccessDecisionManager"); "AbstractSecurityInterceptor must provide a non-null AccessDecisionManager");
this.securityInterceptor = securityInterceptor; this.securityInterceptor = securityInterceptor;
} }
@ -82,34 +82,23 @@ public class DefaultWebInvocationPrivilegeEvaluator implements WebInvocationPriv
@Override @Override
public boolean isAllowed(String contextPath, String uri, String method, Authentication authentication) { public boolean isAllowed(String contextPath, String uri, String method, Authentication authentication) {
Assert.notNull(uri, "uri parameter is required"); Assert.notNull(uri, "uri parameter is required");
FilterInvocation filterInvocation = new FilterInvocation(contextPath, uri, method);
FilterInvocation fi = new FilterInvocation(contextPath, uri, method); Collection<ConfigAttribute> attributes = this.securityInterceptor.obtainSecurityMetadataSource()
Collection<ConfigAttribute> attrs = this.securityInterceptor.obtainSecurityMetadataSource().getAttributes(fi); .getAttributes(filterInvocation);
if (attributes == null) {
if (attrs == null) { return (!this.securityInterceptor.isRejectPublicInvocations());
if (this.securityInterceptor.isRejectPublicInvocations()) {
return false;
}
return true;
} }
if (authentication == null) { if (authentication == null) {
return false; return false;
} }
try { try {
this.securityInterceptor.getAccessDecisionManager().decide(authentication, fi, attrs); this.securityInterceptor.getAccessDecisionManager().decide(authentication, filterInvocation, attributes);
return true;
} }
catch (AccessDeniedException unauthorized) { catch (AccessDeniedException ex) {
if (logger.isDebugEnabled()) { logger.debug(LogMessage.format("%s denied for %s", filterInvocation, authentication), ex);
logger.debug(fi.toString() + " denied for " + authentication.toString(), unauthorized);
}
return false; return false;
} }
return true;
} }
} }

View File

@ -26,6 +26,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.springframework.context.support.MessageSourceAccessor; import org.springframework.context.support.MessageSourceAccessor;
import org.springframework.core.log.LogMessage;
import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.authentication.AuthenticationTrustResolver; import org.springframework.security.authentication.AuthenticationTrustResolver;
import org.springframework.security.authentication.AuthenticationTrustResolverImpl; import org.springframework.security.authentication.AuthenticationTrustResolverImpl;
@ -107,14 +108,15 @@ public class ExceptionTranslationFilter extends GenericFilterBean {
} }
@Override @Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException { throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req; doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
HttpServletResponse response = (HttpServletResponse) res; }
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
try { try {
chain.doFilter(request, response); chain.doFilter(request, response);
this.logger.debug("Chain processed normally"); this.logger.debug("Chain processed normally");
} }
catch (IOException ex) { catch (IOException ex) {
@ -123,38 +125,36 @@ public class ExceptionTranslationFilter extends GenericFilterBean {
catch (Exception ex) { catch (Exception ex) {
// Try to extract a SpringSecurityException from the stacktrace // Try to extract a SpringSecurityException from the stacktrace
Throwable[] causeChain = this.throwableAnalyzer.determineCauseChain(ex); Throwable[] causeChain = this.throwableAnalyzer.determineCauseChain(ex);
RuntimeException ase = (AuthenticationException) this.throwableAnalyzer RuntimeException securityException = (AuthenticationException) this.throwableAnalyzer
.getFirstThrowableOfType(AuthenticationException.class, causeChain); .getFirstThrowableOfType(AuthenticationException.class, causeChain);
if (securityException == null) {
if (ase == null) { securityException = (AccessDeniedException) this.throwableAnalyzer
ase = (AccessDeniedException) this.throwableAnalyzer
.getFirstThrowableOfType(AccessDeniedException.class, causeChain); .getFirstThrowableOfType(AccessDeniedException.class, causeChain);
} }
if (securityException == null) {
if (ase != null) { rethrow(ex);
if (response.isCommitted()) {
throw new ServletException(
"Unable to handle the Spring Security Exception because the response is already committed.",
ex);
}
handleSpringSecurityException(request, response, chain, ase);
} }
else { if (response.isCommitted()) {
// Rethrow ServletExceptions and RuntimeExceptions as-is throw new ServletException("Unable to handle the Spring Security Exception "
if (ex instanceof ServletException) { + "because the response is already committed.", ex);
throw (ServletException) ex;
}
else if (ex instanceof RuntimeException) {
throw (RuntimeException) ex;
}
// Wrap other Exceptions. This shouldn't actually happen
// as we've already covered all the possibilities for doFilter
throw new RuntimeException(ex);
} }
handleSpringSecurityException(request, response, chain, securityException);
} }
} }
private void rethrow(Exception ex) throws ServletException {
// Rethrow ServletExceptions and RuntimeExceptions as-is
if (ex instanceof ServletException) {
throw (ServletException) ex;
}
if (ex instanceof RuntimeException) {
throw (RuntimeException) ex;
}
// Wrap other Exceptions. This shouldn't actually happen
// as we've already covered all the possibilities for doFilter
throw new RuntimeException(ex);
}
public AuthenticationEntryPoint getAuthenticationEntryPoint() { public AuthenticationEntryPoint getAuthenticationEntryPoint() {
return this.authenticationEntryPoint; return this.authenticationEntryPoint;
} }
@ -166,32 +166,36 @@ public class ExceptionTranslationFilter extends GenericFilterBean {
private void handleSpringSecurityException(HttpServletRequest request, HttpServletResponse response, private void handleSpringSecurityException(HttpServletRequest request, HttpServletResponse response,
FilterChain chain, RuntimeException exception) throws IOException, ServletException { FilterChain chain, RuntimeException exception) throws IOException, ServletException {
if (exception instanceof AuthenticationException) { if (exception instanceof AuthenticationException) {
this.logger.debug("Authentication exception occurred; redirecting to authentication entry point", handleAuthenticationException(request, response, chain, (AuthenticationException) exception);
exception);
sendStartAuthentication(request, response, chain, (AuthenticationException) exception);
} }
else if (exception instanceof AccessDeniedException) { else if (exception instanceof AccessDeniedException) {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); handleAccessDeniedException(request, response, chain, (AccessDeniedException) exception);
if (this.authenticationTrustResolver.isAnonymous(authentication) }
|| this.authenticationTrustResolver.isRememberMe(authentication)) { }
this.logger.debug(
"Access is denied (user is " + (this.authenticationTrustResolver.isAnonymous(authentication)
? "anonymous" : "not fully authenticated")
+ "); redirecting to authentication entry point",
exception);
sendStartAuthentication(request, response, chain, private void handleAuthenticationException(HttpServletRequest request, HttpServletResponse response,
new InsufficientAuthenticationException( FilterChain chain, AuthenticationException exception) throws ServletException, IOException {
this.messages.getMessage("ExceptionTranslationFilter.insufficientAuthentication", this.logger.debug("Authentication exception occurred; redirecting to authentication entry point", exception);
"Full authentication is required to access this resource"))); sendStartAuthentication(request, response, chain, exception);
} }
else {
this.logger.debug("Access is denied (user is not anonymous); delegating to AccessDeniedHandler",
exception);
this.accessDeniedHandler.handle(request, response, (AccessDeniedException) exception); private void handleAccessDeniedException(HttpServletRequest request, HttpServletResponse response,
} FilterChain chain, AccessDeniedException exception) throws ServletException, IOException {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
boolean isAnonymous = this.authenticationTrustResolver.isAnonymous(authentication);
if (isAnonymous || this.authenticationTrustResolver.isRememberMe(authentication)) {
this.logger.debug(LogMessage
.of(() -> "Access is denied (user is " + (isAnonymous ? "anonymous" : "not fully authenticated")
+ "); redirecting to authentication entry point"),
exception);
sendStartAuthentication(request, response, chain,
new InsufficientAuthenticationException(
this.messages.getMessage("ExceptionTranslationFilter.insufficientAuthentication",
"Full authentication is required to access this resource")));
}
else {
this.logger.debug("Access is denied (user is not anonymous); delegating to AccessDeniedHandler", exception);
this.accessDeniedHandler.handle(request, response, exception);
} }
} }
@ -232,7 +236,6 @@ public class ExceptionTranslationFilter extends GenericFilterBean {
@Override @Override
protected void initExtractorMap() { protected void initExtractorMap() {
super.initExtractorMap(); super.initExtractorMap();
registerExtractor(ServletException.class, (throwable) -> { registerExtractor(ServletException.class, (throwable) -> {
ThrowableAnalyzer.verifyThrowableHierarchy(throwable, ServletException.class); ThrowableAnalyzer.verifyThrowableHierarchy(throwable, ServletException.class);
return ((ServletException) throwable).getRootCause(); return ((ServletException) throwable).getRootCause();

View File

@ -24,6 +24,7 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.PortMapper; import org.springframework.security.web.PortMapper;
import org.springframework.security.web.PortMapperImpl; import org.springframework.security.web.PortMapperImpl;
@ -43,10 +44,14 @@ public abstract class AbstractRetryEntryPoint implements ChannelEntryPoint {
private PortResolver portResolver = new PortResolverImpl(); private PortResolver portResolver = new PortResolverImpl();
/** The scheme ("http://" or "https://") */ /**
* The scheme ("http://" or "https://")
*/
private final String scheme; private final String scheme;
/** The standard port for the scheme (80 for http, 443 for https) */ /**
* The standard port for the scheme (80 for http, 443 for https)
*/
private final int standardPort; private final int standardPort;
private RedirectStrategy redirectStrategy = new DefaultRedirectStrategy(); private RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
@ -60,21 +65,14 @@ public abstract class AbstractRetryEntryPoint implements ChannelEntryPoint {
public void commence(HttpServletRequest request, HttpServletResponse response) throws IOException { public void commence(HttpServletRequest request, HttpServletResponse response) throws IOException {
String queryString = request.getQueryString(); String queryString = request.getQueryString();
String redirectUrl = request.getRequestURI() + ((queryString != null) ? ("?" + queryString) : ""); String redirectUrl = request.getRequestURI() + ((queryString != null) ? ("?" + queryString) : "");
Integer currentPort = this.portResolver.getServerPort(request); Integer currentPort = this.portResolver.getServerPort(request);
Integer redirectPort = getMappedPort(currentPort); Integer redirectPort = getMappedPort(currentPort);
if (redirectPort != null) { if (redirectPort != null) {
boolean includePort = redirectPort != this.standardPort; boolean includePort = redirectPort != this.standardPort;
String port = (includePort) ? (":" + redirectPort) : "";
redirectUrl = this.scheme + request.getServerName() + ((includePort) ? (":" + redirectPort) : "") redirectUrl = this.scheme + request.getServerName() + port + redirectUrl;
+ redirectUrl;
} }
this.logger.debug(LogMessage.format("Redirecting to: %s", redirectUrl));
if (this.logger.isDebugEnabled()) {
this.logger.debug("Redirecting to: " + redirectUrl);
}
this.redirectStrategy.sendRedirect(request, response, redirectUrl); this.redirectStrategy.sendRedirect(request, response, redirectUrl);
} }

View File

@ -64,10 +64,8 @@ public class ChannelDecisionManagerImpl implements ChannelDecisionManager, Initi
return; return;
} }
} }
for (ChannelProcessor processor : this.channelProcessors) { for (ChannelProcessor processor : this.channelProcessors) {
processor.decide(invocation, config); processor.decide(invocation, config);
if (invocation.getResponse().isCommitted()) { if (invocation.getResponse().isCommitted()) {
break; break;
} }
@ -79,11 +77,10 @@ public class ChannelDecisionManagerImpl implements ChannelDecisionManager, Initi
} }
@SuppressWarnings("cast") @SuppressWarnings("cast")
public void setChannelProcessors(List<?> newList) { public void setChannelProcessors(List<?> channelProcessors) {
Assert.notEmpty(newList, "A list of ChannelProcessors is required"); Assert.notEmpty(channelProcessors, "A list of ChannelProcessors is required");
this.channelProcessors = new ArrayList<>(newList.size()); this.channelProcessors = new ArrayList<>(channelProcessors.size());
for (Object currentObject : channelProcessors) {
for (Object currentObject : newList) {
Assert.isInstanceOf(ChannelProcessor.class, currentObject, () -> "ChannelProcessor " Assert.isInstanceOf(ChannelProcessor.class, currentObject, () -> "ChannelProcessor "
+ currentObject.getClass().getName() + " must implement ChannelProcessor"); + currentObject.getClass().getName() + " must implement ChannelProcessor");
this.channelProcessors.add((ChannelProcessor) currentObject); this.channelProcessors.add((ChannelProcessor) currentObject);
@ -95,13 +92,11 @@ public class ChannelDecisionManagerImpl implements ChannelDecisionManager, Initi
if (ANY_CHANNEL.equals(attribute.getAttribute())) { if (ANY_CHANNEL.equals(attribute.getAttribute())) {
return true; return true;
} }
for (ChannelProcessor processor : this.channelProcessors) { for (ChannelProcessor processor : this.channelProcessors) {
if (processor.supports(attribute)) { if (processor.supports(attribute)) {
return true; return true;
} }
} }
return false; return false;
} }

View File

@ -28,6 +28,7 @@ import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.springframework.core.log.LogMessage;
import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.ConfigAttribute;
import org.springframework.security.web.FilterInvocation; import org.springframework.security.web.FilterInvocation;
import org.springframework.security.web.access.intercept.FilterInvocationSecurityMetadataSource; import org.springframework.security.web.access.intercept.FilterInvocationSecurityMetadataSource;
@ -93,35 +94,26 @@ public class ChannelProcessingFilter extends GenericFilterBean {
public void afterPropertiesSet() { public void afterPropertiesSet() {
Assert.notNull(this.securityMetadataSource, "securityMetadataSource must be specified"); Assert.notNull(this.securityMetadataSource, "securityMetadataSource must be specified");
Assert.notNull(this.channelDecisionManager, "channelDecisionManager must be specified"); Assert.notNull(this.channelDecisionManager, "channelDecisionManager must be specified");
Collection<ConfigAttribute> attributes = this.securityMetadataSource.getAllConfigAttributes();
Collection<ConfigAttribute> attrDefs = this.securityMetadataSource.getAllConfigAttributes(); if (attributes == null) {
this.logger.warn("Could not validate configuration attributes as the "
if (attrDefs == null) { + "FilterInvocationSecurityMetadataSource did not return any attributes");
if (this.logger.isWarnEnabled()) {
this.logger.warn(
"Could not validate configuration attributes as the FilterInvocationSecurityMetadataSource did "
+ "not return any attributes");
}
return; return;
} }
Set<ConfigAttribute> unsupportedAttributes = getUnsupportedAttributes(attributes);
Assert.isTrue(unsupportedAttributes.isEmpty(),
() -> "Unsupported configuration attributes: " + unsupportedAttributes);
this.logger.info("Validated configuration attributes");
}
private Set<ConfigAttribute> getUnsupportedAttributes(Collection<ConfigAttribute> attrDefs) {
Set<ConfigAttribute> unsupportedAttributes = new HashSet<>(); Set<ConfigAttribute> unsupportedAttributes = new HashSet<>();
for (ConfigAttribute attr : attrDefs) { for (ConfigAttribute attr : attrDefs) {
if (!this.channelDecisionManager.supports(attr)) { if (!this.channelDecisionManager.supports(attr)) {
unsupportedAttributes.add(attr); unsupportedAttributes.add(attr);
} }
} }
return unsupportedAttributes;
if (unsupportedAttributes.size() == 0) {
if (this.logger.isInfoEnabled()) {
this.logger.info("Validated configuration attributes");
}
}
else {
throw new IllegalArgumentException("Unsupported configuration attributes: " + unsupportedAttributes);
}
} }
@Override @Override
@ -129,22 +121,15 @@ public class ChannelProcessingFilter extends GenericFilterBean {
throws IOException, ServletException { throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req; HttpServletRequest request = (HttpServletRequest) req;
HttpServletResponse response = (HttpServletResponse) res; HttpServletResponse response = (HttpServletResponse) res;
FilterInvocation filterInvocation = new FilterInvocation(request, response, chain);
FilterInvocation fi = new FilterInvocation(request, response, chain); Collection<ConfigAttribute> attributes = this.securityMetadataSource.getAttributes(filterInvocation);
Collection<ConfigAttribute> attr = this.securityMetadataSource.getAttributes(fi); if (attributes != null) {
this.logger.debug(LogMessage.format("Request: %s; ConfigAttributes: %s", filterInvocation, attributes));
if (attr != null) { this.channelDecisionManager.decide(filterInvocation, attributes);
if (this.logger.isDebugEnabled()) { if (filterInvocation.getResponse().isCommitted()) {
this.logger.debug("Request: " + fi.toString() + "; ConfigAttributes: " + attr);
}
this.channelDecisionManager.decide(fi, attr);
if (fi.getResponse().isCommitted()) {
return; return;
} }
} }
chain.doFilter(request, response); chain.doFilter(request, response);
} }

View File

@ -40,7 +40,6 @@ public interface ChannelProcessor {
/** /**
* Decided whether the presented {@link FilterInvocation} provides the appropriate * Decided whether the presented {@link FilterInvocation} provides the appropriate
* level of channel security based on the requested list of <tt>ConfigAttribute</tt>s. * level of channel security based on the requested list of <tt>ConfigAttribute</tt>s.
*
*/ */
void decide(FilterInvocation invocation, Collection<ConfigAttribute> config) throws IOException, ServletException; void decide(FilterInvocation invocation, Collection<ConfigAttribute> config) throws IOException, ServletException;

View File

@ -55,10 +55,7 @@ public class InsecureChannelProcessor implements InitializingBean, ChannelProces
@Override @Override
public void decide(FilterInvocation invocation, Collection<ConfigAttribute> config) public void decide(FilterInvocation invocation, Collection<ConfigAttribute> config)
throws IOException, ServletException { throws IOException, ServletException {
if ((invocation == null) || (config == null)) { Assert.isTrue(invocation != null && config != null, "Nulls cannot be provided");
throw new IllegalArgumentException("Nulls cannot be provided");
}
for (ConfigAttribute attribute : config) { for (ConfigAttribute attribute : config) {
if (supports(attribute)) { if (supports(attribute)) {
if (invocation.getHttpRequest().isSecure()) { if (invocation.getHttpRequest().isSecure()) {

View File

@ -56,7 +56,6 @@ public class SecureChannelProcessor implements InitializingBean, ChannelProcesso
public void decide(FilterInvocation invocation, Collection<ConfigAttribute> config) public void decide(FilterInvocation invocation, Collection<ConfigAttribute> config)
throws IOException, ServletException { throws IOException, ServletException {
Assert.isTrue((invocation != null) && (config != null), "Nulls cannot be provided"); Assert.isTrue((invocation != null) && (config != null), "Nulls cannot be provided");
for (ConfigAttribute attribute : config) { for (ConfigAttribute attribute : config) {
if (supports(attribute)) { if (supports(attribute)) {
if (!invocation.getHttpRequest().isSecure()) { if (!invocation.getHttpRequest().isSecure()) {

View File

@ -41,25 +41,37 @@ abstract class AbstractVariableEvaluationContextPostProcessor
@Override @Override
public final EvaluationContext postProcess(EvaluationContext context, FilterInvocation invocation) { public final EvaluationContext postProcess(EvaluationContext context, FilterInvocation invocation) {
final HttpServletRequest request = invocation.getHttpRequest(); return new VariableEvaluationContext(context, invocation.getHttpRequest());
return new DelegatingEvaluationContext(context) {
private Map<String, String> variables;
@Override
public Object lookupVariable(String name) {
Object result = super.lookupVariable(name);
if (result != null) {
return result;
}
if (this.variables == null) {
this.variables = extractVariables(request);
}
return this.variables.get(name);
}
};
} }
abstract Map<String, String> extractVariables(HttpServletRequest request); abstract Map<String, String> extractVariables(HttpServletRequest request);
/**
* {@link DelegatingEvaluationContext} to expose variable.
*/
class VariableEvaluationContext extends DelegatingEvaluationContext {
private final HttpServletRequest request;
private Map<String, String> variables;
VariableEvaluationContext(EvaluationContext delegate, HttpServletRequest request) {
super(delegate);
this.request = request;
}
@Override
public Object lookupVariable(String name) {
Object result = super.lookupVariable(name);
if (result != null) {
return result;
}
if (this.variables == null) {
this.variables = extractVariables(this.request);
}
return this.variables.get(name);
}
}
} }

View File

@ -20,6 +20,7 @@ import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.function.BiConsumer;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
@ -58,29 +59,29 @@ public final class ExpressionBasedFilterInvocationSecurityMetadataSource
private static LinkedHashMap<RequestMatcher, Collection<ConfigAttribute>> processMap( private static LinkedHashMap<RequestMatcher, Collection<ConfigAttribute>> processMap(
LinkedHashMap<RequestMatcher, Collection<ConfigAttribute>> requestMap, ExpressionParser parser) { LinkedHashMap<RequestMatcher, Collection<ConfigAttribute>> requestMap, ExpressionParser parser) {
Assert.notNull(parser, "SecurityExpressionHandler returned a null parser object"); Assert.notNull(parser, "SecurityExpressionHandler returned a null parser object");
LinkedHashMap<RequestMatcher, Collection<ConfigAttribute>> processed = new LinkedHashMap<>(requestMap);
requestMap.forEach((request, value) -> process(parser, request, value, processed::put));
return processed;
}
LinkedHashMap<RequestMatcher, Collection<ConfigAttribute>> requestToExpressionAttributesMap = new LinkedHashMap<>( private static void process(ExpressionParser parser, RequestMatcher request, Collection<ConfigAttribute> value,
requestMap); BiConsumer<RequestMatcher, Collection<ConfigAttribute>> consumer) {
String expression = getExpression(request, value);
for (Map.Entry<RequestMatcher, Collection<ConfigAttribute>> entry : requestMap.entrySet()) { logger.debug("Adding web access control expression '" + expression + "', for " + request);
RequestMatcher request = entry.getKey(); AbstractVariableEvaluationContextPostProcessor postProcessor = createPostProcessor(request);
Assert.isTrue(entry.getValue().size() == 1, () -> "Expected a single expression attribute for " + request); ArrayList<ConfigAttribute> processed = new ArrayList<>(1);
ArrayList<ConfigAttribute> attributes = new ArrayList<>(1); try {
String expression = entry.getValue().toArray(new ConfigAttribute[1])[0].getAttribute(); processed.add(new WebExpressionConfigAttribute(parser.parseExpression(expression), postProcessor));
logger.debug("Adding web access control expression '" + expression + "', for " + request);
AbstractVariableEvaluationContextPostProcessor postProcessor = createPostProcessor(request);
try {
attributes.add(new WebExpressionConfigAttribute(parser.parseExpression(expression), postProcessor));
}
catch (ParseException ex) {
throw new IllegalArgumentException("Failed to parse expression '" + expression + "'");
}
requestToExpressionAttributesMap.put(request, attributes);
} }
catch (ParseException ex) {
throw new IllegalArgumentException("Failed to parse expression '" + expression + "'");
}
consumer.accept(request, processed);
}
return requestToExpressionAttributesMap; private static String getExpression(RequestMatcher request, Collection<ConfigAttribute> value) {
Assert.isTrue(value.size() == 1, () -> "Expected a single expression attribute for " + request);
return value.toArray(new ConfigAttribute[1])[0].getAttribute();
} }
private static AbstractVariableEvaluationContextPostProcessor createPostProcessor(RequestMatcher request) { private static AbstractVariableEvaluationContextPostProcessor createPostProcessor(RequestMatcher request) {

View File

@ -25,6 +25,7 @@ import org.springframework.security.access.expression.ExpressionUtils;
import org.springframework.security.access.expression.SecurityExpressionHandler; import org.springframework.security.access.expression.SecurityExpressionHandler;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.web.FilterInvocation; import org.springframework.security.web.FilterInvocation;
import org.springframework.util.Assert;
/** /**
* Voter which handles web authorisation decisions. * Voter which handles web authorisation decisions.
@ -37,21 +38,19 @@ public class WebExpressionVoter implements AccessDecisionVoter<FilterInvocation>
private SecurityExpressionHandler<FilterInvocation> expressionHandler = new DefaultWebSecurityExpressionHandler(); private SecurityExpressionHandler<FilterInvocation> expressionHandler = new DefaultWebSecurityExpressionHandler();
@Override @Override
public int vote(Authentication authentication, FilterInvocation fi, Collection<ConfigAttribute> attributes) { public int vote(Authentication authentication, FilterInvocation filterInvocation,
assert authentication != null; Collection<ConfigAttribute> attributes) {
assert fi != null; Assert.notNull(authentication, "authentication must not be null");
assert attributes != null; Assert.notNull(filterInvocation, "filterInvocation must not be null");
Assert.notNull(attributes, "attributes must not be null");
WebExpressionConfigAttribute weca = findConfigAttribute(attributes); WebExpressionConfigAttribute webExpressionConfigAttribute = findConfigAttribute(attributes);
if (webExpressionConfigAttribute == null) {
if (weca == null) {
return ACCESS_ABSTAIN; return ACCESS_ABSTAIN;
} }
EvaluationContext ctx = webExpressionConfigAttribute.postProcess(
EvaluationContext ctx = this.expressionHandler.createEvaluationContext(authentication, fi); this.expressionHandler.createEvaluationContext(authentication, filterInvocation), filterInvocation);
ctx = weca.postProcess(ctx, fi); return ExpressionUtils.evaluateAsBoolean(webExpressionConfigAttribute.getAuthorizeExpression(), ctx)
? ACCESS_GRANTED : ACCESS_DENIED;
return ExpressionUtils.evaluateAsBoolean(weca.getAuthorizeExpression(), ctx) ? ACCESS_GRANTED : ACCESS_DENIED;
} }
private WebExpressionConfigAttribute findConfigAttribute(Collection<ConfigAttribute> attributes) { private WebExpressionConfigAttribute findConfigAttribute(Collection<ConfigAttribute> attributes) {

View File

@ -29,13 +29,13 @@ import org.springframework.security.web.util.matcher.IpAddressMatcher;
*/ */
public class WebSecurityExpressionRoot extends SecurityExpressionRoot { public class WebSecurityExpressionRoot extends SecurityExpressionRoot {
// private FilterInvocation filterInvocation; /**
/** Allows direct access to the request object */ * Allows direct access to the request object
*/
public final HttpServletRequest request; public final HttpServletRequest request;
public WebSecurityExpressionRoot(Authentication a, FilterInvocation fi) { public WebSecurityExpressionRoot(Authentication a, FilterInvocation fi) {
super(a); super(a);
// this.filterInvocation = fi;
this.request = fi.getRequest(); this.request = fi.getRequest();
} }
@ -47,7 +47,8 @@ public class WebSecurityExpressionRoot extends SecurityExpressionRoot {
* @return true if the IP address of the current request is in the required range. * @return true if the IP address of the current request is in the required range.
*/ */
public boolean hasIpAddress(String ipAddress) { public boolean hasIpAddress(String ipAddress) {
return (new IpAddressMatcher(ipAddress).matches(this.request)); IpAddressMatcher matcher = new IpAddressMatcher(ipAddress);
return matcher.matches(this.request);
} }
} }

View File

@ -65,18 +65,13 @@ public class DefaultFilterInvocationSecurityMetadataSource implements FilterInvo
*/ */
public DefaultFilterInvocationSecurityMetadataSource( public DefaultFilterInvocationSecurityMetadataSource(
LinkedHashMap<RequestMatcher, Collection<ConfigAttribute>> requestMap) { LinkedHashMap<RequestMatcher, Collection<ConfigAttribute>> requestMap) {
this.requestMap = requestMap; this.requestMap = requestMap;
} }
@Override @Override
public Collection<ConfigAttribute> getAllConfigAttributes() { public Collection<ConfigAttribute> getAllConfigAttributes() {
Set<ConfigAttribute> allAttributes = new HashSet<>(); Set<ConfigAttribute> allAttributes = new HashSet<>();
this.requestMap.values().forEach(allAttributes::addAll);
for (Map.Entry<RequestMatcher, Collection<ConfigAttribute>> entry : this.requestMap.entrySet()) {
allAttributes.addAll(entry.getValue());
}
return allAttributes; return allAttributes;
} }

View File

@ -78,8 +78,7 @@ public class FilterSecurityInterceptor extends AbstractSecurityInterceptor imple
@Override @Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException { throws IOException, ServletException {
FilterInvocation fi = new FilterInvocation(request, response, chain); invoke(new FilterInvocation(request, response, chain));
invoke(fi);
} }
public FilterInvocationSecurityMetadataSource getSecurityMetadataSource() { public FilterInvocationSecurityMetadataSource getSecurityMetadataSource() {
@ -100,30 +99,30 @@ public class FilterSecurityInterceptor extends AbstractSecurityInterceptor imple
return FilterInvocation.class; return FilterInvocation.class;
} }
public void invoke(FilterInvocation fi) throws IOException, ServletException { public void invoke(FilterInvocation filterInvocation) throws IOException, ServletException {
if ((fi.getRequest() != null) && (fi.getRequest().getAttribute(FILTER_APPLIED) != null) if (isApplied(filterInvocation) && this.observeOncePerRequest) {
&& this.observeOncePerRequest) {
// filter already applied to this request and user wants us to observe // filter already applied to this request and user wants us to observe
// once-per-request handling, so don't re-do security checking // once-per-request handling, so don't re-do security checking
fi.getChain().doFilter(fi.getRequest(), fi.getResponse()); filterInvocation.getChain().doFilter(filterInvocation.getRequest(), filterInvocation.getResponse());
return;
} }
else { // first time this request being called, so perform security checking
// first time this request being called, so perform security checking if (filterInvocation.getRequest() != null && this.observeOncePerRequest) {
if (fi.getRequest() != null && this.observeOncePerRequest) { filterInvocation.getRequest().setAttribute(FILTER_APPLIED, Boolean.TRUE);
fi.getRequest().setAttribute(FILTER_APPLIED, Boolean.TRUE);
}
InterceptorStatusToken token = super.beforeInvocation(fi);
try {
fi.getChain().doFilter(fi.getRequest(), fi.getResponse());
}
finally {
super.finallyInvocation(token);
}
super.afterInvocation(token, null);
} }
InterceptorStatusToken token = super.beforeInvocation(filterInvocation);
try {
filterInvocation.getChain().doFilter(filterInvocation.getRequest(), filterInvocation.getResponse());
}
finally {
super.finallyInvocation(token);
}
super.afterInvocation(token, null);
}
private boolean isApplied(FilterInvocation filterInvocation) {
return (filterInvocation.getRequest() != null)
&& (filterInvocation.getRequest().getAttribute(FILTER_APPLIED) != null);
} }
/** /**

View File

@ -77,7 +77,6 @@ public class RequestKey {
} }
sb.append(this.url); sb.append(this.url);
sb.append("]"); sb.append("]");
return sb.toString(); return sb.toString();
} }

View File

@ -30,6 +30,7 @@ import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.context.MessageSource; import org.springframework.context.MessageSource;
import org.springframework.context.MessageSourceAware; import org.springframework.context.MessageSourceAware;
import org.springframework.context.support.MessageSourceAccessor; import org.springframework.context.support.MessageSourceAccessor;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.InternalAuthenticationServiceException; import org.springframework.security.authentication.InternalAuthenticationServiceException;
@ -206,52 +207,39 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
* </ol> * </ol>
*/ */
@Override @Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException { throws IOException, ServletException {
doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
}
HttpServletRequest request = (HttpServletRequest) req; private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
HttpServletResponse response = (HttpServletResponse) res; throws IOException, ServletException {
if (!requiresAuthentication(request, response)) { if (!requiresAuthentication(request, response)) {
chain.doFilter(request, response); chain.doFilter(request, response);
return; return;
} }
this.logger.debug("Request is to process authentication");
if (this.logger.isDebugEnabled()) {
this.logger.debug("Request is to process authentication");
}
Authentication authResult;
try { try {
authResult = attemptAuthentication(request, response); Authentication authenticationResult = attemptAuthentication(request, response);
if (authResult == null) { if (authenticationResult == null) {
// return immediately as subclass has indicated that it hasn't completed // return immediately as subclass has indicated that it hasn't completed
// authentication
return; return;
} }
this.sessionStrategy.onAuthentication(authResult, request, response); this.sessionStrategy.onAuthentication(authenticationResult, request, response);
// Authentication success
if (this.continueChainBeforeSuccessfulAuthentication) {
chain.doFilter(request, response);
}
successfulAuthentication(request, response, chain, authenticationResult);
} }
catch (InternalAuthenticationServiceException failed) { catch (InternalAuthenticationServiceException failed) {
this.logger.error("An internal error occurred while trying to authenticate the user.", failed); this.logger.error("An internal error occurred while trying to authenticate the user.", failed);
unsuccessfulAuthentication(request, response, failed); unsuccessfulAuthentication(request, response, failed);
return;
} }
catch (AuthenticationException failed) { catch (AuthenticationException ex) {
// Authentication failed // Authentication failed
unsuccessfulAuthentication(request, response, failed); unsuccessfulAuthentication(request, response, ex);
return;
} }
// Authentication success
if (this.continueChainBeforeSuccessfulAuthentication) {
chain.doFilter(request, response);
}
successfulAuthentication(request, response, chain, authResult);
} }
/** /**
@ -316,20 +304,13 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
*/ */
protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain, protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain,
Authentication authResult) throws IOException, ServletException { Authentication authResult) throws IOException, ServletException {
this.logger.debug(
if (this.logger.isDebugEnabled()) { LogMessage.format("Authentication success. Updating SecurityContextHolder to contain: %s", authResult));
this.logger.debug("Authentication success. Updating SecurityContextHolder to contain: " + authResult);
}
SecurityContextHolder.getContext().setAuthentication(authResult); SecurityContextHolder.getContext().setAuthentication(authResult);
this.rememberMeServices.loginSuccess(request, response, authResult); this.rememberMeServices.loginSuccess(request, response, authResult);
// Fire event
if (this.eventPublisher != null) { if (this.eventPublisher != null) {
this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass())); this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass()));
} }
this.successHandler.onAuthenticationSuccess(request, response, authResult); this.successHandler.onAuthenticationSuccess(request, response, authResult);
} }
@ -347,15 +328,12 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response,
AuthenticationException failed) throws IOException, ServletException { AuthenticationException failed) throws IOException, ServletException {
SecurityContextHolder.clearContext(); SecurityContextHolder.clearContext();
if (this.logger.isDebugEnabled()) { if (this.logger.isDebugEnabled()) {
this.logger.debug("Authentication request failed: " + failed.toString(), failed); this.logger.debug("Authentication request failed: " + failed.toString(), failed);
this.logger.debug("Updated SecurityContextHolder to contain null Authentication"); this.logger.debug("Updated SecurityContextHolder to contain null Authentication");
this.logger.debug("Delegating to authentication failure handler " + this.failureHandler); this.logger.debug("Delegating to authentication failure handler " + this.failureHandler);
} }
this.rememberMeServices.loginFail(request, response); this.rememberMeServices.loginFail(request, response);
this.failureHandler.onAuthenticationFailure(request, response, failed); this.failureHandler.onAuthenticationFailure(request, response, failed);
} }

View File

@ -25,6 +25,7 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.RedirectStrategy;
@ -84,18 +85,16 @@ public abstract class AbstractAuthenticationTargetUrlRequestHandler {
protected void handle(HttpServletRequest request, HttpServletResponse response, Authentication authentication) protected void handle(HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException, ServletException { throws IOException, ServletException {
String targetUrl = determineTargetUrl(request, response, authentication); String targetUrl = determineTargetUrl(request, response, authentication);
if (response.isCommitted()) { if (response.isCommitted()) {
this.logger.debug("Response has already been committed. Unable to redirect to " + targetUrl); this.logger.debug(
LogMessage.format("Response has already been committed. Unable to redirect to %s", targetUrl));
return; return;
} }
this.redirectStrategy.sendRedirect(request, response, targetUrl); this.redirectStrategy.sendRedirect(request, response, targetUrl);
} }
/** /**
* Builds the target URL according to the logic defined in the main class Javadoc * Builds the target URL according to the logic defined in the main class Javadoc
*
* @since 5.2 * @since 5.2
*/ */
protected String determineTargetUrl(HttpServletRequest request, HttpServletResponse response, protected String determineTargetUrl(HttpServletRequest request, HttpServletResponse response,
@ -110,30 +109,23 @@ public abstract class AbstractAuthenticationTargetUrlRequestHandler {
if (isAlwaysUseDefaultTargetUrl()) { if (isAlwaysUseDefaultTargetUrl()) {
return this.defaultTargetUrl; return this.defaultTargetUrl;
} }
// Check for the parameter and use that if available // Check for the parameter and use that if available
String targetUrl = null; String targetUrl = null;
if (this.targetUrlParameter != null) { if (this.targetUrlParameter != null) {
targetUrl = request.getParameter(this.targetUrlParameter); targetUrl = request.getParameter(this.targetUrlParameter);
if (StringUtils.hasText(targetUrl)) { if (StringUtils.hasText(targetUrl)) {
this.logger.debug("Found targetUrlParameter in request: " + targetUrl); this.logger.debug("Found targetUrlParameter in request: " + targetUrl);
return targetUrl; return targetUrl;
} }
} }
if (this.useReferer && !StringUtils.hasLength(targetUrl)) { if (this.useReferer && !StringUtils.hasLength(targetUrl)) {
targetUrl = request.getHeader("Referer"); targetUrl = request.getHeader("Referer");
this.logger.debug("Using Referer header: " + targetUrl); this.logger.debug("Using Referer header: " + targetUrl);
} }
if (!StringUtils.hasText(targetUrl)) { if (!StringUtils.hasText(targetUrl)) {
targetUrl = this.defaultTargetUrl; targetUrl = this.defaultTargetUrl;
this.logger.debug("Using default Url: " + targetUrl); this.logger.debug("Using default Url: " + targetUrl);
} }
return targetUrl; return targetUrl;
} }

View File

@ -26,6 +26,7 @@ import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@ -85,31 +86,24 @@ public class AnonymousAuthenticationFilter extends GenericFilterBean implements
@Override @Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
throws IOException, ServletException { throws IOException, ServletException {
if (SecurityContextHolder.getContext().getAuthentication() == null) { if (SecurityContextHolder.getContext().getAuthentication() == null) {
SecurityContextHolder.getContext().setAuthentication(createAuthentication((HttpServletRequest) req)); SecurityContextHolder.getContext().setAuthentication(createAuthentication((HttpServletRequest) req));
this.logger.debug(LogMessage.of(() -> "Populated SecurityContextHolder with anonymous token: '"
if (this.logger.isDebugEnabled()) { + SecurityContextHolder.getContext().getAuthentication() + "'"));
this.logger.debug("Populated SecurityContextHolder with anonymous token: '"
+ SecurityContextHolder.getContext().getAuthentication() + "'");
}
} }
else { else {
if (this.logger.isDebugEnabled()) { this.logger.debug(LogMessage
this.logger.debug("SecurityContextHolder not populated with anonymous token, as it already contained: '" .of(() -> "SecurityContextHolder not populated with anonymous token, as it already contained: '"
+ SecurityContextHolder.getContext().getAuthentication() + "'"); + SecurityContextHolder.getContext().getAuthentication() + "'"));
}
} }
chain.doFilter(req, res); chain.doFilter(req, res);
} }
protected Authentication createAuthentication(HttpServletRequest request) { protected Authentication createAuthentication(HttpServletRequest request) {
AnonymousAuthenticationToken auth = new AnonymousAuthenticationToken(this.key, this.principal, AnonymousAuthenticationToken token = new AnonymousAuthenticationToken(this.key, this.principal,
this.authorities); this.authorities);
auth.setDetails(this.authenticationDetailsSource.buildDetails(request)); token.setDetails(this.authenticationDetailsSource.buildDetails(request));
return token;
return auth;
} }
public void setAuthenticationDetailsSource( public void setAuthenticationDetailsSource(

View File

@ -29,7 +29,7 @@ import org.springframework.util.Assert;
/** /**
* Adapts a {@link AuthenticationEntryPoint} into a {@link AuthenticationFailureHandler} * Adapts a {@link AuthenticationEntryPoint} into a {@link AuthenticationFailureHandler}
* *
* @author sbespalov * @author Sergey Bespalov
* @since 5.2.0 * @since 5.2.0
*/ */
public class AuthenticationEntryPointFailureHandler implements AuthenticationFailureHandler { public class AuthenticationEntryPointFailureHandler implements AuthenticationFailureHandler {

View File

@ -84,7 +84,6 @@ public class AuthenticationFilter extends OncePerRequestFilter {
AuthenticationConverter authenticationConverter) { AuthenticationConverter authenticationConverter) {
Assert.notNull(authenticationManagerResolver, "authenticationManagerResolver cannot be null"); Assert.notNull(authenticationManagerResolver, "authenticationManagerResolver cannot be null");
Assert.notNull(authenticationConverter, "authenticationConverter cannot be null"); Assert.notNull(authenticationConverter, "authenticationConverter cannot be null");
this.authenticationManagerResolver = authenticationManagerResolver; this.authenticationManagerResolver = authenticationManagerResolver;
this.authenticationConverter = authenticationConverter; this.authenticationConverter = authenticationConverter;
} }
@ -142,19 +141,16 @@ public class AuthenticationFilter extends OncePerRequestFilter {
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
return; return;
} }
try { try {
Authentication authenticationResult = attemptAuthentication(request, response); Authentication authenticationResult = attemptAuthentication(request, response);
if (authenticationResult == null) { if (authenticationResult == null) {
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
return; return;
} }
HttpSession session = request.getSession(false); HttpSession session = request.getSession(false);
if (session != null) { if (session != null) {
request.changeSessionId(); request.changeSessionId();
} }
successfulAuthentication(request, response, filterChain, authenticationResult); successfulAuthentication(request, response, filterChain, authenticationResult);
} }
catch (AuthenticationException ex) { catch (AuthenticationException ex) {
@ -182,13 +178,11 @@ public class AuthenticationFilter extends OncePerRequestFilter {
if (authentication == null) { if (authentication == null) {
return null; return null;
} }
AuthenticationManager authenticationManager = this.authenticationManagerResolver.resolve(request); AuthenticationManager authenticationManager = this.authenticationManagerResolver.resolve(request);
Authentication authenticationResult = authenticationManager.authenticate(authentication); Authentication authenticationResult = authenticationManager.authenticate(authentication);
if (authenticationResult == null) { if (authenticationResult == null) {
throw new ServletException("AuthenticationManager should not return null Authentication object."); throw new ServletException("AuthenticationManager should not return null Authentication object.");
} }
return authenticationResult; return authenticationResult;
} }

View File

@ -27,6 +27,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.util.matcher.ELRequestMatcher; import org.springframework.security.web.util.matcher.ELRequestMatcher;
@ -62,7 +63,7 @@ import org.springframework.util.Assert;
*/ */
public class DelegatingAuthenticationEntryPoint implements AuthenticationEntryPoint, InitializingBean { public class DelegatingAuthenticationEntryPoint implements AuthenticationEntryPoint, InitializingBean {
private final Log logger = LogFactory.getLog(getClass()); private static final Log logger = LogFactory.getLog(DelegatingAuthenticationEntryPoint.class);
private final LinkedHashMap<RequestMatcher, AuthenticationEntryPoint> entryPoints; private final LinkedHashMap<RequestMatcher, AuthenticationEntryPoint> entryPoints;
@ -75,25 +76,16 @@ public class DelegatingAuthenticationEntryPoint implements AuthenticationEntryPo
@Override @Override
public void commence(HttpServletRequest request, HttpServletResponse response, public void commence(HttpServletRequest request, HttpServletResponse response,
AuthenticationException authException) throws IOException, ServletException { AuthenticationException authException) throws IOException, ServletException {
for (RequestMatcher requestMatcher : this.entryPoints.keySet()) { for (RequestMatcher requestMatcher : this.entryPoints.keySet()) {
if (this.logger.isDebugEnabled()) { logger.debug(LogMessage.format("Trying to match using %s", requestMatcher));
this.logger.debug("Trying to match using " + requestMatcher);
}
if (requestMatcher.matches(request)) { if (requestMatcher.matches(request)) {
AuthenticationEntryPoint entryPoint = this.entryPoints.get(requestMatcher); AuthenticationEntryPoint entryPoint = this.entryPoints.get(requestMatcher);
if (this.logger.isDebugEnabled()) { logger.debug(LogMessage.format("Match found! Executing %s", entryPoint));
this.logger.debug("Match found! Executing " + entryPoint);
}
entryPoint.commence(request, response, authException); entryPoint.commence(request, response, authException);
return; return;
} }
} }
logger.debug(LogMessage.format("No match found. Using default entry point %s", this.defaultEntryPoint));
if (this.logger.isDebugEnabled()) {
this.logger.debug("No match found. Using default entry point " + this.defaultEntryPoint);
}
// No EntryPoint matched, use defaultEntryPoint // No EntryPoint matched, use defaultEntryPoint
this.defaultEntryPoint.commence(request, response, authException); this.defaultEntryPoint.commence(request, response, authException);
} }

View File

@ -62,9 +62,6 @@ public class DelegatingAuthenticationFailureHandler implements AuthenticationFai
this.defaultHandler = defaultHandler; this.defaultHandler = defaultHandler;
} }
/**
* {@inheritDoc}
*/
@Override @Override
public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
AuthenticationException exception) throws IOException, ServletException { AuthenticationException exception) throws IOException, ServletException {

View File

@ -49,7 +49,6 @@ public class ExceptionMappingAuthenticationFailureHandler extends SimpleUrlAuthe
public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
AuthenticationException exception) throws IOException, ServletException { AuthenticationException exception) throws IOException, ServletException {
String url = this.failureUrlMap.get(exception.getClass().getName()); String url = this.failureUrlMap.get(exception.getClass().getName());
if (url != null) { if (url != null) {
getRedirectStrategy().sendRedirect(request, response, url); getRedirectStrategy().sendRedirect(request, response, url);
} }

View File

@ -55,9 +55,7 @@ public class Http403ForbiddenEntryPoint implements AuthenticationEntryPoint {
@Override @Override
public void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException arg2) public void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException arg2)
throws IOException { throws IOException {
if (logger.isDebugEnabled()) { logger.debug("Pre-authenticated entry point called. Rejecting access");
logger.debug("Pre-authenticated entry point called. Rejecting access");
}
response.sendError(HttpServletResponse.SC_FORBIDDEN, "Access Denied"); response.sendError(HttpServletResponse.SC_FORBIDDEN, "Access Denied");
} }

View File

@ -27,6 +27,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.DefaultRedirectStrategy;
@ -93,9 +94,8 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin
public void afterPropertiesSet() { public void afterPropertiesSet() {
Assert.isTrue(StringUtils.hasText(this.loginFormUrl) && UrlUtils.isValidRedirectUrl(this.loginFormUrl), Assert.isTrue(StringUtils.hasText(this.loginFormUrl) && UrlUtils.isValidRedirectUrl(this.loginFormUrl),
"loginFormUrl must be specified and must be a valid redirect URL"); "loginFormUrl must be specified and must be a valid redirect URL");
if (this.useForward && UrlUtils.isAbsoluteUrl(this.loginFormUrl)) { Assert.isTrue(!this.useForward || !UrlUtils.isAbsoluteUrl(this.loginFormUrl),
throw new IllegalArgumentException("useForward must be false if using an absolute loginFormURL"); "useForward must be false if using an absolute loginFormURL");
}
Assert.notNull(this.portMapper, "portMapper must be specified"); Assert.notNull(this.portMapper, "portMapper must be specified");
Assert.notNull(this.portResolver, "portResolver must be specified"); Assert.notNull(this.portResolver, "portResolver must be specified");
} }
@ -110,7 +110,6 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin
*/ */
protected String determineUrlToUseForThisRequest(HttpServletRequest request, HttpServletResponse response, protected String determineUrlToUseForThisRequest(HttpServletRequest request, HttpServletResponse response,
AuthenticationException exception) { AuthenticationException exception) {
return getLoginFormUrl(); return getLoginFormUrl();
} }
@ -120,75 +119,55 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin
@Override @Override
public void commence(HttpServletRequest request, HttpServletResponse response, public void commence(HttpServletRequest request, HttpServletResponse response,
AuthenticationException authException) throws IOException, ServletException { AuthenticationException authException) throws IOException, ServletException {
if (!this.useForward) {
String redirectUrl = null;
if (this.useForward) {
if (this.forceHttps && "http".equals(request.getScheme())) {
// First redirect the current request to HTTPS.
// When that request is received, the forward to the login page will be
// used.
redirectUrl = buildHttpsRedirectUrlForRequest(request);
}
if (redirectUrl == null) {
String loginForm = determineUrlToUseForThisRequest(request, response, authException);
if (logger.isDebugEnabled()) {
logger.debug("Server side forward to: " + loginForm);
}
RequestDispatcher dispatcher = request.getRequestDispatcher(loginForm);
dispatcher.forward(request, response);
return;
}
}
else {
// redirect to login page. Use https if forceHttps true // redirect to login page. Use https if forceHttps true
String redirectUrl = buildRedirectUrlToLoginPage(request, response, authException);
redirectUrl = buildRedirectUrlToLoginPage(request, response, authException); this.redirectStrategy.sendRedirect(request, response, redirectUrl);
return;
} }
String redirectUrl = null;
this.redirectStrategy.sendRedirect(request, response, redirectUrl); if (this.forceHttps && "http".equals(request.getScheme())) {
// First redirect the current request to HTTPS. When that request is received,
// the forward to the login page will be used.
redirectUrl = buildHttpsRedirectUrlForRequest(request);
}
if (redirectUrl != null) {
this.redirectStrategy.sendRedirect(request, response, redirectUrl);
return;
}
String loginForm = determineUrlToUseForThisRequest(request, response, authException);
logger.debug(LogMessage.format("Server side forward to: %s", loginForm));
RequestDispatcher dispatcher = request.getRequestDispatcher(loginForm);
dispatcher.forward(request, response);
return;
} }
protected String buildRedirectUrlToLoginPage(HttpServletRequest request, HttpServletResponse response, protected String buildRedirectUrlToLoginPage(HttpServletRequest request, HttpServletResponse response,
AuthenticationException authException) { AuthenticationException authException) {
String loginForm = determineUrlToUseForThisRequest(request, response, authException); String loginForm = determineUrlToUseForThisRequest(request, response, authException);
if (UrlUtils.isAbsoluteUrl(loginForm)) { if (UrlUtils.isAbsoluteUrl(loginForm)) {
return loginForm; return loginForm;
} }
int serverPort = this.portResolver.getServerPort(request); int serverPort = this.portResolver.getServerPort(request);
String scheme = request.getScheme(); String scheme = request.getScheme();
RedirectUrlBuilder urlBuilder = new RedirectUrlBuilder(); RedirectUrlBuilder urlBuilder = new RedirectUrlBuilder();
urlBuilder.setScheme(scheme); urlBuilder.setScheme(scheme);
urlBuilder.setServerName(request.getServerName()); urlBuilder.setServerName(request.getServerName());
urlBuilder.setPort(serverPort); urlBuilder.setPort(serverPort);
urlBuilder.setContextPath(request.getContextPath()); urlBuilder.setContextPath(request.getContextPath());
urlBuilder.setPathInfo(loginForm); urlBuilder.setPathInfo(loginForm);
if (this.forceHttps && "http".equals(scheme)) { if (this.forceHttps && "http".equals(scheme)) {
Integer httpsPort = this.portMapper.lookupHttpsPort(serverPort); Integer httpsPort = this.portMapper.lookupHttpsPort(serverPort);
if (httpsPort != null) { if (httpsPort != null) {
// Overwrite scheme and port in the redirect URL // Overwrite scheme and port in the redirect URL
urlBuilder.setScheme("https"); urlBuilder.setScheme("https");
urlBuilder.setPort(httpsPort); urlBuilder.setPort(httpsPort);
} }
else { else {
logger.warn("Unable to redirect to HTTPS as no port mapping found for HTTP port " + serverPort); logger.warn(LogMessage.format("Unable to redirect to HTTPS as no port mapping found for HTTP port %s",
serverPort));
} }
} }
return urlBuilder.getUrl(); return urlBuilder.getUrl();
} }
@ -197,10 +176,8 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin
* current request to HTTPS, before doing a forward to the login page. * current request to HTTPS, before doing a forward to the login page.
*/ */
protected String buildHttpsRedirectUrlForRequest(HttpServletRequest request) throws IOException, ServletException { protected String buildHttpsRedirectUrlForRequest(HttpServletRequest request) throws IOException, ServletException {
int serverPort = this.portResolver.getServerPort(request); int serverPort = this.portResolver.getServerPort(request);
Integer httpsPort = this.portMapper.lookupHttpsPort(serverPort); Integer httpsPort = this.portMapper.lookupHttpsPort(serverPort);
if (httpsPort != null) { if (httpsPort != null) {
RedirectUrlBuilder urlBuilder = new RedirectUrlBuilder(); RedirectUrlBuilder urlBuilder = new RedirectUrlBuilder();
urlBuilder.setScheme("https"); urlBuilder.setScheme("https");
@ -210,13 +187,11 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin
urlBuilder.setServletPath(request.getServletPath()); urlBuilder.setServletPath(request.getServletPath());
urlBuilder.setPathInfo(request.getPathInfo()); urlBuilder.setPathInfo(request.getPathInfo());
urlBuilder.setQuery(request.getQueryString()); urlBuilder.setQuery(request.getQueryString());
return urlBuilder.getUrl(); return urlBuilder.getUrl();
} }
// Fall through to server-side forward with warning message // Fall through to server-side forward with warning message
logger.warn("Unable to redirect to HTTPS as no port mapping found for HTTP port " + serverPort); logger.warn(
LogMessage.format("Unable to redirect to HTTPS as no port mapping found for HTTP port %s", serverPort));
return null; return null;
} }

View File

@ -74,10 +74,8 @@ public class SavedRequestAwareAuthenticationSuccessHandler extends SimpleUrlAuth
public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response,
Authentication authentication) throws ServletException, IOException { Authentication authentication) throws ServletException, IOException {
SavedRequest savedRequest = this.requestCache.getRequest(request, response); SavedRequest savedRequest = this.requestCache.getRequest(request, response);
if (savedRequest == null) { if (savedRequest == null) {
super.onAuthenticationSuccess(request, response, authentication); super.onAuthenticationSuccess(request, response, authentication);
return; return;
} }
String targetUrlParameter = getTargetUrlParameter(); String targetUrlParameter = getTargetUrlParameter();
@ -85,12 +83,9 @@ public class SavedRequestAwareAuthenticationSuccessHandler extends SimpleUrlAuth
|| (targetUrlParameter != null && StringUtils.hasText(request.getParameter(targetUrlParameter)))) { || (targetUrlParameter != null && StringUtils.hasText(request.getParameter(targetUrlParameter)))) {
this.requestCache.removeRequest(request, response); this.requestCache.removeRequest(request, response);
super.onAuthenticationSuccess(request, response, authentication); super.onAuthenticationSuccess(request, response, authentication);
return; return;
} }
clearAuthenticationAttributes(request); clearAuthenticationAttributes(request);
// Use the DefaultSavedRequest URL // Use the DefaultSavedRequest URL
String targetUrl = savedRequest.getRedirectUrl(); String targetUrl = savedRequest.getRedirectUrl();
this.logger.debug("Redirecting to DefaultSavedRequest Url: " + targetUrl); this.logger.debug("Redirecting to DefaultSavedRequest Url: " + targetUrl);

View File

@ -76,24 +76,19 @@ public class SimpleUrlAuthenticationFailureHandler implements AuthenticationFail
@Override @Override
public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
AuthenticationException exception) throws IOException, ServletException { AuthenticationException exception) throws IOException, ServletException {
if (this.defaultFailureUrl == null) { if (this.defaultFailureUrl == null) {
this.logger.debug("No failure URL set, sending 401 Unauthorized error"); this.logger.debug("No failure URL set, sending 401 Unauthorized error");
response.sendError(HttpStatus.UNAUTHORIZED.value(), HttpStatus.UNAUTHORIZED.getReasonPhrase()); response.sendError(HttpStatus.UNAUTHORIZED.value(), HttpStatus.UNAUTHORIZED.getReasonPhrase());
return;
}
saveException(request, exception);
if (this.forwardToDestination) {
this.logger.debug("Forwarding to " + this.defaultFailureUrl);
request.getRequestDispatcher(this.defaultFailureUrl).forward(request, response);
} }
else { else {
saveException(request, exception); this.logger.debug("Redirecting to " + this.defaultFailureUrl);
this.redirectStrategy.sendRedirect(request, response, this.defaultFailureUrl);
if (this.forwardToDestination) {
this.logger.debug("Forwarding to " + this.defaultFailureUrl);
request.getRequestDispatcher(this.defaultFailureUrl).forward(request, response);
}
else {
this.logger.debug("Redirecting to " + this.defaultFailureUrl);
this.redirectStrategy.sendRedirect(request, response, this.defaultFailureUrl);
}
} }
} }
@ -108,13 +103,11 @@ public class SimpleUrlAuthenticationFailureHandler implements AuthenticationFail
protected final void saveException(HttpServletRequest request, AuthenticationException exception) { protected final void saveException(HttpServletRequest request, AuthenticationException exception) {
if (this.forwardToDestination) { if (this.forwardToDestination) {
request.setAttribute(WebAttributes.AUTHENTICATION_EXCEPTION, exception); request.setAttribute(WebAttributes.AUTHENTICATION_EXCEPTION, exception);
return;
} }
else { HttpSession session = request.getSession(false);
HttpSession session = request.getSession(false); if (session != null || this.allowSessionCreation) {
request.getSession().setAttribute(WebAttributes.AUTHENTICATION_EXCEPTION, exception);
if (session != null || this.allowSessionCreation) {
request.getSession().setAttribute(WebAttributes.AUTHENTICATION_EXCEPTION, exception);
}
} }
} }

View File

@ -59,7 +59,6 @@ public class SimpleUrlAuthenticationSuccessHandler extends AbstractAuthenticatio
@Override @Override
public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response,
Authentication authentication) throws IOException, ServletException { Authentication authentication) throws IOException, ServletException {
handle(request, response, authentication); handle(request, response, authentication);
clearAuthenticationAttributes(request); clearAuthenticationAttributes(request);
} }
@ -70,12 +69,9 @@ public class SimpleUrlAuthenticationSuccessHandler extends AbstractAuthenticatio
*/ */
protected final void clearAuthenticationAttributes(HttpServletRequest request) { protected final void clearAuthenticationAttributes(HttpServletRequest request) {
HttpSession session = request.getSession(false); HttpSession session = request.getSession(false);
if (session != null) {
if (session == null) { session.removeAttribute(WebAttributes.AUTHENTICATION_EXCEPTION);
return;
} }
session.removeAttribute(WebAttributes.AUTHENTICATION_EXCEPTION);
} }
} }

View File

@ -74,25 +74,14 @@ public class UsernamePasswordAuthenticationFilter extends AbstractAuthentication
if (this.postOnly && !request.getMethod().equals("POST")) { if (this.postOnly && !request.getMethod().equals("POST")) {
throw new AuthenticationServiceException("Authentication method not supported: " + request.getMethod()); throw new AuthenticationServiceException("Authentication method not supported: " + request.getMethod());
} }
String username = obtainUsername(request); String username = obtainUsername(request);
String password = obtainPassword(request); username = (username != null) ? username : "";
if (username == null) {
username = "";
}
if (password == null) {
password = "";
}
username = username.trim(); username = username.trim();
String password = obtainPassword(request);
password = (password != null) ? password : "";
UsernamePasswordAuthenticationToken authRequest = new UsernamePasswordAuthenticationToken(username, password); UsernamePasswordAuthenticationToken authRequest = new UsernamePasswordAuthenticationToken(username, password);
// Allow subclasses to set the "details" property // Allow subclasses to set the "details" property
setDetails(request, authRequest); setDetails(request, authRequest);
return this.getAuthenticationManager().authenticate(authRequest); return this.getAuthenticationManager().authenticate(authRequest);
} }

View File

@ -44,7 +44,6 @@ public class WebAuthenticationDetails implements Serializable {
*/ */
public WebAuthenticationDetails(HttpServletRequest request) { public WebAuthenticationDetails(HttpServletRequest request) {
this.remoteAddress = request.getRemoteAddr(); this.remoteAddress = request.getRemoteAddr();
HttpSession session = request.getSession(false); HttpSession session = request.getSession(false);
this.sessionId = (session != null) ? session.getId() : null; this.sessionId = (session != null) ? session.getId() : null;
} }
@ -62,39 +61,31 @@ public class WebAuthenticationDetails implements Serializable {
@Override @Override
public boolean equals(Object obj) { public boolean equals(Object obj) {
if (obj instanceof WebAuthenticationDetails) { if (obj instanceof WebAuthenticationDetails) {
WebAuthenticationDetails rhs = (WebAuthenticationDetails) obj; WebAuthenticationDetails other = (WebAuthenticationDetails) obj;
if ((this.remoteAddress == null) && (other.getRemoteAddress() != null)) {
if ((this.remoteAddress == null) && (rhs.getRemoteAddress() != null)) {
return false; return false;
} }
if ((this.remoteAddress != null) && (other.getRemoteAddress() == null)) {
if ((this.remoteAddress != null) && (rhs.getRemoteAddress() == null)) {
return false; return false;
} }
if (this.remoteAddress != null) { if (this.remoteAddress != null) {
if (!this.remoteAddress.equals(rhs.getRemoteAddress())) { if (!this.remoteAddress.equals(other.getRemoteAddress())) {
return false; return false;
} }
} }
if ((this.sessionId == null) && (other.getSessionId() != null)) {
if ((this.sessionId == null) && (rhs.getSessionId() != null)) {
return false; return false;
} }
if ((this.sessionId != null) && (other.getSessionId() == null)) {
if ((this.sessionId != null) && (rhs.getSessionId() == null)) {
return false; return false;
} }
if (this.sessionId != null) { if (this.sessionId != null) {
if (!this.sessionId.equals(rhs.getSessionId())) { if (!this.sessionId.equals(other.getSessionId())) {
return false; return false;
} }
} }
return true; return true;
} }
return false; return false;
} }
@ -118,15 +109,12 @@ public class WebAuthenticationDetails implements Serializable {
@Override @Override
public int hashCode() { public int hashCode() {
int code = 7654; int code = 7654;
if (this.remoteAddress != null) { if (this.remoteAddress != null) {
code = code * (this.remoteAddress.hashCode() % 7); code = code * (this.remoteAddress.hashCode() % 7);
} }
if (this.sessionId != null) { if (this.sessionId != null) {
code = code * (this.sessionId.hashCode() % 7); code = code * (this.sessionId.hashCode() % 7);
} }
return code; return code;
} }
@ -136,7 +124,6 @@ public class WebAuthenticationDetails implements Serializable {
sb.append(super.toString()).append(": "); sb.append(super.toString()).append(": ");
sb.append("RemoteIpAddress: ").append(this.getRemoteAddress()).append("; "); sb.append("RemoteIpAddress: ").append(this.getRemoteAddress()).append("; ");
sb.append("SessionId: ").append(this.getSessionId()); sb.append("SessionId: ").append(this.getSessionId());
return sb.toString(); return sb.toString();
} }

View File

@ -43,15 +43,14 @@ public final class CookieClearingLogoutHandler implements LogoutHandler {
Assert.notNull(cookiesToClear, "List of cookies cannot be null"); Assert.notNull(cookiesToClear, "List of cookies cannot be null");
List<Function<HttpServletRequest, Cookie>> cookieList = new ArrayList<>(); List<Function<HttpServletRequest, Cookie>> cookieList = new ArrayList<>();
for (String cookieName : cookiesToClear) { for (String cookieName : cookiesToClear) {
Function<HttpServletRequest, Cookie> f = (request) -> { cookieList.add((request) -> {
Cookie cookie = new Cookie(cookieName, null); Cookie cookie = new Cookie(cookieName, null);
String cookiePath = request.getContextPath() + "/"; String cookiePath = request.getContextPath() + "/";
cookie.setPath(cookiePath); cookie.setPath(cookiePath);
cookie.setMaxAge(0); cookie.setMaxAge(0);
cookie.setSecure(request.isSecure()); cookie.setSecure(request.isSecure());
return cookie; return cookie;
}; });
cookieList.add(f);
} }
this.cookiesToClear = cookieList; this.cookiesToClear = cookieList;
} }
@ -65,8 +64,7 @@ public final class CookieClearingLogoutHandler implements LogoutHandler {
List<Function<HttpServletRequest, Cookie>> cookieList = new ArrayList<>(); List<Function<HttpServletRequest, Cookie>> cookieList = new ArrayList<>();
for (Cookie cookie : cookiesToClear) { for (Cookie cookie : cookiesToClear) {
Assert.isTrue(cookie.getMaxAge() == 0, "Cookie maxAge must be 0"); Assert.isTrue(cookie.getMaxAge() == 0, "Cookie maxAge must be 0");
Function<HttpServletRequest, Cookie> f = (request) -> cookie; cookieList.add((request) -> cookie);
cookieList.add(f);
} }
this.cookiesToClear = cookieList; this.cookiesToClear = cookieList;
} }

View File

@ -25,6 +25,7 @@ import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.util.UrlUtils; import org.springframework.security.web.util.UrlUtils;
@ -83,25 +84,20 @@ public class LogoutFilter extends GenericFilterBean {
} }
@Override @Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException { throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req; doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
HttpServletResponse response = (HttpServletResponse) res; }
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
if (requiresLogout(request, response)) { if (requiresLogout(request, response)) {
Authentication auth = SecurityContextHolder.getContext().getAuthentication(); Authentication auth = SecurityContextHolder.getContext().getAuthentication();
this.logger.debug(LogMessage.format("Logging out user '%s' and transferring to logout destination", auth));
if (this.logger.isDebugEnabled()) {
this.logger.debug("Logging out user '" + auth + "' and transferring to logout destination");
}
this.handler.logout(request, response, auth); this.handler.logout(request, response, auth);
this.logoutSuccessHandler.onLogoutSuccess(request, response, auth); this.logoutSuccessHandler.onLogoutSuccess(request, response, auth);
return; return;
} }
chain.doFilter(request, response); chain.doFilter(request, response);
} }

View File

@ -23,6 +23,7 @@ import javax.servlet.http.HttpSession;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
@ -61,16 +62,14 @@ public class SecurityContextLogoutHandler implements LogoutHandler {
if (this.invalidateHttpSession) { if (this.invalidateHttpSession) {
HttpSession session = request.getSession(false); HttpSession session = request.getSession(false);
if (session != null) { if (session != null) {
this.logger.debug("Invalidating session: " + session.getId()); this.logger.debug(LogMessage.format("Invalidating session: %s", session.getId()));
session.invalidate(); session.invalidate();
} }
} }
if (this.clearAuthentication) { if (this.clearAuthentication) {
SecurityContext context = SecurityContextHolder.getContext(); SecurityContext context = SecurityContextHolder.getContext();
context.setAuthentication(null); context.setAuthentication(null);
} }
SecurityContextHolder.clearContext(); SecurityContextHolder.clearContext();
} }

View File

@ -28,6 +28,7 @@ import javax.servlet.http.HttpSession;
import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.event.InteractiveAuthenticationSuccessEvent; import org.springframework.security.authentication.event.InteractiveAuthenticationSuccessEvent;
@ -124,16 +125,11 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
@Override @Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException { throws IOException, ServletException {
this.logger.debug(LogMessage
if (this.logger.isDebugEnabled()) { .of(() -> "Checking secure context token: " + SecurityContextHolder.getContext().getAuthentication()));
this.logger
.debug("Checking secure context token: " + SecurityContextHolder.getContext().getAuthentication());
}
if (this.requiresAuthenticationRequestMatcher.matches((HttpServletRequest) request)) { if (this.requiresAuthenticationRequestMatcher.matches((HttpServletRequest) request)) {
doAuthenticate((HttpServletRequest) request, (HttpServletResponse) response); doAuthenticate((HttpServletRequest) request, (HttpServletResponse) response);
} }
chain.doFilter(request, response); chain.doFilter(request, response);
} }
@ -156,21 +152,15 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
* @return true if the principal has changed, else false * @return true if the principal has changed, else false
*/ */
protected boolean principalChanged(HttpServletRequest request, Authentication currentAuthentication) { protected boolean principalChanged(HttpServletRequest request, Authentication currentAuthentication) {
Object principal = getPreAuthenticatedPrincipal(request); Object principal = getPreAuthenticatedPrincipal(request);
if ((principal instanceof String) && currentAuthentication.getName().equals(principal)) { if ((principal instanceof String) && currentAuthentication.getName().equals(principal)) {
return false; return false;
} }
if (principal != null && principal.equals(currentAuthentication.getPrincipal())) { if (principal != null && principal.equals(currentAuthentication.getPrincipal())) {
return false; return false;
} }
this.logger.debug(LogMessage.format("Pre-authenticated principal has changed to %s and will be reauthenticated",
if (this.logger.isDebugEnabled()) { principal));
this.logger
.debug("Pre-authenticated principal has changed to " + principal + " and will be reauthenticated");
}
return true; return true;
} }
@ -179,35 +169,24 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
*/ */
private void doAuthenticate(HttpServletRequest request, HttpServletResponse response) private void doAuthenticate(HttpServletRequest request, HttpServletResponse response)
throws IOException, ServletException { throws IOException, ServletException {
Authentication authResult;
Object principal = getPreAuthenticatedPrincipal(request); Object principal = getPreAuthenticatedPrincipal(request);
Object credentials = getPreAuthenticatedCredentials(request);
if (principal == null) { if (principal == null) {
if (this.logger.isDebugEnabled()) { this.logger.debug("No pre-authenticated principal found in request");
this.logger.debug("No pre-authenticated principal found in request");
}
return; return;
} }
this.logger.debug(LogMessage.format("preAuthenticatedPrincipal = %s, trying to authenticate", principal));
if (this.logger.isDebugEnabled()) { Object credentials = getPreAuthenticatedCredentials(request);
this.logger.debug("preAuthenticatedPrincipal = " + principal + ", trying to authenticate");
}
try { try {
PreAuthenticatedAuthenticationToken authRequest = new PreAuthenticatedAuthenticationToken(principal, PreAuthenticatedAuthenticationToken authenticationRequest = new PreAuthenticatedAuthenticationToken(
credentials); principal, credentials);
authRequest.setDetails(this.authenticationDetailsSource.buildDetails(request)); authenticationRequest.setDetails(this.authenticationDetailsSource.buildDetails(request));
authResult = this.authenticationManager.authenticate(authRequest); Authentication authenticationResult = this.authenticationManager.authenticate(authenticationRequest);
successfulAuthentication(request, response, authResult); successfulAuthentication(request, response, authenticationResult);
} }
catch (AuthenticationException failed) { catch (AuthenticationException ex) {
unsuccessfulAuthentication(request, response, failed); unsuccessfulAuthentication(request, response, ex);
if (!this.continueFilterChainOnUnsuccessfulAuthentication) { if (!this.continueFilterChainOnUnsuccessfulAuthentication) {
throw failed; throw ex;
} }
} }
} }
@ -218,15 +197,11 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
*/ */
protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response,
Authentication authResult) throws IOException, ServletException { Authentication authResult) throws IOException, ServletException {
if (this.logger.isDebugEnabled()) { this.logger.debug(LogMessage.format("Authentication success: %s", authResult));
this.logger.debug("Authentication success: " + authResult);
}
SecurityContextHolder.getContext().setAuthentication(authResult); SecurityContextHolder.getContext().setAuthentication(authResult);
// Fire event
if (this.eventPublisher != null) { if (this.eventPublisher != null) {
this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass())); this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass()));
} }
if (this.authenticationSuccessHandler != null) { if (this.authenticationSuccessHandler != null) {
this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authResult); this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authResult);
} }
@ -241,12 +216,8 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response,
AuthenticationException failed) throws IOException, ServletException { AuthenticationException failed) throws IOException, ServletException {
SecurityContextHolder.clearContext(); SecurityContextHolder.clearContext();
this.logger.debug("Cleared security context due to exception", failed);
if (this.logger.isDebugEnabled()) {
this.logger.debug("Cleared security context due to exception", failed);
}
request.setAttribute(WebAttributes.AUTHENTICATION_EXCEPTION, failed); request.setAttribute(WebAttributes.AUTHENTICATION_EXCEPTION, failed);
if (this.authenticationFailureHandler != null) { if (this.authenticationFailureHandler != null) {
this.authenticationFailureHandler.onAuthenticationFailure(request, response, failed); this.authenticationFailureHandler.onAuthenticationFailure(request, response, failed);
} }
@ -355,36 +326,27 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
@Override @Override
public boolean matches(HttpServletRequest request) { public boolean matches(HttpServletRequest request) {
Authentication currentUser = SecurityContextHolder.getContext().getAuthentication(); Authentication currentUser = SecurityContextHolder.getContext().getAuthentication();
if (currentUser == null) { if (currentUser == null) {
return true; return true;
} }
if (!AbstractPreAuthenticatedProcessingFilter.this.checkForPrincipalChanges) { if (!AbstractPreAuthenticatedProcessingFilter.this.checkForPrincipalChanges) {
return false; return false;
} }
if (!principalChanged(request, currentUser)) { if (!principalChanged(request, currentUser)) {
return false; return false;
} }
AbstractPreAuthenticatedProcessingFilter.this.logger AbstractPreAuthenticatedProcessingFilter.this.logger
.debug("Pre-authenticated principal has changed and will be reauthenticated"); .debug("Pre-authenticated principal has changed and will be reauthenticated");
if (AbstractPreAuthenticatedProcessingFilter.this.invalidateSessionOnPrincipalChange) { if (AbstractPreAuthenticatedProcessingFilter.this.invalidateSessionOnPrincipalChange) {
SecurityContextHolder.clearContext(); SecurityContextHolder.clearContext();
HttpSession session = request.getSession(false); HttpSession session = request.getSession(false);
if (session != null) { if (session != null) {
AbstractPreAuthenticatedProcessingFilter.this.logger.debug("Invalidating existing session"); AbstractPreAuthenticatedProcessingFilter.this.logger.debug("Invalidating existing session");
session.invalidate(); session.invalidate();
request.getSession(); request.getSession();
} }
} }
return true; return true;
} }

View File

@ -21,6 +21,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.Ordered; import org.springframework.core.Ordered;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AccountStatusUserDetailsChecker; import org.springframework.security.authentication.AccountStatusUserDetailsChecker;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.BadCredentialsException;
@ -50,11 +51,11 @@ public class PreAuthenticatedAuthenticationProvider implements AuthenticationPro
private static final Log logger = LogFactory.getLog(PreAuthenticatedAuthenticationProvider.class); private static final Log logger = LogFactory.getLog(PreAuthenticatedAuthenticationProvider.class);
private AuthenticationUserDetailsService<PreAuthenticatedAuthenticationToken> preAuthenticatedUserDetailsService = null; private AuthenticationUserDetailsService<PreAuthenticatedAuthenticationToken> preAuthenticatedUserDetailsService;
private UserDetailsChecker userDetailsChecker = new AccountStatusUserDetailsChecker(); private UserDetailsChecker userDetailsChecker = new AccountStatusUserDetailsChecker();
private boolean throwExceptionWhenTokenRejected = false; private boolean throwExceptionWhenTokenRejected;
private int order = -1; // default: same as non-ordered private int order = -1; // default: same as non-ordered
@ -77,38 +78,27 @@ public class PreAuthenticatedAuthenticationProvider implements AuthenticationPro
if (!supports(authentication.getClass())) { if (!supports(authentication.getClass())) {
return null; return null;
} }
logger.debug(LogMessage.format("PreAuthenticated authentication request: %s", authentication));
if (logger.isDebugEnabled()) {
logger.debug("PreAuthenticated authentication request: " + authentication);
}
if (authentication.getPrincipal() == null) { if (authentication.getPrincipal() == null) {
logger.debug("No pre-authenticated principal found in request."); logger.debug("No pre-authenticated principal found in request.");
if (this.throwExceptionWhenTokenRejected) { if (this.throwExceptionWhenTokenRejected) {
throw new BadCredentialsException("No pre-authenticated principal found in request."); throw new BadCredentialsException("No pre-authenticated principal found in request.");
} }
return null; return null;
} }
if (authentication.getCredentials() == null) { if (authentication.getCredentials() == null) {
logger.debug("No pre-authenticated credentials found in request."); logger.debug("No pre-authenticated credentials found in request.");
if (this.throwExceptionWhenTokenRejected) { if (this.throwExceptionWhenTokenRejected) {
throw new BadCredentialsException("No pre-authenticated credentials found in request."); throw new BadCredentialsException("No pre-authenticated credentials found in request.");
} }
return null; return null;
} }
UserDetails userDetails = this.preAuthenticatedUserDetailsService
UserDetails ud = this.preAuthenticatedUserDetailsService
.loadUserDetails((PreAuthenticatedAuthenticationToken) authentication); .loadUserDetails((PreAuthenticatedAuthenticationToken) authentication);
this.userDetailsChecker.check(userDetails);
this.userDetailsChecker.check(ud); PreAuthenticatedAuthenticationToken result = new PreAuthenticatedAuthenticationToken(userDetails,
authentication.getCredentials(), userDetails.getAuthorities());
PreAuthenticatedAuthenticationToken result = new PreAuthenticatedAuthenticationToken(ud,
authentication.getCredentials(), ud.getAuthorities());
result.setDetails(authentication.getDetails()); result.setDetails(authentication.getDetails());
return result; return result;
} }

View File

@ -46,7 +46,6 @@ public class PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails extends
public PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails(HttpServletRequest request, public PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails(HttpServletRequest request,
Collection<? extends GrantedAuthority> authorities) { Collection<? extends GrantedAuthority> authorities) {
super(request); super(request);
List<GrantedAuthority> temp = new ArrayList<>(authorities.size()); List<GrantedAuthority> temp = new ArrayList<>(authorities.size());
temp.addAll(authorities); temp.addAll(authorities);
this.authorities = Collections.unmodifiableList(temp); this.authorities = Collections.unmodifiableList(temp);

View File

@ -59,12 +59,10 @@ public class RequestAttributeAuthenticationFilter extends AbstractPreAuthenticat
@Override @Override
protected Object getPreAuthenticatedPrincipal(HttpServletRequest request) { protected Object getPreAuthenticatedPrincipal(HttpServletRequest request) {
String principal = (String) request.getAttribute(this.principalEnvironmentVariable); String principal = (String) request.getAttribute(this.principalEnvironmentVariable);
if (principal == null && this.exceptionIfVariableMissing) { if (principal == null && this.exceptionIfVariableMissing) {
throw new PreAuthenticatedCredentialsNotFoundException( throw new PreAuthenticatedCredentialsNotFoundException(
this.principalEnvironmentVariable + " variable not found in request."); this.principalEnvironmentVariable + " variable not found in request.");
} }
return principal; return principal;
} }
@ -78,7 +76,6 @@ public class RequestAttributeAuthenticationFilter extends AbstractPreAuthenticat
if (this.credentialsEnvironmentVariable != null) { if (this.credentialsEnvironmentVariable != null) {
return request.getAttribute(this.credentialsEnvironmentVariable); return request.getAttribute(this.credentialsEnvironmentVariable);
} }
return "N/A"; return "N/A";
} }

View File

@ -60,12 +60,10 @@ public class RequestHeaderAuthenticationFilter extends AbstractPreAuthenticatedP
@Override @Override
protected Object getPreAuthenticatedPrincipal(HttpServletRequest request) { protected Object getPreAuthenticatedPrincipal(HttpServletRequest request) {
String principal = request.getHeader(this.principalRequestHeader); String principal = request.getHeader(this.principalRequestHeader);
if (principal == null && this.exceptionIfHeaderMissing) { if (principal == null && this.exceptionIfHeaderMissing) {
throw new PreAuthenticatedCredentialsNotFoundException( throw new PreAuthenticatedCredentialsNotFoundException(
this.principalRequestHeader + " header not found in request."); this.principalRequestHeader + " header not found in request.");
} }
return principal; return principal;
} }
@ -79,7 +77,6 @@ public class RequestHeaderAuthenticationFilter extends AbstractPreAuthenticatedP
if (this.credentialsRequestHeader != null) { if (this.credentialsRequestHeader != null) {
return request.getHeader(this.credentialsRequestHeader); return request.getHeader(this.credentialsRequestHeader);
} }
return "N/A"; return "N/A";
} }

View File

@ -27,6 +27,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.mapping.Attributes2GrantedAuthoritiesMapper; import org.springframework.security.core.authority.mapping.Attributes2GrantedAuthoritiesMapper;
@ -76,13 +77,11 @@ public class J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource implements
*/ */
protected Collection<String> getUserRoles(HttpServletRequest request) { protected Collection<String> getUserRoles(HttpServletRequest request) {
ArrayList<String> j2eeUserRolesList = new ArrayList<>(); ArrayList<String> j2eeUserRolesList = new ArrayList<>();
for (String role : this.j2eeMappableRoles) { for (String role : this.j2eeMappableRoles) {
if (request.isUserInRole(role)) { if (request.isUserInRole(role)) {
j2eeUserRolesList.add(role); j2eeUserRolesList.add(role);
} }
} }
return j2eeUserRolesList; return j2eeUserRolesList;
} }
@ -93,19 +92,14 @@ public class J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource implements
*/ */
@Override @Override
public PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails buildDetails(HttpServletRequest context) { public PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails buildDetails(HttpServletRequest context) {
Collection<String> j2eeUserRoles = getUserRoles(context); Collection<String> j2eeUserRoles = getUserRoles(context);
Collection<? extends GrantedAuthority> userGas = this.j2eeUserRoles2GrantedAuthoritiesMapper Collection<? extends GrantedAuthority> userGrantedAuthorities = this.j2eeUserRoles2GrantedAuthoritiesMapper
.getGrantedAuthorities(j2eeUserRoles); .getGrantedAuthorities(j2eeUserRoles);
if (this.logger.isDebugEnabled()) { if (this.logger.isDebugEnabled()) {
this.logger.debug("J2EE roles [" + j2eeUserRoles + "] mapped to Granted Authorities: [" + userGas + "]"); this.logger.debug(LogMessage.format("J2EE roles [%s] mapped to Granted Authorities: [%s]", j2eeUserRoles,
userGrantedAuthorities));
} }
return new PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails(context, userGrantedAuthorities);
PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails result = new PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails(
context, userGas);
return result;
} }
/** /**

View File

@ -18,6 +18,7 @@ package org.springframework.security.web.authentication.preauth.j2ee;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import org.springframework.core.log.LogMessage;
import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter; import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter;
/** /**
@ -36,9 +37,7 @@ public class J2eePreAuthenticatedProcessingFilter extends AbstractPreAuthenticat
@Override @Override
protected Object getPreAuthenticatedPrincipal(HttpServletRequest httpRequest) { protected Object getPreAuthenticatedPrincipal(HttpServletRequest httpRequest) {
Object principal = (httpRequest.getUserPrincipal() != null) ? httpRequest.getUserPrincipal().getName() : null; Object principal = (httpRequest.getUserPrincipal() != null) ? httpRequest.getUserPrincipal().getName() : null;
if (this.logger.isDebugEnabled()) { this.logger.debug(LogMessage.format("PreAuthenticated J2EE principal: %s", principal));
this.logger.debug("PreAuthenticated J2EE principal: " + principal);
}
return principal; return principal;
} }

View File

@ -22,6 +22,7 @@ import java.io.StringReader;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List;
import java.util.Set; import java.util.Set;
import javax.xml.parsers.DocumentBuilder; import javax.xml.parsers.DocumentBuilder;
@ -43,6 +44,7 @@ import org.springframework.context.ResourceLoaderAware;
import org.springframework.core.io.Resource; import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader; import org.springframework.core.io.ResourceLoader;
import org.springframework.security.core.authority.mapping.MappableAttributesRetriever; import org.springframework.security.core.authority.mapping.MappableAttributesRetriever;
import org.springframework.util.Assert;
/** /**
* This <tt>MappableAttributesRetriever</tt> implementation reads the list of defined J2EE * This <tt>MappableAttributesRetriever</tt> implementation reads the list of defined J2EE
@ -82,17 +84,17 @@ public class WebXmlMappableAttributesRetriever
Resource webXml = this.resourceLoader.getResource("/WEB-INF/web.xml"); Resource webXml = this.resourceLoader.getResource("/WEB-INF/web.xml");
Document doc = getDocument(webXml.getInputStream()); Document doc = getDocument(webXml.getInputStream());
NodeList webApp = doc.getElementsByTagName("web-app"); NodeList webApp = doc.getElementsByTagName("web-app");
if (webApp.getLength() != 1) { Assert.isTrue(webApp.getLength() == 1, () -> "Failed to find 'web-app' element in resource" + webXml);
throw new IllegalArgumentException("Failed to find 'web-app' element in resource" + webXml);
}
NodeList securityRoles = ((Element) webApp.item(0)).getElementsByTagName("security-role"); NodeList securityRoles = ((Element) webApp.item(0)).getElementsByTagName("security-role");
List<String> roleNames = getRoleNames(webXml, securityRoles);
this.mappableAttributes = Collections.unmodifiableSet(new HashSet<>(roleNames));
}
private List<String> getRoleNames(Resource webXml, NodeList securityRoles) {
ArrayList<String> roleNames = new ArrayList<>(); ArrayList<String> roleNames = new ArrayList<>();
for (int i = 0; i < securityRoles.getLength(); i++) { for (int i = 0; i < securityRoles.getLength(); i++) {
Element secRoleElt = (Element) securityRoles.item(i); Element securityRoleElement = (Element) securityRoles.item(i);
NodeList roles = secRoleElt.getElementsByTagName("role-name"); NodeList roles = securityRoleElement.getElementsByTagName("role-name");
if (roles.getLength() > 0) { if (roles.getLength() > 0) {
String roleName = roles.item(0).getTextContent().trim(); String roleName = roles.item(0).getTextContent().trim();
roleNames.add(roleName); roleNames.add(roleName);
@ -102,22 +104,19 @@ public class WebXmlMappableAttributesRetriever
this.logger.info("No security-role elements found in " + webXml); this.logger.info("No security-role elements found in " + webXml);
} }
} }
return roleNames;
this.mappableAttributes = Collections.unmodifiableSet(new HashSet<>(roleNames));
} }
/** /**
* @return Document for the specified InputStream * @return Document for the specified InputStream
*/ */
private Document getDocument(InputStream aStream) { private Document getDocument(InputStream aStream) {
Document doc;
try { try {
DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance(); DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
factory.setValidating(false); factory.setValidating(false);
DocumentBuilder db = factory.newDocumentBuilder(); DocumentBuilder builder = factory.newDocumentBuilder();
db.setEntityResolver(new MyEntityResolver()); builder.setEntityResolver(new MyEntityResolver());
doc = db.parse(aStream); return builder.parse(aStream);
return doc;
} }
catch (FactoryConfigurationError | IOException | SAXException | ParserConfigurationException ex) { catch (FactoryConfigurationError | IOException | SAXException | ParserConfigurationException ex) {
throw new RuntimeException("Unable to parse document object", ex); throw new RuntimeException("Unable to parse document object", ex);

View File

@ -31,6 +31,8 @@ import javax.security.auth.Subject;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
/** /**
* WebSphere Security helper class to allow retrieval of the current username and groups. * WebSphere Security helper class to allow retrieval of the current username and groups.
* <p> * <p>
@ -75,9 +77,7 @@ final class DefaultWASUsernameAndGroupsExtractor implements WASUsernameAndGroups
* @return String the security name for the given subject * @return String the security name for the given subject
*/ */
private static String getSecurityName(final Subject subject) { private static String getSecurityName(final Subject subject) {
if (logger.isDebugEnabled()) { logger.debug(LogMessage.format("Determining Websphere security name for subject %s", subject));
logger.debug("Determining Websphere security name for subject " + subject);
}
String userSecurityName = null; String userSecurityName = null;
if (subject != null) { if (subject != null) {
// SEC-803 // SEC-803
@ -86,9 +86,7 @@ final class DefaultWASUsernameAndGroupsExtractor implements WASUsernameAndGroups
userSecurityName = (String) invokeMethod(getSecurityNameMethod(), credential); userSecurityName = (String) invokeMethod(getSecurityNameMethod(), credential);
} }
} }
if (logger.isDebugEnabled()) { logger.debug(LogMessage.format("Websphere security name is %s for subject %s", subject, userSecurityName));
logger.debug("Websphere security name is " + userSecurityName + " for subject " + subject);
}
return userSecurityName; return userSecurityName;
} }
@ -119,69 +117,56 @@ final class DefaultWASUsernameAndGroupsExtractor implements WASUsernameAndGroups
*/ */
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static List<String> getWebSphereGroups(final String securityName) { private static List<String> getWebSphereGroups(final String securityName) {
Context ic = null; Context context = null;
try { try {
// TODO: Cache UserRegistry object // TODO: Cache UserRegistry object
ic = new InitialContext(); context = new InitialContext();
Object objRef = ic.lookup(USER_REGISTRY); Object objRef = context.lookup(USER_REGISTRY);
Object userReg = invokeMethod(getNarrowMethod(), null, objRef, Object userReg = invokeMethod(getNarrowMethod(), null, objRef,
Class.forName("com.ibm.websphere.security.UserRegistry")); Class.forName("com.ibm.websphere.security.UserRegistry"));
if (logger.isDebugEnabled()) { logger.debug(LogMessage.format("Determining WebSphere groups for user %s using WebSphere UserRegistry %s",
logger.debug("Determining WebSphere groups for user " + securityName + " using WebSphere UserRegistry " securityName, userReg));
+ userReg); final Collection<String> groups = (Collection<String>) invokeMethod(getGroupsForUserMethod(), userReg,
}
final Collection groups = (Collection) invokeMethod(getGroupsForUserMethod(), userReg,
new Object[] { securityName }); new Object[] { securityName });
if (logger.isDebugEnabled()) { logger.debug(LogMessage.format("Groups for user %s: %s", securityName, groups));
logger.debug("Groups for user " + securityName + ": " + groups.toString()); return new ArrayList<String>(groups);
}
return new ArrayList(groups);
} }
catch (Exception ex) { catch (Exception ex) {
logger.error("Exception occured while looking up groups for user", ex); logger.error("Exception occured while looking up groups for user", ex);
throw new RuntimeException("Exception occured while looking up groups for user", ex); throw new RuntimeException("Exception occured while looking up groups for user", ex);
} }
finally { finally {
try { closeContext(context);
if (ic != null) { }
ic.close(); }
}
} private static void closeContext(Context context) {
catch (NamingException ex) { try {
logger.debug("Exception occured while closing context", ex); if (context != null) {
context.close();
} }
} }
catch (NamingException ex) {
logger.debug("Exception occured while closing context", ex);
}
} }
private static Object invokeMethod(Method method, Object instance, Object... args) { private static Object invokeMethod(Method method, Object instance, Object... args) {
try { try {
return method.invoke(instance, args); return method.invoke(instance, args);
} }
catch (IllegalArgumentException ex) { catch (IllegalArgumentException | IllegalAccessException | InvocationTargetException ex) {
logger.error("Error while invoking method " + method.getClass().getName() + "." + method.getName() + "(" String message = "Error while invoking method " + method.getClass().getName() + "." + method.getName() + "("
+ Arrays.asList(args) + ")", ex); + Arrays.asList(args) + ")";
throw new RuntimeException("Error while invoking method " + method.getClass().getName() + "." logger.error(message, ex);
+ method.getName() + "(" + Arrays.asList(args) + ")", ex); throw new RuntimeException(message, ex);
}
catch (IllegalAccessException ex) {
logger.error("Error while invoking method " + method.getClass().getName() + "." + method.getName() + "("
+ Arrays.asList(args) + ")", ex);
throw new RuntimeException("Error while invoking method " + method.getClass().getName() + "."
+ method.getName() + "(" + Arrays.asList(args) + ")", ex);
}
catch (InvocationTargetException ex) {
logger.error("Error while invoking method " + method.getClass().getName() + "." + method.getName() + "("
+ Arrays.asList(args) + ")", ex);
throw new RuntimeException("Error while invoking method " + method.getClass().getName() + "."
+ method.getName() + "(" + Arrays.asList(args) + ")", ex);
} }
} }
private static Method getMethod(String className, String methodName, String[] parameterTypeNames) { private static Method getMethod(String className, String methodName, String[] parameterTypeNames) {
try { try {
Class<?> c = Class.forName(className); Class<?> c = Class.forName(className);
final int len = parameterTypeNames.length; int len = parameterTypeNames.length;
Class<?>[] parameterTypes = new Class[len]; Class<?>[] parameterTypes = new Class[len];
for (int i = 0; i < len; i++) { for (int i = 0; i < len; i++) {
parameterTypes[i] = Class.forName(parameterTypeNames[i]); parameterTypes[i] = Class.forName(parameterTypeNames[i]);

View File

@ -18,6 +18,7 @@ package org.springframework.security.web.authentication.preauth.websphere;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import org.springframework.core.log.LogMessage;
import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter; import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter;
/** /**
@ -51,9 +52,7 @@ public class WebSpherePreAuthenticatedProcessingFilter extends AbstractPreAuthen
@Override @Override
protected Object getPreAuthenticatedPrincipal(HttpServletRequest httpRequest) { protected Object getPreAuthenticatedPrincipal(HttpServletRequest httpRequest) {
Object principal = this.wasHelper.getCurrentUserName(); Object principal = this.wasHelper.getCurrentUserName();
if (this.logger.isDebugEnabled()) { this.logger.debug(LogMessage.format("PreAuthenticated WebSphere principal: %s", principal));
this.logger.debug("PreAuthenticated WebSphere principal: " + principal);
}
return principal; return principal;
} }

View File

@ -24,6 +24,7 @@ import javax.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.mapping.Attributes2GrantedAuthoritiesMapper; import org.springframework.security.core.authority.mapping.Attributes2GrantedAuthoritiesMapper;
@ -68,9 +69,8 @@ public class WebSpherePreAuthenticatedWebAuthenticationDetailsSource implements
List<String> webSphereGroups = this.wasHelper.getGroupsForCurrentUser(); List<String> webSphereGroups = this.wasHelper.getGroupsForCurrentUser();
Collection<? extends GrantedAuthority> userGas = this.webSphereGroups2GrantedAuthoritiesMapper Collection<? extends GrantedAuthority> userGas = this.webSphereGroups2GrantedAuthoritiesMapper
.getGrantedAuthorities(webSphereGroups); .getGrantedAuthorities(webSphereGroups);
if (this.logger.isDebugEnabled()) { this.logger.debug(
this.logger.debug("WebSphere groups: " + webSphereGroups + " mapped to Granted Authorities: " + userGas); LogMessage.format("WebSphere groups: %s mapped to Granted Authorities: %s", webSphereGroups, userGas));
}
return userGas; return userGas;
} }

View File

@ -25,6 +25,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.context.MessageSource; import org.springframework.context.MessageSource;
import org.springframework.context.support.MessageSourceAccessor; import org.springframework.context.support.MessageSourceAccessor;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.core.SpringSecurityMessageSource; import org.springframework.security.core.SpringSecurityMessageSource;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -58,24 +59,15 @@ public class SubjectDnX509PrincipalExtractor implements X509PrincipalExtractor {
public Object extractPrincipal(X509Certificate clientCert) { public Object extractPrincipal(X509Certificate clientCert) {
// String subjectDN = clientCert.getSubjectX500Principal().getName(); // String subjectDN = clientCert.getSubjectX500Principal().getName();
String subjectDN = clientCert.getSubjectDN().getName(); String subjectDN = clientCert.getSubjectDN().getName();
this.logger.debug(LogMessage.format("Subject DN is '%s'", subjectDN));
this.logger.debug("Subject DN is '" + subjectDN + "'");
Matcher matcher = this.subjectDnPattern.matcher(subjectDN); Matcher matcher = this.subjectDnPattern.matcher(subjectDN);
if (!matcher.find()) { if (!matcher.find()) {
throw new BadCredentialsException(this.messages.getMessage("SubjectDnX509PrincipalExtractor.noMatching", throw new BadCredentialsException(this.messages.getMessage("SubjectDnX509PrincipalExtractor.noMatching",
new Object[] { subjectDN }, "No matching pattern was found in subject DN: {0}")); new Object[] { subjectDN }, "No matching pattern was found in subject DN: {0}"));
} }
Assert.isTrue(matcher.groupCount() == 1, "Regular expression must contain a single group ");
if (matcher.groupCount() != 1) {
throw new IllegalArgumentException("Regular expression must contain a single group ");
}
String username = matcher.group(1); String username = matcher.group(1);
this.logger.debug(LogMessage.format("Extracted Principal name is '%s'", username));
this.logger.debug("Extracted Principal name is '" + username + "'");
return username; return username;
} }

View File

@ -20,6 +20,7 @@ import java.security.cert.X509Certificate;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import org.springframework.core.log.LogMessage;
import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter; import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter;
/** /**
@ -32,12 +33,7 @@ public class X509AuthenticationFilter extends AbstractPreAuthenticatedProcessing
@Override @Override
protected Object getPreAuthenticatedPrincipal(HttpServletRequest request) { protected Object getPreAuthenticatedPrincipal(HttpServletRequest request) {
X509Certificate cert = extractClientCertificate(request); X509Certificate cert = extractClientCertificate(request);
return (cert != null) ? this.principalExtractor.extractPrincipal(cert) : null;
if (cert == null) {
return null;
}
return this.principalExtractor.extractPrincipal(cert);
} }
@Override @Override
@ -47,19 +43,11 @@ public class X509AuthenticationFilter extends AbstractPreAuthenticatedProcessing
private X509Certificate extractClientCertificate(HttpServletRequest request) { private X509Certificate extractClientCertificate(HttpServletRequest request) {
X509Certificate[] certs = (X509Certificate[]) request.getAttribute("javax.servlet.request.X509Certificate"); X509Certificate[] certs = (X509Certificate[]) request.getAttribute("javax.servlet.request.X509Certificate");
if (certs != null && certs.length > 0) { if (certs != null && certs.length > 0) {
if (this.logger.isDebugEnabled()) { this.logger.debug(LogMessage.format("X.509 client authentication certificate:%s", certs[0]));
this.logger.debug("X.509 client authentication certificate:" + certs[0]);
}
return certs[0]; return certs[0];
} }
this.logger.debug("No client certificate found in request.");
if (this.logger.isDebugEnabled()) {
this.logger.debug("No client certificate found in request.");
}
return null; return null;
} }

View File

@ -31,6 +31,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.support.MessageSourceAccessor; import org.springframework.context.support.MessageSourceAccessor;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AccountStatusException; import org.springframework.security.authentication.AccountStatusException;
import org.springframework.security.authentication.AccountStatusUserDetailsChecker; import org.springframework.security.authentication.AccountStatusUserDetailsChecker;
import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationDetailsSource;
@ -118,47 +119,38 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
@Override @Override
public final Authentication autoLogin(HttpServletRequest request, HttpServletResponse response) { public final Authentication autoLogin(HttpServletRequest request, HttpServletResponse response) {
String rememberMeCookie = extractRememberMeCookie(request); String rememberMeCookie = extractRememberMeCookie(request);
if (rememberMeCookie == null) { if (rememberMeCookie == null) {
return null; return null;
} }
this.logger.debug("Remember-me cookie detected"); this.logger.debug("Remember-me cookie detected");
if (rememberMeCookie.length() == 0) { if (rememberMeCookie.length() == 0) {
this.logger.debug("Cookie was empty"); this.logger.debug("Cookie was empty");
cancelCookie(request, response); cancelCookie(request, response);
return null; return null;
} }
UserDetails user = null;
try { try {
String[] cookieTokens = decodeCookie(rememberMeCookie); String[] cookieTokens = decodeCookie(rememberMeCookie);
user = processAutoLoginCookie(cookieTokens, request, response); UserDetails user = processAutoLoginCookie(cookieTokens, request, response);
this.userDetailsChecker.check(user); this.userDetailsChecker.check(user);
this.logger.debug("Remember-me cookie accepted"); this.logger.debug("Remember-me cookie accepted");
return createSuccessfulAuthentication(request, user); return createSuccessfulAuthentication(request, user);
} }
catch (CookieTheftException cte) { catch (CookieTheftException ex) {
cancelCookie(request, response); cancelCookie(request, response);
throw cte; throw ex;
} }
catch (UsernameNotFoundException noUser) { catch (UsernameNotFoundException ex) {
this.logger.debug("Remember-me login was valid but corresponding user not found.", noUser); this.logger.debug("Remember-me login was valid but corresponding user not found.", ex);
} }
catch (InvalidCookieException invalidCookie) { catch (InvalidCookieException ex) {
this.logger.debug("Invalid remember-me cookie: " + invalidCookie.getMessage()); this.logger.debug("Invalid remember-me cookie: " + ex.getMessage());
} }
catch (AccountStatusException statusInvalid) { catch (AccountStatusException ex) {
this.logger.debug("Invalid UserDetails: " + statusInvalid.getMessage()); this.logger.debug("Invalid UserDetails: " + ex.getMessage());
} }
catch (RememberMeAuthenticationException ex) { catch (RememberMeAuthenticationException ex) {
this.logger.debug(ex.getMessage()); this.logger.debug(ex.getMessage());
} }
cancelCookie(request, response); cancelCookie(request, response);
return null; return null;
} }
@ -172,17 +164,14 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
*/ */
protected String extractRememberMeCookie(HttpServletRequest request) { protected String extractRememberMeCookie(HttpServletRequest request) {
Cookie[] cookies = request.getCookies(); Cookie[] cookies = request.getCookies();
if ((cookies == null) || (cookies.length == 0)) { if ((cookies == null) || (cookies.length == 0)) {
return null; return null;
} }
for (Cookie cookie : cookies) { for (Cookie cookie : cookies) {
if (this.cookieName.equals(cookie.getName())) { if (this.cookieName.equals(cookie.getName())) {
return cookie.getValue(); return cookie.getValue();
} }
} }
return null; return null;
} }
@ -216,18 +205,14 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
for (int j = 0; j < cookieValue.length() % 4; j++) { for (int j = 0; j < cookieValue.length() % 4; j++) {
cookieValue = cookieValue + "="; cookieValue = cookieValue + "=";
} }
try { try {
Base64.getDecoder().decode(cookieValue.getBytes()); Base64.getDecoder().decode(cookieValue.getBytes());
} }
catch (IllegalArgumentException ex) { catch (IllegalArgumentException ex) {
throw new InvalidCookieException("Cookie token was not Base64 encoded; value was '" + cookieValue + "'"); throw new InvalidCookieException("Cookie token was not Base64 encoded; value was '" + cookieValue + "'");
} }
String cookieAsPlainText = new String(Base64.getDecoder().decode(cookieValue.getBytes())); String cookieAsPlainText = new String(Base64.getDecoder().decode(cookieValue.getBytes()));
String[] tokens = StringUtils.delimitedListToStringArray(cookieAsPlainText, DELIMITER); String[] tokens = StringUtils.delimitedListToStringArray(cookieAsPlainText, DELIMITER);
for (int i = 0; i < tokens.length; i++) { for (int i = 0; i < tokens.length; i++) {
try { try {
tokens[i] = URLDecoder.decode(tokens[i], StandardCharsets.UTF_8.toString()); tokens[i] = URLDecoder.decode(tokens[i], StandardCharsets.UTF_8.toString());
@ -236,7 +221,6 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
this.logger.error(ex.getMessage(), ex); this.logger.error(ex.getMessage(), ex);
} }
} }
return tokens; return tokens;
} }
@ -254,20 +238,15 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
catch (UnsupportedEncodingException ex) { catch (UnsupportedEncodingException ex) {
this.logger.error(ex.getMessage(), ex); this.logger.error(ex.getMessage(), ex);
} }
if (i < cookieTokens.length - 1) { if (i < cookieTokens.length - 1) {
sb.append(DELIMITER); sb.append(DELIMITER);
} }
} }
String value = sb.toString(); String value = sb.toString();
sb = new StringBuilder(new String(Base64.getEncoder().encode(value.getBytes()))); sb = new StringBuilder(new String(Base64.getEncoder().encode(value.getBytes())));
while (sb.charAt(sb.length() - 1) == '=') { while (sb.charAt(sb.length() - 1) == '=') {
sb.deleteCharAt(sb.length() - 1); sb.deleteCharAt(sb.length() - 1);
} }
return sb.toString(); return sb.toString();
} }
@ -293,12 +272,10 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
@Override @Override
public final void loginSuccess(HttpServletRequest request, HttpServletResponse response, public final void loginSuccess(HttpServletRequest request, HttpServletResponse response,
Authentication successfulAuthentication) { Authentication successfulAuthentication) {
if (!rememberMeRequested(request, this.parameter)) { if (!rememberMeRequested(request, this.parameter)) {
this.logger.debug("Remember-me login not requested."); this.logger.debug("Remember-me login not requested.");
return; return;
} }
onLoginSuccess(request, response, successfulAuthentication); onLoginSuccess(request, response, successfulAuthentication);
} }
@ -324,20 +301,15 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
if (this.alwaysRemember) { if (this.alwaysRemember) {
return true; return true;
} }
String paramValue = request.getParameter(parameter); String paramValue = request.getParameter(parameter);
if (paramValue != null) { if (paramValue != null) {
if (paramValue.equalsIgnoreCase("true") || paramValue.equalsIgnoreCase("on") if (paramValue.equalsIgnoreCase("true") || paramValue.equalsIgnoreCase("on")
|| paramValue.equalsIgnoreCase("yes") || paramValue.equals("1")) { || paramValue.equalsIgnoreCase("yes") || paramValue.equals("1")) {
return true; return true;
} }
} }
this.logger.debug(
if (this.logger.isDebugEnabled()) { LogMessage.format("Did not send remember-me cookie (principal did not set parameter '%s')", parameter));
this.logger.debug("Did not send remember-me cookie (principal did not set parameter '" + parameter + "')");
}
return false; return false;
} }
@ -370,12 +342,7 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
if (this.cookieDomain != null) { if (this.cookieDomain != null) {
cookie.setDomain(this.cookieDomain); cookie.setDomain(this.cookieDomain);
} }
if (this.useSecureCookie == null) { cookie.setSecure((this.useSecureCookie != null) ? this.useSecureCookie : request.isSecure());
cookie.setSecure(request.isSecure());
}
else {
cookie.setSecure(this.useSecureCookie);
}
response.addCookie(cookie); response.addCookie(cookie);
} }
@ -402,16 +369,8 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
if (maxAge < 1) { if (maxAge < 1) {
cookie.setVersion(1); cookie.setVersion(1);
} }
cookie.setSecure((this.useSecureCookie != null) ? this.useSecureCookie : request.isSecure());
if (this.useSecureCookie == null) {
cookie.setSecure(request.isSecure());
}
else {
cookie.setSecure(this.useSecureCookie);
}
cookie.setHttpOnly(true); cookie.setHttpOnly(true);
response.addCookie(cookie); response.addCookie(cookie);
} }
@ -426,9 +385,8 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
*/ */
@Override @Override
public void logout(HttpServletRequest request, HttpServletResponse response, Authentication authentication) { public void logout(HttpServletRequest request, HttpServletResponse response, Authentication authentication) {
if (this.logger.isDebugEnabled()) { this.logger.debug(LogMessage
this.logger.debug("Logout of user " + ((authentication != null) ? authentication.getName() : "Unknown")); .of(() -> "Logout of user " + ((authentication != null) ? authentication.getName() : "Unknown")));
}
cancelCookie(request, response); cancelCookie(request, response);
} }

View File

@ -36,21 +36,17 @@ public class InMemoryTokenRepositoryImpl implements PersistentTokenRepository {
@Override @Override
public synchronized void createNewToken(PersistentRememberMeToken token) { public synchronized void createNewToken(PersistentRememberMeToken token) {
PersistentRememberMeToken current = this.seriesTokens.get(token.getSeries()); PersistentRememberMeToken current = this.seriesTokens.get(token.getSeries());
if (current != null) { if (current != null) {
throw new DataIntegrityViolationException("Series Id '" + token.getSeries() + "' already exists!"); throw new DataIntegrityViolationException("Series Id '" + token.getSeries() + "' already exists!");
} }
this.seriesTokens.put(token.getSeries(), token); this.seriesTokens.put(token.getSeries(), token);
} }
@Override @Override
public synchronized void updateToken(String series, String tokenValue, Date lastUsed) { public synchronized void updateToken(String series, String tokenValue, Date lastUsed) {
PersistentRememberMeToken token = getTokenForSeries(series); PersistentRememberMeToken token = getTokenForSeries(series);
PersistentRememberMeToken newToken = new PersistentRememberMeToken(token.getUsername(), series, tokenValue, PersistentRememberMeToken newToken = new PersistentRememberMeToken(token.getUsername(), series, tokenValue,
new Date()); new Date());
// Store it, overwriting the existing one. // Store it, overwriting the existing one.
this.seriesTokens.put(series, newToken); this.seriesTokens.put(series, newToken);
} }
@ -63,12 +59,9 @@ public class InMemoryTokenRepositoryImpl implements PersistentTokenRepository {
@Override @Override
public synchronized void removeUserTokens(String username) { public synchronized void removeUserTokens(String username) {
Iterator<String> series = this.seriesTokens.keySet().iterator(); Iterator<String> series = this.seriesTokens.keySet().iterator();
while (series.hasNext()) { while (series.hasNext()) {
String seriesId = series.next(); String seriesId = series.next();
PersistentRememberMeToken token = this.seriesTokens.get(seriesId); PersistentRememberMeToken token = this.seriesTokens.get(seriesId);
if (username.equals(token.getUsername())) { if (username.equals(token.getUsername())) {
series.remove(); series.remove();
} }

View File

@ -16,8 +16,11 @@
package org.springframework.security.web.authentication.rememberme; package org.springframework.security.web.authentication.rememberme;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Date; import java.util.Date;
import org.springframework.core.log.LogMessage;
import org.springframework.dao.DataAccessException; import org.springframework.dao.DataAccessException;
import org.springframework.dao.EmptyResultDataAccessException; import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.dao.IncorrectResultSizeDataAccessException; import org.springframework.dao.IncorrectResultSizeDataAccessException;
@ -87,27 +90,26 @@ public class JdbcTokenRepositoryImpl extends JdbcDaoSupport implements Persisten
@Override @Override
public PersistentRememberMeToken getTokenForSeries(String seriesId) { public PersistentRememberMeToken getTokenForSeries(String seriesId) {
try { try {
return getJdbcTemplate().queryForObject(this.tokensBySeriesSql, return getJdbcTemplate().queryForObject(this.tokensBySeriesSql, this::createRememberMeToken, seriesId);
(rs, rowNum) -> new PersistentRememberMeToken(rs.getString(1), rs.getString(2), rs.getString(3),
rs.getTimestamp(4)),
seriesId);
} }
catch (EmptyResultDataAccessException zeroResults) { catch (EmptyResultDataAccessException ex) {
if (this.logger.isDebugEnabled()) { this.logger.debug(LogMessage.format("Querying token for series '%s' returned no results.", seriesId), ex);
this.logger.debug("Querying token for series '" + seriesId + "' returned no results.", zeroResults);
}
} }
catch (IncorrectResultSizeDataAccessException moreThanOne) { catch (IncorrectResultSizeDataAccessException ex) {
this.logger.error("Querying token for series '" + seriesId + "' returned more than one value. Series" this.logger.error(LogMessage.format(
+ " should be unique"); "Querying token for series '%s' returned more than one value. Series" + " should be unique",
seriesId));
} }
catch (DataAccessException ex) { catch (DataAccessException ex) {
this.logger.error("Failed to load token for series " + seriesId, ex); this.logger.error("Failed to load token for series " + seriesId, ex);
} }
return null; return null;
} }
private PersistentRememberMeToken createRememberMeToken(ResultSet rs, int rowNum) throws SQLException {
return new PersistentRememberMeToken(rs.getString(1), rs.getString(2), rs.getString(3), rs.getTimestamp(4));
}
@Override @Override
public void removeUserTokens(String username) { public void removeUserTokens(String username) {
getJdbcTemplate().update(this.removeUserTokensSql, username); getJdbcTemplate().update(this.removeUserTokensSql, username);

View File

@ -24,6 +24,7 @@ import java.util.Date;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.UserDetailsService;
@ -93,47 +94,35 @@ public class PersistentTokenBasedRememberMeServices extends AbstractRememberMeSe
@Override @Override
protected UserDetails processAutoLoginCookie(String[] cookieTokens, HttpServletRequest request, protected UserDetails processAutoLoginCookie(String[] cookieTokens, HttpServletRequest request,
HttpServletResponse response) { HttpServletResponse response) {
if (cookieTokens.length != 2) { if (cookieTokens.length != 2) {
throw new InvalidCookieException("Cookie token did not contain " + 2 + " tokens, but contained '" throw new InvalidCookieException("Cookie token did not contain " + 2 + " tokens, but contained '"
+ Arrays.asList(cookieTokens) + "'"); + Arrays.asList(cookieTokens) + "'");
} }
String presentedSeries = cookieTokens[0];
final String presentedSeries = cookieTokens[0]; String presentedToken = cookieTokens[1];
final String presentedToken = cookieTokens[1];
PersistentRememberMeToken token = this.tokenRepository.getTokenForSeries(presentedSeries); PersistentRememberMeToken token = this.tokenRepository.getTokenForSeries(presentedSeries);
if (token == null) { if (token == null) {
// No series match, so we can't authenticate using this cookie // No series match, so we can't authenticate using this cookie
throw new RememberMeAuthenticationException("No persistent token found for series id: " + presentedSeries); throw new RememberMeAuthenticationException("No persistent token found for series id: " + presentedSeries);
} }
// We have a match for this user/series combination // We have a match for this user/series combination
if (!presentedToken.equals(token.getTokenValue())) { if (!presentedToken.equals(token.getTokenValue())) {
// Token doesn't match series value. Delete all logins for this user and throw // Token doesn't match series value. Delete all logins for this user and throw
// an exception to warn them. // an exception to warn them.
this.tokenRepository.removeUserTokens(token.getUsername()); this.tokenRepository.removeUserTokens(token.getUsername());
throw new CookieTheftException(this.messages.getMessage( throw new CookieTheftException(this.messages.getMessage(
"PersistentTokenBasedRememberMeServices.cookieStolen", "PersistentTokenBasedRememberMeServices.cookieStolen",
"Invalid remember-me token (Series/token) mismatch. Implies previous cookie theft attack.")); "Invalid remember-me token (Series/token) mismatch. Implies previous cookie theft attack."));
} }
if (token.getDate().getTime() + getTokenValiditySeconds() * 1000L < System.currentTimeMillis()) { if (token.getDate().getTime() + getTokenValiditySeconds() * 1000L < System.currentTimeMillis()) {
throw new RememberMeAuthenticationException("Remember-me login has expired"); throw new RememberMeAuthenticationException("Remember-me login has expired");
} }
// Token also matches, so login is valid. Update the token value, keeping the // Token also matches, so login is valid. Update the token value, keeping the
// *same* series number. // *same* series number.
if (this.logger.isDebugEnabled()) { this.logger.debug(LogMessage.format("Refreshing persistent login token for user '%s', series '%s'",
this.logger.debug("Refreshing persistent login token for user '" + token.getUsername() + "', series '" token.getUsername(), token.getSeries()));
+ token.getSeries() + "'");
}
PersistentRememberMeToken newToken = new PersistentRememberMeToken(token.getUsername(), token.getSeries(), PersistentRememberMeToken newToken = new PersistentRememberMeToken(token.getUsername(), token.getSeries(),
generateTokenData(), new Date()); generateTokenData(), new Date());
try { try {
this.tokenRepository.updateToken(newToken.getSeries(), newToken.getTokenValue(), newToken.getDate()); this.tokenRepository.updateToken(newToken.getSeries(), newToken.getTokenValue(), newToken.getDate());
addCookie(newToken, request, response); addCookie(newToken, request, response);
@ -142,7 +131,6 @@ public class PersistentTokenBasedRememberMeServices extends AbstractRememberMeSe
this.logger.error("Failed to update token: ", ex); this.logger.error("Failed to update token: ", ex);
throw new RememberMeAuthenticationException("Autologin failed due to data access problem"); throw new RememberMeAuthenticationException("Autologin failed due to data access problem");
} }
return getUserDetailsService().loadUserByUsername(token.getUsername()); return getUserDetailsService().loadUserByUsername(token.getUsername());
} }
@ -155,9 +143,7 @@ public class PersistentTokenBasedRememberMeServices extends AbstractRememberMeSe
protected void onLoginSuccess(HttpServletRequest request, HttpServletResponse response, protected void onLoginSuccess(HttpServletRequest request, HttpServletResponse response,
Authentication successfulAuthentication) { Authentication successfulAuthentication) {
String username = successfulAuthentication.getName(); String username = successfulAuthentication.getName();
this.logger.debug(LogMessage.format("Creating new persistent login for user %s", username));
this.logger.debug("Creating new persistent login for user " + username);
PersistentRememberMeToken persistentToken = new PersistentRememberMeToken(username, generateSeriesData(), PersistentRememberMeToken persistentToken = new PersistentRememberMeToken(username, generateSeriesData(),
generateTokenData(), new Date()); generateTokenData(), new Date());
try { try {
@ -172,7 +158,6 @@ public class PersistentTokenBasedRememberMeServices extends AbstractRememberMeSe
@Override @Override
public void logout(HttpServletRequest request, HttpServletResponse response, Authentication authentication) { public void logout(HttpServletRequest request, HttpServletResponse response, Authentication authentication) {
super.logout(request, response, authentication); super.logout(request, response, authentication);
if (authentication != null) { if (authentication != null) {
this.tokenRepository.removeUserTokens(authentication.getName()); this.tokenRepository.removeUserTokens(authentication.getName());
} }

View File

@ -27,6 +27,7 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.event.InteractiveAuthenticationSuccessEvent; import org.springframework.security.authentication.event.InteractiveAuthenticationSuccessEvent;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@ -86,66 +87,50 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements
} }
@Override @Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException { throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req; doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
HttpServletResponse response = (HttpServletResponse) res; }
if (SecurityContextHolder.getContext().getAuthentication() == null) {
Authentication rememberMeAuth = this.rememberMeServices.autoLogin(request, response);
if (rememberMeAuth != null) {
// Attempt authenticaton via AuthenticationManager
try {
rememberMeAuth = this.authenticationManager.authenticate(rememberMeAuth);
// Store to SecurityContextHolder
SecurityContextHolder.getContext().setAuthentication(rememberMeAuth);
onSuccessfulAuthentication(request, response, rememberMeAuth);
if (this.logger.isDebugEnabled()) {
this.logger.debug("SecurityContextHolder populated with remember-me token: '"
+ SecurityContextHolder.getContext().getAuthentication() + "'");
}
// Fire event
if (this.eventPublisher != null) {
this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(
SecurityContextHolder.getContext().getAuthentication(), this.getClass()));
}
if (this.successHandler != null) {
this.successHandler.onAuthenticationSuccess(request, response, rememberMeAuth);
return;
}
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
if (SecurityContextHolder.getContext().getAuthentication() != null) {
this.logger.debug(LogMessage
.of(() -> "SecurityContextHolder not populated with remember-me token, as it already contained: '"
+ SecurityContextHolder.getContext().getAuthentication() + "'"));
chain.doFilter(request, response);
return;
}
Authentication rememberMeAuth = this.rememberMeServices.autoLogin(request, response);
if (rememberMeAuth != null) {
// Attempt authenticaton via AuthenticationManager
try {
rememberMeAuth = this.authenticationManager.authenticate(rememberMeAuth);
// Store to SecurityContextHolder
SecurityContextHolder.getContext().setAuthentication(rememberMeAuth);
onSuccessfulAuthentication(request, response, rememberMeAuth);
this.logger.debug(LogMessage.of(() -> "SecurityContextHolder populated with remember-me token: '"
+ SecurityContextHolder.getContext().getAuthentication() + "'"));
if (this.eventPublisher != null) {
this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(
SecurityContextHolder.getContext().getAuthentication(), this.getClass()));
} }
catch (AuthenticationException authenticationException) { if (this.successHandler != null) {
if (this.logger.isDebugEnabled()) { this.successHandler.onAuthenticationSuccess(request, response, rememberMeAuth);
this.logger.debug("SecurityContextHolder not populated with remember-me token, as " return;
+ "AuthenticationManager rejected Authentication returned by RememberMeServices: '"
+ rememberMeAuth + "'; invalidating remember-me token", authenticationException);
}
this.rememberMeServices.loginFail(request, response);
onUnsuccessfulAuthentication(request, response, authenticationException);
} }
} }
catch (AuthenticationException ex) {
chain.doFilter(request, response); this.logger.debug(LogMessage
} .format("SecurityContextHolder not populated with remember-me token, as AuthenticationManager "
else { + "rejected Authentication returned by RememberMeServices: '%s'; "
if (this.logger.isDebugEnabled()) { + "invalidating remember-me token", rememberMeAuth),
this.logger ex);
.debug("SecurityContextHolder not populated with remember-me token, as it already contained: '" this.rememberMeServices.loginFail(request, response);
+ SecurityContextHolder.getContext().getAuthentication() + "'"); onUnsuccessfulAuthentication(request, response, ex);
} }
chain.doFilter(request, response);
} }
chain.doFilter(request, response);
} }
/** /**

View File

@ -90,52 +90,43 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices {
@Override @Override
protected UserDetails processAutoLoginCookie(String[] cookieTokens, HttpServletRequest request, protected UserDetails processAutoLoginCookie(String[] cookieTokens, HttpServletRequest request,
HttpServletResponse response) { HttpServletResponse response) {
if (cookieTokens.length != 3) { if (cookieTokens.length != 3) {
throw new InvalidCookieException( throw new InvalidCookieException(
"Cookie token did not contain 3" + " tokens, but contained '" + Arrays.asList(cookieTokens) + "'"); "Cookie token did not contain 3" + " tokens, but contained '" + Arrays.asList(cookieTokens) + "'");
} }
long tokenExpiryTime = getTokenExpiryTime(cookieTokens);
if (isTokenExpired(tokenExpiryTime)) {
throw new InvalidCookieException("Cookie token[1] has expired (expired on '" + new Date(tokenExpiryTime)
+ "'; current time is '" + new Date() + "')");
}
// Check the user exists. Defer lookup until after expiry time checked, to
// possibly avoid expensive database call.
UserDetails userDetails = getUserDetailsService().loadUserByUsername(cookieTokens[0]);
Assert.notNull(userDetails, () -> "UserDetailsService " + getUserDetailsService()
+ " returned null for username " + cookieTokens[0] + ". " + "This is an interface contract violation");
// Check signature of token matches remaining details. Must do this after user
// lookup, as we need the DAO-derived password. If efficiency was a major issue,
// just add in a UserCache implementation, but recall that this method is usually
// only called once per HttpSession - if the token is valid, it will cause
// SecurityContextHolder population, whilst if invalid, will cause the cookie to
// be cancelled.
String expectedTokenSignature = makeTokenSignature(tokenExpiryTime, userDetails.getUsername(),
userDetails.getPassword());
if (!equals(expectedTokenSignature, cookieTokens[2])) {
throw new InvalidCookieException("Cookie token[2] contained signature '" + cookieTokens[2]
+ "' but expected '" + expectedTokenSignature + "'");
}
return userDetails;
}
long tokenExpiryTime; private long getTokenExpiryTime(String[] cookieTokens) {
try { try {
tokenExpiryTime = new Long(cookieTokens[1]); return new Long(cookieTokens[1]);
} }
catch (NumberFormatException nfe) { catch (NumberFormatException nfe) {
throw new InvalidCookieException( throw new InvalidCookieException(
"Cookie token[1] did not contain a valid number (contained '" + cookieTokens[1] + "')"); "Cookie token[1] did not contain a valid number (contained '" + cookieTokens[1] + "')");
} }
if (isTokenExpired(tokenExpiryTime)) {
throw new InvalidCookieException("Cookie token[1] has expired (expired on '" + new Date(tokenExpiryTime)
+ "'; current time is '" + new Date() + "')");
}
// Check the user exists.
// Defer lookup until after expiry time checked, to possibly avoid expensive
// database call.
UserDetails userDetails = getUserDetailsService().loadUserByUsername(cookieTokens[0]);
Assert.notNull(userDetails, () -> "UserDetailsService " + getUserDetailsService()
+ " returned null for username " + cookieTokens[0] + ". " + "This is an interface contract violation");
// Check signature of token matches remaining details.
// Must do this after user lookup, as we need the DAO-derived password.
// If efficiency was a major issue, just add in a UserCache implementation,
// but recall that this method is usually only called once per HttpSession - if
// the token is valid,
// it will cause SecurityContextHolder population, whilst if invalid, will cause
// the cookie to be cancelled.
String expectedTokenSignature = makeTokenSignature(tokenExpiryTime, userDetails.getUsername(),
userDetails.getPassword());
if (!equals(expectedTokenSignature, cookieTokens[2])) {
throw new InvalidCookieException("Cookie token[2] contained signature '" + cookieTokens[2]
+ "' but expected '" + expectedTokenSignature + "'");
}
return userDetails;
} }
/** /**
@ -144,15 +135,13 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices {
*/ */
protected String makeTokenSignature(long tokenExpiryTime, String username, String password) { protected String makeTokenSignature(long tokenExpiryTime, String username, String password) {
String data = username + ":" + tokenExpiryTime + ":" + password + ":" + getKey(); String data = username + ":" + tokenExpiryTime + ":" + password + ":" + getKey();
MessageDigest digest;
try { try {
digest = MessageDigest.getInstance("MD5"); MessageDigest digest = MessageDigest.getInstance("MD5");
return new String(Hex.encode(digest.digest(data.getBytes())));
} }
catch (NoSuchAlgorithmException ex) { catch (NoSuchAlgorithmException ex) {
throw new IllegalStateException("No MD5 algorithm available!"); throw new IllegalStateException("No MD5 algorithm available!");
} }
return new String(Hex.encode(digest.digest(data.getBytes())));
} }
protected boolean isTokenExpired(long tokenExpiryTime) { protected boolean isTokenExpired(long tokenExpiryTime) {
@ -162,10 +151,8 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices {
@Override @Override
public void onLoginSuccess(HttpServletRequest request, HttpServletResponse response, public void onLoginSuccess(HttpServletRequest request, HttpServletResponse response,
Authentication successfulAuthentication) { Authentication successfulAuthentication) {
String username = retrieveUserName(successfulAuthentication); String username = retrieveUserName(successfulAuthentication);
String password = retrievePassword(successfulAuthentication); String password = retrievePassword(successfulAuthentication);
// If unable to find a username and password, just abort as // If unable to find a username and password, just abort as
// TokenBasedRememberMeServices is // TokenBasedRememberMeServices is
// unable to construct a valid token in this case. // unable to construct a valid token in this case.
@ -173,27 +160,21 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices {
this.logger.debug("Unable to retrieve username"); this.logger.debug("Unable to retrieve username");
return; return;
} }
if (!StringUtils.hasLength(password)) { if (!StringUtils.hasLength(password)) {
UserDetails user = getUserDetailsService().loadUserByUsername(username); UserDetails user = getUserDetailsService().loadUserByUsername(username);
password = user.getPassword(); password = user.getPassword();
if (!StringUtils.hasLength(password)) { if (!StringUtils.hasLength(password)) {
this.logger.debug("Unable to obtain password for user: " + username); this.logger.debug("Unable to obtain password for user: " + username);
return; return;
} }
} }
int tokenLifetime = calculateLoginLifetime(request, successfulAuthentication); int tokenLifetime = calculateLoginLifetime(request, successfulAuthentication);
long expiryTime = System.currentTimeMillis(); long expiryTime = System.currentTimeMillis();
// SEC-949 // SEC-949
expiryTime += 1000L * ((tokenLifetime < 0) ? TWO_WEEKS_S : tokenLifetime); expiryTime += 1000L * ((tokenLifetime < 0) ? TWO_WEEKS_S : tokenLifetime);
String signatureValue = makeTokenSignature(expiryTime, username, password); String signatureValue = makeTokenSignature(expiryTime, username, password);
setCookie(new String[] { username, Long.toString(expiryTime), signatureValue }, tokenLifetime, request, setCookie(new String[] { username, Long.toString(expiryTime), signatureValue }, tokenLifetime, request,
response); response);
if (this.logger.isDebugEnabled()) { if (this.logger.isDebugEnabled()) {
this.logger.debug( this.logger.debug(
"Added remember-me cookie for user '" + username + "', expiry: '" + new Date(expiryTime) + "'"); "Added remember-me cookie for user '" + username + "', expiry: '" + new Date(expiryTime) + "'");
@ -223,21 +204,17 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices {
if (isInstanceOfUserDetails(authentication)) { if (isInstanceOfUserDetails(authentication)) {
return ((UserDetails) authentication.getPrincipal()).getUsername(); return ((UserDetails) authentication.getPrincipal()).getUsername();
} }
else { return authentication.getPrincipal().toString();
return authentication.getPrincipal().toString();
}
} }
protected String retrievePassword(Authentication authentication) { protected String retrievePassword(Authentication authentication) {
if (isInstanceOfUserDetails(authentication)) { if (isInstanceOfUserDetails(authentication)) {
return ((UserDetails) authentication.getPrincipal()).getPassword(); return ((UserDetails) authentication.getPrincipal()).getPassword();
} }
else { if (authentication.getCredentials() != null) {
if (authentication.getCredentials() == null) {
return null;
}
return authentication.getCredentials().toString(); return authentication.getCredentials().toString();
} }
return null;
} }
private boolean isInstanceOfUserDetails(Authentication authentication) { private boolean isInstanceOfUserDetails(Authentication authentication) {
@ -250,15 +227,11 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices {
private static boolean equals(String expected, String actual) { private static boolean equals(String expected, String actual) {
byte[] expectedBytes = bytesUtf8(expected); byte[] expectedBytes = bytesUtf8(expected);
byte[] actualBytes = bytesUtf8(actual); byte[] actualBytes = bytesUtf8(actual);
return MessageDigest.isEqual(expectedBytes, actualBytes); return MessageDigest.isEqual(expectedBytes, actualBytes);
} }
private static byte[] bytesUtf8(String s) { private static byte[] bytesUtf8(String s) {
if (s == null) { return (s != null) ? Utf8.encode(s) : null;
return null;
}
return Utf8.encode(s);
} }
} }

View File

@ -73,35 +73,26 @@ public abstract class AbstractSessionFixationProtectionStrategy
public void onAuthentication(Authentication authentication, HttpServletRequest request, public void onAuthentication(Authentication authentication, HttpServletRequest request,
HttpServletResponse response) { HttpServletResponse response) {
boolean hadSessionAlready = request.getSession(false) != null; boolean hadSessionAlready = request.getSession(false) != null;
if (!hadSessionAlready && !this.alwaysCreateSession) { if (!hadSessionAlready && !this.alwaysCreateSession) {
// Session fixation isn't a problem if there's no session // Session fixation isn't a problem if there's no session
return; return;
} }
// Create new session if necessary // Create new session if necessary
HttpSession session = request.getSession(); HttpSession session = request.getSession();
if (hadSessionAlready && request.isRequestedSessionIdValid()) { if (hadSessionAlready && request.isRequestedSessionIdValid()) {
String originalSessionId; String originalSessionId;
String newSessionId; String newSessionId;
Object mutex = WebUtils.getSessionMutex(session); Object mutex = WebUtils.getSessionMutex(session);
synchronized (mutex) { synchronized (mutex) {
// We need to migrate to a new session // We need to migrate to a new session
originalSessionId = session.getId(); originalSessionId = session.getId();
session = applySessionFixation(request); session = applySessionFixation(request);
newSessionId = session.getId(); newSessionId = session.getId();
} }
if (originalSessionId.equals(newSessionId)) { if (originalSessionId.equals(newSessionId)) {
this.logger.warn( this.logger.warn("Your servlet container did not change the session ID when a new session "
"Your servlet container did not change the session ID when a new session was created. You will" + "was created. You will not be adequately protected against session-fixation attacks");
+ " not be adequately protected against session-fixation attacks");
} }
onSessionChange(originalSessionId, session, authentication); onSessionChange(originalSessionId, session, authentication);
} }
} }

View File

@ -25,6 +25,7 @@ import javax.servlet.http.HttpSession;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -63,10 +64,7 @@ public class CompositeSessionAuthenticationStrategy implements SessionAuthentica
public CompositeSessionAuthenticationStrategy(List<SessionAuthenticationStrategy> delegateStrategies) { public CompositeSessionAuthenticationStrategy(List<SessionAuthenticationStrategy> delegateStrategies) {
Assert.notEmpty(delegateStrategies, "delegateStrategies cannot be null or empty"); Assert.notEmpty(delegateStrategies, "delegateStrategies cannot be null or empty");
for (SessionAuthenticationStrategy strategy : delegateStrategies) { for (SessionAuthenticationStrategy strategy : delegateStrategies) {
if (strategy == null) { Assert.notNull(strategy, () -> "delegateStrategies cannot contain null entires. Got " + delegateStrategies);
throw new IllegalArgumentException(
"delegateStrategies cannot contain null entires. Got " + delegateStrategies);
}
} }
this.delegateStrategies = delegateStrategies; this.delegateStrategies = delegateStrategies;
} }
@ -75,9 +73,7 @@ public class CompositeSessionAuthenticationStrategy implements SessionAuthentica
public void onAuthentication(Authentication authentication, HttpServletRequest request, public void onAuthentication(Authentication authentication, HttpServletRequest request,
HttpServletResponse response) throws SessionAuthenticationException { HttpServletResponse response) throws SessionAuthenticationException {
for (SessionAuthenticationStrategy delegate : this.delegateStrategies) { for (SessionAuthenticationStrategy delegate : this.delegateStrategies) {
if (this.logger.isDebugEnabled()) { this.logger.debug(LogMessage.format("Delegating to %s", delegate));
this.logger.debug("Delegating to " + delegate);
}
delegate.onAuthentication(authentication, request, response); delegate.onAuthentication(authentication, request, response);
} }
} }

View File

@ -94,26 +94,19 @@ public class ConcurrentSessionControlAuthenticationStrategy
@Override @Override
public void onAuthentication(Authentication authentication, HttpServletRequest request, public void onAuthentication(Authentication authentication, HttpServletRequest request,
HttpServletResponse response) { HttpServletResponse response) {
List<SessionInformation> sessions = this.sessionRegistry.getAllSessions(authentication.getPrincipal(), false);
final List<SessionInformation> sessions = this.sessionRegistry.getAllSessions(authentication.getPrincipal(),
false);
int sessionCount = sessions.size(); int sessionCount = sessions.size();
int allowedSessions = getMaximumSessionsForThisUser(authentication); int allowedSessions = getMaximumSessionsForThisUser(authentication);
if (sessionCount < allowedSessions) { if (sessionCount < allowedSessions) {
// They haven't got too many login sessions running at present // They haven't got too many login sessions running at present
return; return;
} }
if (allowedSessions == -1) { if (allowedSessions == -1) {
// We permit unlimited logins // We permit unlimited logins
return; return;
} }
if (sessionCount == allowedSessions) { if (sessionCount == allowedSessions) {
HttpSession session = request.getSession(false); HttpSession session = request.getSession(false);
if (session != null) { if (session != null) {
// Only permit it though if this request is associated with one of the // Only permit it though if this request is associated with one of the
// already registered sessions // already registered sessions
@ -126,7 +119,6 @@ public class ConcurrentSessionControlAuthenticationStrategy
// If the session is null, a new one will be created by the parent class, // If the session is null, a new one will be created by the parent class,
// exceeding the allowed number // exceeding the allowed number
} }
allowableSessionsExceeded(sessions, allowedSessions, this.sessionRegistry); allowableSessionsExceeded(sessions, allowedSessions, this.sessionRegistry);
} }
@ -157,7 +149,6 @@ public class ConcurrentSessionControlAuthenticationStrategy
this.messages.getMessage("ConcurrentSessionControlAuthenticationStrategy.exceededAllowed", this.messages.getMessage("ConcurrentSessionControlAuthenticationStrategy.exceededAllowed",
new Object[] { allowableSessions }, "Maximum sessions of {0} for this principal exceeded")); new Object[] { allowableSessions }, "Maximum sessions of {0} for this principal exceeded"));
} }
// Determine least recently used sessions, and mark them for invalidation // Determine least recently used sessions, and mark them for invalidation
sessions.sort(Comparator.comparing(SessionInformation::getLastRequest)); sessions.sort(Comparator.comparing(SessionInformation::getLastRequest));
int maximumSessionsExceededBy = sessions.size() - allowableSessions + 1; int maximumSessionsExceededBy = sessions.size() - allowableSessions + 1;

View File

@ -23,6 +23,8 @@ import java.util.Map;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession; import javax.servlet.http.HttpSession;
import org.springframework.core.log.LogMessage;
/** /**
* Uses {@code HttpServletRequest.invalidate()} to protect against session fixation * Uses {@code HttpServletRequest.invalidate()} to protect against session fixation
* attacks. * attacks.
@ -82,21 +84,13 @@ public class SessionFixationProtectionStrategy extends AbstractSessionFixationPr
final HttpSession applySessionFixation(HttpServletRequest request) { final HttpSession applySessionFixation(HttpServletRequest request) {
HttpSession session = request.getSession(); HttpSession session = request.getSession();
String originalSessionId = session.getId(); String originalSessionId = session.getId();
if (this.logger.isDebugEnabled()) { this.logger.debug(LogMessage.of(() -> "Invalidating session with Id '" + originalSessionId + "' "
this.logger.debug("Invalidating session with Id '" + originalSessionId + "' " + (this.migrateSessionAttributes ? "and" : "without") + " migrating attributes."));
+ (this.migrateSessionAttributes ? "and" : "without") + " migrating attributes.");
}
Map<String, Object> attributesToMigrate = extractAttributes(session); Map<String, Object> attributesToMigrate = extractAttributes(session);
int maxInactiveIntervalToMigrate = session.getMaxInactiveInterval(); int maxInactiveIntervalToMigrate = session.getMaxInactiveInterval();
session.invalidate(); session.invalidate();
session = request.getSession(true); // we now have a new session session = request.getSession(true); // we now have a new session
this.logger.debug(LogMessage.format("Started new session: %s", session.getId()));
if (this.logger.isDebugEnabled()) {
this.logger.debug("Started new session: " + session.getId());
}
transferAttributes(attributesToMigrate, session); transferAttributes(attributesToMigrate, session);
if (this.migrateSessionAttributes) { if (this.migrateSessionAttributes) {
session.setMaxInactiveInterval(maxInactiveIntervalToMigrate); session.setMaxInactiveInterval(maxInactiveIntervalToMigrate);
@ -111,27 +105,22 @@ public class SessionFixationProtectionStrategy extends AbstractSessionFixationPr
*/ */
void transferAttributes(Map<String, Object> attributes, HttpSession newSession) { void transferAttributes(Map<String, Object> attributes, HttpSession newSession) {
if (attributes != null) { if (attributes != null) {
for (Map.Entry<String, Object> entry : attributes.entrySet()) { attributes.forEach(newSession::setAttribute);
newSession.setAttribute(entry.getKey(), entry.getValue());
}
} }
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private HashMap<String, Object> createMigratedAttributeMap(HttpSession session) { private HashMap<String, Object> createMigratedAttributeMap(HttpSession session) {
HashMap<String, Object> attributesToMigrate = new HashMap<>(); HashMap<String, Object> attributesToMigrate = new HashMap<>();
Enumeration<String> enumeration = session.getAttributeNames();
Enumeration enumer = session.getAttributeNames(); while (enumeration.hasMoreElements()) {
String key = enumeration.nextElement();
while (enumer.hasMoreElements()) {
String key = (String) enumer.nextElement();
if (!this.migrateSessionAttributes && !key.startsWith("SPRING_SECURITY_")) { if (!this.migrateSessionAttributes && !key.startsWith("SPRING_SECURITY_")) {
// Only retain Spring Security attributes // Only retain Spring Security attributes
continue; continue;
} }
attributesToMigrate.put(key, session.getAttribute(key)); attributesToMigrate.put(key, session.getAttribute(key));
} }
return attributesToMigrate; return attributesToMigrate;
} }

View File

@ -34,6 +34,7 @@ import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.context.MessageSource; import org.springframework.context.MessageSource;
import org.springframework.context.MessageSourceAware; import org.springframework.context.MessageSourceAware;
import org.springframework.context.support.MessageSourceAccessor; import org.springframework.context.support.MessageSourceAccessor;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AccountExpiredException; import org.springframework.security.authentication.AccountExpiredException;
import org.springframework.security.authentication.AccountStatusUserDetailsChecker; import org.springframework.security.authentication.AccountStatusUserDetailsChecker;
import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
@ -149,7 +150,6 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
Assert.isNull(this.successHandler, "You cannot set both successHandler and targetUrl"); Assert.isNull(this.successHandler, "You cannot set both successHandler and targetUrl");
this.successHandler = new SimpleUrlAuthenticationSuccessHandler(this.targetUrl); this.successHandler = new SimpleUrlAuthenticationSuccessHandler(this.targetUrl);
} }
if (this.failureHandler == null) { if (this.failureHandler == null) {
this.failureHandler = (this.switchFailureUrl != null) this.failureHandler = (this.switchFailureUrl != null)
? new SimpleUrlAuthenticationFailureHandler(this.switchFailureUrl) ? new SimpleUrlAuthenticationFailureHandler(this.switchFailureUrl)
@ -161,20 +161,20 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
} }
@Override @Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException { throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req; doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
HttpServletResponse response = (HttpServletResponse) res; }
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
// check for switch or exit request // check for switch or exit request
if (requiresSwitchUser(request)) { if (requiresSwitchUser(request)) {
// if set, attempt switch and store original // if set, attempt switch and store original
try { try {
Authentication targetUser = attemptSwitchUser(request); Authentication targetUser = attemptSwitchUser(request);
// update the current context to the new target user // update the current context to the new target user
SecurityContextHolder.getContext().setAuthentication(targetUser); SecurityContextHolder.getContext().setAuthentication(targetUser);
// redirect to target url // redirect to target url
this.successHandler.onAuthenticationSuccess(request, response, targetUser); this.successHandler.onAuthenticationSuccess(request, response, targetUser);
} }
@ -182,22 +182,17 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
this.logger.debug("Switch User failed", ex); this.logger.debug("Switch User failed", ex);
this.failureHandler.onAuthenticationFailure(request, response, ex); this.failureHandler.onAuthenticationFailure(request, response, ex);
} }
return; return;
} }
else if (requiresExitUser(request)) { if (requiresExitUser(request)) {
// get the original authentication object (if exists) // get the original authentication object (if exists)
Authentication originalUser = attemptExitUser(request); Authentication originalUser = attemptExitUser(request);
// update the current context back to the original user // update the current context back to the original user
SecurityContextHolder.getContext().setAuthentication(originalUser); SecurityContextHolder.getContext().setAuthentication(originalUser);
// redirect to target url // redirect to target url
this.successHandler.onAuthenticationSuccess(request, response, originalUser); this.successHandler.onAuthenticationSuccess(request, response, originalUser);
return; return;
} }
chain.doFilter(request, response); chain.doFilter(request, response);
} }
@ -214,33 +209,19 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
*/ */
protected Authentication attemptSwitchUser(HttpServletRequest request) throws AuthenticationException { protected Authentication attemptSwitchUser(HttpServletRequest request) throws AuthenticationException {
UsernamePasswordAuthenticationToken targetUserRequest; UsernamePasswordAuthenticationToken targetUserRequest;
String username = request.getParameter(this.usernameParameter); String username = request.getParameter(this.usernameParameter);
username = (username != null) ? username : "";
if (username == null) { this.logger.debug(LogMessage.format("Attempt to switch to user [%s]", username));
username = "";
}
if (this.logger.isDebugEnabled()) {
this.logger.debug("Attempt to switch to user [" + username + "]");
}
UserDetails targetUser = this.userDetailsService.loadUserByUsername(username); UserDetails targetUser = this.userDetailsService.loadUserByUsername(username);
this.userDetailsChecker.check(targetUser); this.userDetailsChecker.check(targetUser);
// OK, create the switch user token // OK, create the switch user token
targetUserRequest = createSwitchUserToken(request, targetUser); targetUserRequest = createSwitchUserToken(request, targetUser);
this.logger.debug(LogMessage.format("Switch User Token [%s]", targetUserRequest));
if (this.logger.isDebugEnabled()) {
this.logger.debug("Switch User Token [" + targetUserRequest + "]");
}
// publish event // publish event
if (this.eventPublisher != null) { if (this.eventPublisher != null) {
this.eventPublisher.publishEvent(new AuthenticationSwitchUserEvent( this.eventPublisher.publishEvent(new AuthenticationSwitchUserEvent(
SecurityContextHolder.getContext().getAuthentication(), targetUser)); SecurityContextHolder.getContext().getAuthentication(), targetUser));
} }
return targetUserRequest; return targetUserRequest;
} }
@ -256,35 +237,28 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
throws AuthenticationCredentialsNotFoundException { throws AuthenticationCredentialsNotFoundException {
// need to check to see if the current user has a SwitchUserGrantedAuthority // need to check to see if the current user has a SwitchUserGrantedAuthority
Authentication current = SecurityContextHolder.getContext().getAuthentication(); Authentication current = SecurityContextHolder.getContext().getAuthentication();
if (current == null) {
if (null == current) {
throw new AuthenticationCredentialsNotFoundException(this.messages throw new AuthenticationCredentialsNotFoundException(this.messages
.getMessage("SwitchUserFilter.noCurrentUser", "No current user associated with this request")); .getMessage("SwitchUserFilter.noCurrentUser", "No current user associated with this request"));
} }
// check to see if the current user did actual switch to another user // check to see if the current user did actual switch to another user
// if so, get the original source user so we can switch back // if so, get the original source user so we can switch back
Authentication original = getSourceAuthentication(current); Authentication original = getSourceAuthentication(current);
if (original == null) { if (original == null) {
this.logger.debug("Could not find original user Authentication object!"); this.logger.debug("Could not find original user Authentication object!");
throw new AuthenticationCredentialsNotFoundException(this.messages.getMessage( throw new AuthenticationCredentialsNotFoundException(this.messages.getMessage(
"SwitchUserFilter.noOriginalAuthentication", "Could not find original Authentication object")); "SwitchUserFilter.noOriginalAuthentication", "Could not find original Authentication object"));
} }
// get the source user details // get the source user details
UserDetails originalUser = null; UserDetails originalUser = null;
Object obj = original.getPrincipal(); Object obj = original.getPrincipal();
if ((obj != null) && obj instanceof UserDetails) { if ((obj != null) && obj instanceof UserDetails) {
originalUser = (UserDetails) obj; originalUser = (UserDetails) obj;
} }
// publish event // publish event
if (this.eventPublisher != null) { if (this.eventPublisher != null) {
this.eventPublisher.publishEvent(new AuthenticationSwitchUserEvent(current, originalUser)); this.eventPublisher.publishEvent(new AuthenticationSwitchUserEvent(current, originalUser));
} }
return original; return original;
} }
@ -299,45 +273,38 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
*/ */
private UsernamePasswordAuthenticationToken createSwitchUserToken(HttpServletRequest request, private UsernamePasswordAuthenticationToken createSwitchUserToken(HttpServletRequest request,
UserDetails targetUser) { UserDetails targetUser) {
UsernamePasswordAuthenticationToken targetUserRequest; UsernamePasswordAuthenticationToken targetUserRequest;
// grant an additional authority that contains the original Authentication object // grant an additional authority that contains the original Authentication object
// which will be used to 'exit' from the current switched user. // which will be used to 'exit' from the current switched user.
Authentication currentAuthentication = getCurrentAuthentication(request);
Authentication currentAuth; GrantedAuthority switchAuthority = new SwitchUserGrantedAuthority(this.switchAuthorityRole,
currentAuthentication);
try {
// SEC-1763. Check first if we are already switched.
currentAuth = attemptExitUser(request);
}
catch (AuthenticationCredentialsNotFoundException ex) {
currentAuth = SecurityContextHolder.getContext().getAuthentication();
}
GrantedAuthority switchAuthority = new SwitchUserGrantedAuthority(this.switchAuthorityRole, currentAuth);
// get the original authorities // get the original authorities
Collection<? extends GrantedAuthority> orig = targetUser.getAuthorities(); Collection<? extends GrantedAuthority> orig = targetUser.getAuthorities();
// Allow subclasses to change the authorities to be granted // Allow subclasses to change the authorities to be granted
if (this.switchUserAuthorityChanger != null) { if (this.switchUserAuthorityChanger != null) {
orig = this.switchUserAuthorityChanger.modifyGrantedAuthorities(targetUser, currentAuth, orig); orig = this.switchUserAuthorityChanger.modifyGrantedAuthorities(targetUser, currentAuthentication, orig);
} }
// add the new switch user authority // add the new switch user authority
List<GrantedAuthority> newAuths = new ArrayList<>(orig); List<GrantedAuthority> newAuths = new ArrayList<>(orig);
newAuths.add(switchAuthority); newAuths.add(switchAuthority);
// create the new authentication token // create the new authentication token
targetUserRequest = new UsernamePasswordAuthenticationToken(targetUser, targetUser.getPassword(), newAuths); targetUserRequest = new UsernamePasswordAuthenticationToken(targetUser, targetUser.getPassword(), newAuths);
// set details // set details
targetUserRequest.setDetails(this.authenticationDetailsSource.buildDetails(request)); targetUserRequest.setDetails(this.authenticationDetailsSource.buildDetails(request));
return targetUserRequest; return targetUserRequest;
} }
private Authentication getCurrentAuthentication(HttpServletRequest request) {
try {
// SEC-1763. Check first if we are already switched.
return attemptExitUser(request);
}
catch (AuthenticationCredentialsNotFoundException ex) {
return SecurityContextHolder.getContext().getAuthentication();
}
}
/** /**
* Find the original <code>Authentication</code> object from the current user's * Find the original <code>Authentication</code> object from the current user's
* granted authorities. A successfully switched user should have a * granted authorities. A successfully switched user should have a
@ -349,10 +316,8 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
*/ */
private Authentication getSourceAuthentication(Authentication current) { private Authentication getSourceAuthentication(Authentication current) {
Authentication original = null; Authentication original = null;
// iterate over granted authorities and find the 'switch user' authority // iterate over granted authorities and find the 'switch user' authority
Collection<? extends GrantedAuthority> authorities = current.getAuthorities(); Collection<? extends GrantedAuthority> authorities = current.getAuthorities();
for (GrantedAuthority auth : authorities) { for (GrantedAuthority auth : authorities) {
// check for switch user type of authority // check for switch user type of authority
if (auth instanceof SwitchUserGrantedAuthority) { if (auth instanceof SwitchUserGrantedAuthority) {
@ -360,7 +325,6 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
this.logger.debug("Found original switch user granted authority [" + original + "]"); this.logger.debug("Found original switch user granted authority [" + original + "]");
} }
} }
return original; return original;
} }

View File

@ -112,24 +112,28 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
this.logoutSuccessUrl = DEFAULT_LOGIN_PAGE_URL + "?logout"; this.logoutSuccessUrl = DEFAULT_LOGIN_PAGE_URL + "?logout";
this.failureUrl = DEFAULT_LOGIN_PAGE_URL + "?" + ERROR_PARAMETER_NAME; this.failureUrl = DEFAULT_LOGIN_PAGE_URL + "?" + ERROR_PARAMETER_NAME;
if (authFilter != null) { if (authFilter != null) {
this.formLoginEnabled = true; initAuthFilter(authFilter);
this.usernameParameter = authFilter.getUsernameParameter();
this.passwordParameter = authFilter.getPasswordParameter();
if (authFilter.getRememberMeServices() instanceof AbstractRememberMeServices) {
this.rememberMeParameter = ((AbstractRememberMeServices) authFilter.getRememberMeServices())
.getParameter();
}
} }
if (openIDFilter != null) { if (openIDFilter != null) {
this.openIdEnabled = true; initOpenIdFilter(openIDFilter);
this.openIDusernameParameter = "openid_identifier"; }
}
if (openIDFilter.getRememberMeServices() instanceof AbstractRememberMeServices) { private void initAuthFilter(UsernamePasswordAuthenticationFilter authFilter) {
this.openIDrememberMeParameter = ((AbstractRememberMeServices) openIDFilter.getRememberMeServices()) this.formLoginEnabled = true;
.getParameter(); this.usernameParameter = authFilter.getUsernameParameter();
} this.passwordParameter = authFilter.getPasswordParameter();
if (authFilter.getRememberMeServices() instanceof AbstractRememberMeServices) {
this.rememberMeParameter = ((AbstractRememberMeServices) authFilter.getRememberMeServices()).getParameter();
}
}
private void initOpenIdFilter(AbstractAuthenticationProcessingFilter openIDFilter) {
this.openIdEnabled = true;
this.openIDusernameParameter = "openid_identifier";
if (openIDFilter.getRememberMeServices() instanceof AbstractRememberMeServices) {
this.openIDrememberMeParameter = ((AbstractRememberMeServices) openIDFilter.getRememberMeServices())
.getParameter();
} }
} }
@ -214,11 +218,13 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
} }
@Override @Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException { throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req; doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
HttpServletResponse response = (HttpServletResponse) res; }
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
boolean loginError = isErrorPage(request); boolean loginError = isErrorPage(request);
boolean logoutSuccess = isLogoutSuccess(request); boolean logoutSuccess = isLogoutSuccess(request);
if (isLoginUrlRequest(request) || loginError || logoutSuccess) { if (isLoginUrlRequest(request) || loginError || logoutSuccess) {
@ -226,66 +232,69 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
response.setContentType("text/html;charset=UTF-8"); response.setContentType("text/html;charset=UTF-8");
response.setContentLength(loginPageHtml.getBytes(StandardCharsets.UTF_8).length); response.setContentLength(loginPageHtml.getBytes(StandardCharsets.UTF_8).length);
response.getWriter().write(loginPageHtml); response.getWriter().write(loginPageHtml);
return; return;
} }
chain.doFilter(request, response); chain.doFilter(request, response);
} }
private String generateLoginPageHtml(HttpServletRequest request, boolean loginError, boolean logoutSuccess) { private String generateLoginPageHtml(HttpServletRequest request, boolean loginError, boolean logoutSuccess) {
String errorMsg = "Invalid credentials"; String errorMsg = "Invalid credentials";
if (loginError) { if (loginError) {
HttpSession session = request.getSession(false); HttpSession session = request.getSession(false);
if (session != null) { if (session != null) {
AuthenticationException ex = (AuthenticationException) session AuthenticationException ex = (AuthenticationException) session
.getAttribute(WebAttributes.AUTHENTICATION_EXCEPTION); .getAttribute(WebAttributes.AUTHENTICATION_EXCEPTION);
errorMsg = (ex != null) ? ex.getMessage() : "Invalid credentials"; errorMsg = (ex != null) ? ex.getMessage() : "Invalid credentials";
} }
} }
StringBuilder sb = new StringBuilder();
sb.append("<!DOCTYPE html>\n" + "<html lang=\"en\">\n" + " <head>\n" + " <meta charset=\"utf-8\">\n"
+ " <meta name=\"viewport\" content=\"width=device-width, initial-scale=1, shrink-to-fit=no\">\n"
+ " <meta name=\"description\" content=\"\">\n" + " <meta name=\"author\" content=\"\">\n"
+ " <title>Please sign in</title>\n"
+ " <link href=\"https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0-beta/css/bootstrap.min.css\" rel=\"stylesheet\" integrity=\"sha384-/Y6pD6FV/Vv2HJnA6t+vslU6fwYXjCFtcEpHbNJ0lyAFsXTsjBbfaDjzALeQsN6M\" crossorigin=\"anonymous\">\n"
+ " <link href=\"https://getbootstrap.com/docs/4.0/examples/signin/signin.css\" rel=\"stylesheet\" crossorigin=\"anonymous\"/>\n"
+ " </head>\n" + " <body>\n" + " <div class=\"container\">\n");
String contextPath = request.getContextPath(); String contextPath = request.getContextPath();
StringBuilder sb = new StringBuilder();
sb.append("<!DOCTYPE html>\n");
sb.append("<html lang=\"en\">\n");
sb.append(" <head>\n");
sb.append(" <meta charset=\"utf-8\">\n");
sb.append(" <meta name=\"viewport\" content=\"width=device-width, initial-scale=1, shrink-to-fit=no\">\n");
sb.append(" <meta name=\"description\" content=\"\">\n");
sb.append(" <meta name=\"author\" content=\"\">\n");
sb.append(" <title>Please sign in</title>\n");
sb.append(" <link href=\"https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0-beta/css/bootstrap.min.css\" "
+ "rel=\"stylesheet\" integrity=\"sha384-/Y6pD6FV/Vv2HJnA6t+vslU6fwYXjCFtcEpHbNJ0lyAFsXTsjBbfaDjzALeQsN6M\" crossorigin=\"anonymous\">\n");
sb.append(" <link href=\"https://getbootstrap.com/docs/4.0/examples/signin/signin.css\" "
+ "rel=\"stylesheet\" crossorigin=\"anonymous\"/>\n");
sb.append(" </head>\n");
sb.append(" <body>\n");
sb.append(" <div class=\"container\">\n");
if (this.formLoginEnabled) { if (this.formLoginEnabled) {
sb.append(" <form class=\"form-signin\" method=\"post\" action=\"" + contextPath sb.append(" <form class=\"form-signin\" method=\"post\" action=\"" + contextPath
+ this.authenticationUrl + "\">\n" + this.authenticationUrl + "\">\n");
+ " <h2 class=\"form-signin-heading\">Please sign in</h2>\n" sb.append(" <h2 class=\"form-signin-heading\">Please sign in</h2>\n");
+ createError(loginError, errorMsg) + createLogoutSuccess(logoutSuccess) + " <p>\n" sb.append(createError(loginError, errorMsg) + createLogoutSuccess(logoutSuccess) + " <p>\n");
+ " <label for=\"username\" class=\"sr-only\">Username</label>\n" sb.append(" <label for=\"username\" class=\"sr-only\">Username</label>\n");
+ " <input type=\"text\" id=\"username\" name=\"" + this.usernameParameter sb.append(" <input type=\"text\" id=\"username\" name=\"" + this.usernameParameter
+ "\" class=\"form-control\" placeholder=\"Username\" required autofocus>\n" + " </p>\n" + "\" class=\"form-control\" placeholder=\"Username\" required autofocus>\n");
+ " <p>\n" + " <label for=\"password\" class=\"sr-only\">Password</label>\n" sb.append(" </p>\n");
+ " <input type=\"password\" id=\"password\" name=\"" + this.passwordParameter sb.append(" <p>\n");
+ "\" class=\"form-control\" placeholder=\"Password\" required>\n" + " </p>\n" sb.append(" <label for=\"password\" class=\"sr-only\">Password</label>\n");
+ createRememberMe(this.rememberMeParameter) + renderHiddenInputs(request) sb.append(" <input type=\"password\" id=\"password\" name=\"" + this.passwordParameter
+ " <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Sign in</button>\n" + "\" class=\"form-control\" placeholder=\"Password\" required>\n");
+ " </form>\n"); sb.append(" </p>\n");
sb.append(createRememberMe(this.rememberMeParameter) + renderHiddenInputs(request));
sb.append(" <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Sign in</button>\n");
sb.append(" </form>\n");
} }
if (this.openIdEnabled) { if (this.openIdEnabled) {
sb.append(" <form name=\"oidf\" class=\"form-signin\" method=\"post\" action=\"" + contextPath sb.append(" <form name=\"oidf\" class=\"form-signin\" method=\"post\" action=\"" + contextPath
+ this.openIDauthenticationUrl + "\">\n" + this.openIDauthenticationUrl + "\">\n");
+ " <h2 class=\"form-signin-heading\">Login with OpenID Identity</h2>\n" sb.append(" <h2 class=\"form-signin-heading\">Login with OpenID Identity</h2>\n");
+ createError(loginError, errorMsg) + createLogoutSuccess(logoutSuccess) + " <p>\n" sb.append(createError(loginError, errorMsg) + createLogoutSuccess(logoutSuccess) + " <p>\n");
+ " <label for=\"username\" class=\"sr-only\">Identity</label>\n" sb.append(" <label for=\"username\" class=\"sr-only\">Identity</label>\n");
+ " <input type=\"text\" id=\"username\" name=\"" + this.openIDusernameParameter sb.append(" <input type=\"text\" id=\"username\" name=\"" + this.openIDusernameParameter
+ "\" class=\"form-control\" placeholder=\"Username\" required autofocus>\n" + " </p>\n" + "\" class=\"form-control\" placeholder=\"Username\" required autofocus>\n");
+ createRememberMe(this.openIDrememberMeParameter) + renderHiddenInputs(request) sb.append(" </p>\n");
+ " <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Sign in</button>\n" sb.append(createRememberMe(this.openIDrememberMeParameter) + renderHiddenInputs(request));
+ " </form>\n"); sb.append(" <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Sign in</button>\n");
sb.append(" </form>\n");
} }
if (this.oauth2LoginEnabled) { if (this.oauth2LoginEnabled) {
sb.append("<h2 class=\"form-signin-heading\">Login with OAuth 2.0</h2>"); sb.append("<h2 class=\"form-signin-heading\">Login with OAuth 2.0</h2>");
sb.append(createError(loginError, errorMsg)); sb.append(createError(loginError, errorMsg));
@ -303,7 +312,6 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
} }
sb.append("</table>\n"); sb.append("</table>\n");
} }
if (this.saml2LoginEnabled) { if (this.saml2LoginEnabled) {
sb.append("<h2 class=\"form-signin-heading\">Login with SAML 2.0</h2>"); sb.append("<h2 class=\"form-signin-heading\">Login with SAML 2.0</h2>");
sb.append(createError(loginError, errorMsg)); sb.append(createError(loginError, errorMsg));
@ -323,15 +331,17 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
} }
sb.append("</div>\n"); sb.append("</div>\n");
sb.append("</body></html>"); sb.append("</body></html>");
return sb.toString(); return sb.toString();
} }
private String renderHiddenInputs(HttpServletRequest request) { private String renderHiddenInputs(HttpServletRequest request) {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
for (Map.Entry<String, String> input : this.resolveHiddenInputs.apply(request).entrySet()) { for (Map.Entry<String, String> input : this.resolveHiddenInputs.apply(request).entrySet()) {
sb.append("<input name=\"").append(input.getKey()).append("\" type=\"hidden\" value=\"") sb.append("<input name=\"");
.append(input.getValue()).append("\" />\n"); sb.append(input.getKey());
sb.append("\" type=\"hidden\" value=\"");
sb.append(input.getValue());
sb.append("\" />\n");
} }
return sb.toString(); return sb.toString();
} }
@ -356,13 +366,17 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
} }
private static String createError(boolean isError, String message) { private static String createError(boolean isError, String message) {
return isError ? "<div class=\"alert alert-danger\" role=\"alert\">" + HtmlUtils.htmlEscape(message) + "</div>" if (!isError) {
: ""; return "";
}
return "<div class=\"alert alert-danger\" role=\"alert\">" + HtmlUtils.htmlEscape(message) + "</div>";
} }
private static String createLogoutSuccess(boolean isLogoutSuccess) { private static String createLogoutSuccess(boolean isLogoutSuccess) {
return isLogoutSuccess ? "<div class=\"alert alert-success\" role=\"alert\">You have been signed out</div>" if (!isLogoutSuccess) {
: ""; return "";
}
return "<div class=\"alert alert-success\" role=\"alert\">You have been signed out</div>";
} }
private boolean matches(HttpServletRequest request, String url) { private boolean matches(HttpServletRequest request, String url) {
@ -371,20 +385,16 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
} }
String uri = request.getRequestURI(); String uri = request.getRequestURI();
int pathParamIndex = uri.indexOf(';'); int pathParamIndex = uri.indexOf(';');
if (pathParamIndex > 0) { if (pathParamIndex > 0) {
// strip everything after the first semi-colon // strip everything after the first semi-colon
uri = uri.substring(0, pathParamIndex); uri = uri.substring(0, pathParamIndex);
} }
if (request.getQueryString() != null) { if (request.getQueryString() != null) {
uri += "?" + request.getQueryString(); uri += "?" + request.getQueryString();
} }
if ("".equals(request.getContextPath())) { if ("".equals(request.getContextPath())) {
return uri.equals(url); return uri.equals(url);
} }
return uri.equals(request.getContextPath() + url); return uri.equals(request.getContextPath() + url);
} }

View File

@ -55,21 +55,34 @@ public class DefaultLogoutPageGeneratingFilter extends OncePerRequestFilter {
} }
private void renderLogout(HttpServletRequest request, HttpServletResponse response) throws IOException { private void renderLogout(HttpServletRequest request, HttpServletResponse response) throws IOException {
String page = "<!DOCTYPE html>\n" + "<html lang=\"en\">\n" + " <head>\n" + " <meta charset=\"utf-8\">\n" StringBuilder sb = new StringBuilder();
+ " <meta name=\"viewport\" content=\"width=device-width, initial-scale=1, shrink-to-fit=no\">\n" sb.append("<!DOCTYPE html>\n");
+ " <meta name=\"description\" content=\"\">\n" + " <meta name=\"author\" content=\"\">\n" sb.append("<html lang=\"en\">\n");
+ " <title>Confirm Log Out?</title>\n" sb.append(" <head>\n");
+ " <link href=\"https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0-beta/css/bootstrap.min.css\" rel=\"stylesheet\" integrity=\"sha384-/Y6pD6FV/Vv2HJnA6t+vslU6fwYXjCFtcEpHbNJ0lyAFsXTsjBbfaDjzALeQsN6M\" crossorigin=\"anonymous\">\n" sb.append(" <meta charset=\"utf-8\">\n");
+ " <link href=\"https://getbootstrap.com/docs/4.0/examples/signin/signin.css\" rel=\"stylesheet\" crossorigin=\"anonymous\"/>\n" sb.append(" <meta name=\"viewport\" content=\"width=device-width, initial-scale=1, shrink-to-fit=no\">\n");
+ " </head>\n" + " <body>\n" + " <div class=\"container\">\n" sb.append(" <meta name=\"description\" content=\"\">\n");
+ " <form class=\"form-signin\" method=\"post\" action=\"" + request.getContextPath() sb.append(" <meta name=\"author\" content=\"\">\n");
+ "/logout\">\n" + " <h2 class=\"form-signin-heading\">Are you sure you want to log out?</h2>\n" sb.append(" <title>Confirm Log Out?</title>\n");
+ renderHiddenInputs(request) sb.append(" <link href=\"https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0-beta/css/bootstrap.min.css\" "
+ " <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Log Out</button>\n" + "rel=\"stylesheet\" integrity=\"sha384-/Y6pD6FV/Vv2HJnA6t+vslU6fwYXjCFtcEpHbNJ0lyAFsXTsjBbfaDjzALeQsN6M\" "
+ " </form>\n" + " </div>\n" + " </body>\n" + "</html>"; + "crossorigin=\"anonymous\">\n");
sb.append(" <link href=\"https://getbootstrap.com/docs/4.0/examples/signin/signin.css\" "
+ "rel=\"stylesheet\" crossorigin=\"anonymous\"/>\n");
sb.append(" </head>\n");
sb.append(" <body>\n");
sb.append(" <div class=\"container\">\n");
sb.append(" <form class=\"form-signin\" method=\"post\" action=\"" + request.getContextPath()
+ "/logout\">\n");
sb.append(" <h2 class=\"form-signin-heading\">Are you sure you want to log out?</h2>\n");
sb.append(renderHiddenInputs(request)
+ " <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Log Out</button>\n");
sb.append(" </form>\n");
sb.append(" </div>\n");
sb.append(" </body>\n");
sb.append("</html>");
response.setContentType("text/html;charset=UTF-8"); response.setContentType("text/html;charset=UTF-8");
response.getWriter().write(page); response.getWriter().write(sb.toString());
} }
/** /**
@ -86,8 +99,11 @@ public class DefaultLogoutPageGeneratingFilter extends OncePerRequestFilter {
private String renderHiddenInputs(HttpServletRequest request) { private String renderHiddenInputs(HttpServletRequest request) {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
for (Map.Entry<String, String> input : this.resolveHiddenInputs.apply(request).entrySet()) { for (Map.Entry<String, String> input : this.resolveHiddenInputs.apply(request).entrySet()) {
sb.append("<input name=\"").append(input.getKey()).append("\" type=\"hidden\" value=\"") sb.append("<input name=\"");
.append(input.getValue()).append("\" />\n"); sb.append(input.getKey());
sb.append("\" type=\"hidden\" value=\"");
sb.append(input.getValue());
sb.append("\" />\n");
} }
return sb.toString(); return sb.toString();
} }

View File

@ -80,29 +80,17 @@ public class BasicAuthenticationConverter implements AuthenticationConverter {
if (header == null) { if (header == null) {
return null; return null;
} }
header = header.trim(); header = header.trim();
if (!StringUtils.startsWithIgnoreCase(header, AUTHENTICATION_SCHEME_BASIC)) { if (!StringUtils.startsWithIgnoreCase(header, AUTHENTICATION_SCHEME_BASIC)) {
return null; return null;
} }
if (header.equalsIgnoreCase(AUTHENTICATION_SCHEME_BASIC)) { if (header.equalsIgnoreCase(AUTHENTICATION_SCHEME_BASIC)) {
throw new BadCredentialsException("Empty basic authentication token"); throw new BadCredentialsException("Empty basic authentication token");
} }
byte[] base64Token = header.substring(6).getBytes(StandardCharsets.UTF_8); byte[] base64Token = header.substring(6).getBytes(StandardCharsets.UTF_8);
byte[] decoded; byte[] decoded = decode(base64Token);
try {
decoded = Base64.getDecoder().decode(base64Token);
}
catch (IllegalArgumentException ex) {
throw new BadCredentialsException("Failed to decode basic authentication token");
}
String token = new String(decoded, getCredentialsCharset(request)); String token = new String(decoded, getCredentialsCharset(request));
int delim = token.indexOf(":"); int delim = token.indexOf(":");
if (delim == -1) { if (delim == -1) {
throw new BadCredentialsException("Invalid basic authentication token"); throw new BadCredentialsException("Invalid basic authentication token");
} }
@ -112,6 +100,15 @@ public class BasicAuthenticationConverter implements AuthenticationConverter {
return result; return result;
} }
private byte[] decode(byte[] base64Token) {
try {
return Base64.getDecoder().decode(base64Token);
}
catch (IllegalArgumentException ex) {
throw new BadCredentialsException("Failed to decode basic authentication token");
}
}
protected Charset getCredentialsCharset(HttpServletRequest request) { protected Charset getCredentialsCharset(HttpServletRequest request) {
return getCredentialsCharset(); return getCredentialsCharset();
} }

View File

@ -24,6 +24,7 @@ import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManager;
@ -132,7 +133,6 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
@Override @Override
public void afterPropertiesSet() { public void afterPropertiesSet() {
Assert.notNull(this.authenticationManager, "An AuthenticationManager is required"); Assert.notNull(this.authenticationManager, "An AuthenticationManager is required");
if (!isIgnoreFailure()) { if (!isIgnoreFailure()) {
Assert.notNull(this.authenticationEntryPoint, "An AuthenticationEntryPoint is required"); Assert.notNull(this.authenticationEntryPoint, "An AuthenticationEntryPoint is required");
} }
@ -141,53 +141,34 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
@Override @Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain) protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException { throws IOException, ServletException {
final boolean debug = this.logger.isDebugEnabled();
try { try {
UsernamePasswordAuthenticationToken authRequest = this.authenticationConverter.convert(request); UsernamePasswordAuthenticationToken authRequest = this.authenticationConverter.convert(request);
if (authRequest == null) { if (authRequest == null) {
chain.doFilter(request, response); chain.doFilter(request, response);
return; return;
} }
String username = authRequest.getName(); String username = authRequest.getName();
this.logger.debug(
if (debug) { LogMessage.format("Basic Authentication Authorization header found for user '%s'", username));
this.logger.debug("Basic Authentication Authorization header found for user '" + username + "'");
}
if (authenticationIsRequired(username)) { if (authenticationIsRequired(username)) {
Authentication authResult = this.authenticationManager.authenticate(authRequest); Authentication authResult = this.authenticationManager.authenticate(authRequest);
this.logger.debug(LogMessage.format("Authentication success: %s", authResult));
if (debug) {
this.logger.debug("Authentication success: " + authResult);
}
SecurityContextHolder.getContext().setAuthentication(authResult); SecurityContextHolder.getContext().setAuthentication(authResult);
this.rememberMeServices.loginSuccess(request, response, authResult); this.rememberMeServices.loginSuccess(request, response, authResult);
onSuccessfulAuthentication(request, response, authResult); onSuccessfulAuthentication(request, response, authResult);
} }
} }
catch (AuthenticationException failed) { catch (AuthenticationException ex) {
SecurityContextHolder.clearContext(); SecurityContextHolder.clearContext();
this.logger.debug("Authentication request for failed!", ex);
if (debug) {
this.logger.debug("Authentication request for failed!", failed);
}
this.rememberMeServices.loginFail(request, response); this.rememberMeServices.loginFail(request, response);
onUnsuccessfulAuthentication(request, response, ex);
onUnsuccessfulAuthentication(request, response, failed);
if (this.ignoreFailure) { if (this.ignoreFailure) {
chain.doFilter(request, response); chain.doFilter(request, response);
} }
else { else {
this.authenticationEntryPoint.commence(request, response, failed); this.authenticationEntryPoint.commence(request, response, ex);
} }
return; return;
} }
@ -196,40 +177,26 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
private boolean authenticationIsRequired(String username) { private boolean authenticationIsRequired(String username) {
// Only reauthenticate if username doesn't match SecurityContextHolder and user // Only reauthenticate if username doesn't match SecurityContextHolder and user
// isn't authenticated // isn't authenticated (see SEC-53)
// (see SEC-53)
Authentication existingAuth = SecurityContextHolder.getContext().getAuthentication(); Authentication existingAuth = SecurityContextHolder.getContext().getAuthentication();
if (existingAuth == null || !existingAuth.isAuthenticated()) { if (existingAuth == null || !existingAuth.isAuthenticated()) {
return true; return true;
} }
// Limit username comparison to providers which use usernames (ie // Limit username comparison to providers which use usernames (ie
// UsernamePasswordAuthenticationToken) // UsernamePasswordAuthenticationToken) (see SEC-348)
// (see SEC-348)
if (existingAuth instanceof UsernamePasswordAuthenticationToken && !existingAuth.getName().equals(username)) { if (existingAuth instanceof UsernamePasswordAuthenticationToken && !existingAuth.getName().equals(username)) {
return true; return true;
} }
// Handle unusual condition where an AnonymousAuthenticationToken is already // Handle unusual condition where an AnonymousAuthenticationToken is already
// present // present. This shouldn't happen very often, as BasicProcessingFitler is meant to
// This shouldn't happen very often, as BasicProcessingFitler is meant to be // be earlier in the filter chain than AnonymousAuthenticationFilter.
// earlier in the filter // Nevertheless, presence of both an AnonymousAuthenticationToken together with a
// chain than AnonymousAuthenticationFilter. Nevertheless, presence of both an // BASIC authentication request header should indicate reauthentication using the
// AnonymousAuthenticationToken
// together with a BASIC authentication request header should indicate
// reauthentication using the
// BASIC protocol is desirable. This behaviour is also consistent with that // BASIC protocol is desirable. This behaviour is also consistent with that
// provided by form and digest, // provided by form and digest, both of which force re-authentication if the
// both of which force re-authentication if the respective header is detected (and // respective header is detected (and in doing so replace/ any existing
// in doing so replace // AnonymousAuthenticationToken). See SEC-610.
// any existing AnonymousAuthenticationToken). See SEC-610. return (existingAuth instanceof AnonymousAuthenticationToken);
if (existingAuth instanceof AnonymousAuthenticationToken) {
return true;
}
return false;
} }
protected void onSuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, protected void onSuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response,

View File

@ -44,18 +44,14 @@ final class DigestAuthUtils {
if (str == null) { if (str == null) {
return null; return null;
} }
int len = str.length(); int len = str.length();
if (len == 0) { if (len == 0) {
return EMPTY_STRING_ARRAY; return EMPTY_STRING_ARRAY;
} }
List<String> list = new ArrayList<>(); List<String> list = new ArrayList<>();
int i = 0; int i = 0;
int start = 0; int start = 0;
boolean match = false; boolean match = false;
while (i < len) { while (i < len) {
if (str.charAt(i) == '"') { if (str.charAt(i) == '"') {
i++; i++;
@ -83,7 +79,6 @@ final class DigestAuthUtils {
if (match) { if (match) {
list.add(str.substring(start, i)); list.add(str.substring(start, i));
} }
return list.toArray(new String[0]); return list.toArray(new String[0]);
} }
@ -108,32 +103,19 @@ final class DigestAuthUtils {
static String generateDigest(boolean passwordAlreadyEncoded, String username, String realm, String password, static String generateDigest(boolean passwordAlreadyEncoded, String username, String realm, String password,
String httpMethod, String uri, String qop, String nonce, String nc, String cnonce) String httpMethod, String uri, String qop, String nonce, String nc, String cnonce)
throws IllegalArgumentException { throws IllegalArgumentException {
String a1Md5;
String a2 = httpMethod + ":" + uri; String a2 = httpMethod + ":" + uri;
String a1Md5 = (!passwordAlreadyEncoded) ? DigestAuthUtils.encodePasswordInA1Format(username, realm, password)
: password;
String a2Md5 = md5Hex(a2); String a2Md5 = md5Hex(a2);
if (passwordAlreadyEncoded) {
a1Md5 = password;
}
else {
a1Md5 = DigestAuthUtils.encodePasswordInA1Format(username, realm, password);
}
String digest;
if (qop == null) { if (qop == null) {
// as per RFC 2069 compliant clients (also reaffirmed by RFC 2617) // as per RFC 2069 compliant clients (also reaffirmed by RFC 2617)
digest = a1Md5 + ":" + nonce + ":" + a2Md5; return md5Hex(a1Md5 + ":" + nonce + ":" + a2Md5);
} }
else if ("auth".equals(qop)) { if ("auth".equals(qop)) {
// As per RFC 2617 compliant clients // As per RFC 2617 compliant clients
digest = a1Md5 + ":" + nonce + ":" + nc + ":" + cnonce + ":" + qop + ":" + a2Md5; return md5Hex(a1Md5 + ":" + nonce + ":" + nc + ":" + cnonce + ":" + qop + ":" + a2Md5);
} }
else { throw new IllegalArgumentException("This method does not support a qop: '" + qop + "'");
throw new IllegalArgumentException("This method does not support a qop: '" + qop + "'");
}
return md5Hex(digest);
} }
/** /**
@ -157,28 +139,15 @@ final class DigestAuthUtils {
if ((array == null) || (array.length == 0)) { if ((array == null) || (array.length == 0)) {
return null; return null;
} }
Map<String, String> map = new HashMap<>(); Map<String, String> map = new HashMap<>();
for (String s : array) { for (String s : array) {
String postRemove; String postRemove = (removeCharacters != null) ? StringUtils.replace(s, removeCharacters, "") : s;
if (removeCharacters == null) {
postRemove = s;
}
else {
postRemove = StringUtils.replace(s, removeCharacters, "");
}
String[] splitThisArrayElement = split(postRemove, delimiter); String[] splitThisArrayElement = split(postRemove, delimiter);
if (splitThisArrayElement == null) { if (splitThisArrayElement == null) {
continue; continue;
} }
map.put(splitThisArrayElement[0].trim(), splitThisArrayElement[1].trim()); map.put(splitThisArrayElement[0].trim(), splitThisArrayElement[1].trim());
} }
return map; return map;
} }
@ -196,33 +165,24 @@ final class DigestAuthUtils {
static String[] split(String toSplit, String delimiter) { static String[] split(String toSplit, String delimiter) {
Assert.hasLength(toSplit, "Cannot split a null or empty string"); Assert.hasLength(toSplit, "Cannot split a null or empty string");
Assert.hasLength(delimiter, "Cannot use a null or empty delimiter to split a string"); Assert.hasLength(delimiter, "Cannot use a null or empty delimiter to split a string");
Assert.isTrue(delimiter.length() == 1, "Delimiter can only be one character in length");
if (delimiter.length() != 1) {
throw new IllegalArgumentException("Delimiter can only be one character in length");
}
int offset = toSplit.indexOf(delimiter); int offset = toSplit.indexOf(delimiter);
if (offset < 0) { if (offset < 0) {
return null; return null;
} }
String beforeDelimiter = toSplit.substring(0, offset); String beforeDelimiter = toSplit.substring(0, offset);
String afterDelimiter = toSplit.substring(offset + 1); String afterDelimiter = toSplit.substring(offset + 1);
return new String[] { beforeDelimiter, afterDelimiter }; return new String[] { beforeDelimiter, afterDelimiter };
} }
static String md5Hex(String data) { static String md5Hex(String data) {
MessageDigest digest;
try { try {
digest = MessageDigest.getInstance("MD5"); MessageDigest digest = MessageDigest.getInstance("MD5");
return new String(Hex.encode(digest.digest(data.getBytes())));
} }
catch (NoSuchAlgorithmException ex) { catch (NoSuchAlgorithmException ex) {
throw new IllegalStateException("No MD5 algorithm available!"); throw new IllegalStateException("No MD5 algorithm available!");
} }
return new String(Hex.encode(digest.digest(data.getBytes())));
} }
} }

View File

@ -27,9 +27,11 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.Ordered; import org.springframework.core.Ordered;
import org.springframework.core.log.LogMessage;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.util.Assert;
/** /**
* Used by the <code>SecurityEnforcementFilter</code> to commence authentication via the * Used by the <code>SecurityEnforcementFilter</code> to commence authentication via the
@ -68,44 +70,30 @@ public class DigestAuthenticationEntryPoint implements AuthenticationEntryPoint,
@Override @Override
public void afterPropertiesSet() { public void afterPropertiesSet() {
if ((this.realmName == null) || "".equals(this.realmName)) { Assert.hasLength(this.realmName, "realmName must be specified");
throw new IllegalArgumentException("realmName must be specified"); Assert.hasLength(this.key, "key must be specified");
}
if ((this.key == null) || "".equals(this.key)) {
throw new IllegalArgumentException("key must be specified");
}
} }
@Override @Override
public void commence(HttpServletRequest request, HttpServletResponse response, public void commence(HttpServletRequest request, HttpServletResponse response,
AuthenticationException authException) throws IOException { AuthenticationException authException) throws IOException {
HttpServletResponse httpResponse = response; // compute a nonce (do not use remote IP address due to proxy farms) format of
// nonce is: base64(expirationTime + ":" + md5Hex(expirationTime + ":" + key))
// compute a nonce (do not use remote IP address due to proxy farms)
// format of nonce is:
// base64(expirationTime + ":" + md5Hex(expirationTime + ":" + key))
long expiryTime = System.currentTimeMillis() + (this.nonceValiditySeconds * 1000); long expiryTime = System.currentTimeMillis() + (this.nonceValiditySeconds * 1000);
String signatureValue = DigestAuthUtils.md5Hex(expiryTime + ":" + this.key); String signatureValue = DigestAuthUtils.md5Hex(expiryTime + ":" + this.key);
String nonceValue = expiryTime + ":" + signatureValue; String nonceValue = expiryTime + ":" + signatureValue;
String nonceValueBase64 = new String(Base64.getEncoder().encode(nonceValue.getBytes())); String nonceValueBase64 = new String(Base64.getEncoder().encode(nonceValue.getBytes()));
// qop is quality of protection, as defined by RFC 2617. We do not use opaque due
// qop is quality of protection, as defined by RFC 2617. // to IE violation of RFC 2617 in not representing opaque on subsequent requests
// we do not use opaque due to IE violation of RFC 2617 in not // in same session.
// representing opaque on subsequent requests in same session.
String authenticateHeader = "Digest realm=\"" + this.realmName + "\", " + "qop=\"auth\", nonce=\"" String authenticateHeader = "Digest realm=\"" + this.realmName + "\", " + "qop=\"auth\", nonce=\""
+ nonceValueBase64 + "\""; + nonceValueBase64 + "\"";
if (authException instanceof NonceExpiredException) { if (authException instanceof NonceExpiredException) {
authenticateHeader = authenticateHeader + ", stale=\"true\""; authenticateHeader = authenticateHeader + ", stale=\"true\"";
} }
logger.debug(LogMessage.format("WWW-Authenticate header sent to user agent: %s", authenticateHeader));
if (logger.isDebugEnabled()) { response.addHeader("WWW-Authenticate", authenticateHeader);
logger.debug("WWW-Authenticate header sent to user agent: " + authenticateHeader); response.sendError(HttpStatus.UNAUTHORIZED.value(), HttpStatus.UNAUTHORIZED.getReasonPhrase());
}
httpResponse.addHeader("WWW-Authenticate", authenticateHeader);
httpResponse.sendError(HttpStatus.UNAUTHORIZED.value(), HttpStatus.UNAUTHORIZED.getReasonPhrase());
} }
public String getKey() { public String getKey() {

View File

@ -33,6 +33,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.context.MessageSource; import org.springframework.context.MessageSource;
import org.springframework.context.MessageSourceAware; import org.springframework.context.MessageSourceAware;
import org.springframework.context.support.MessageSourceAccessor; import org.springframework.context.support.MessageSourceAccessor;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.BadCredentialsException;
@ -112,136 +113,105 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
} }
@Override @Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException { throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req; doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
HttpServletResponse response = (HttpServletResponse) res; }
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
String header = request.getHeader("Authorization"); String header = request.getHeader("Authorization");
if (header == null || !header.startsWith("Digest ")) { if (header == null || !header.startsWith("Digest ")) {
chain.doFilter(request, response); chain.doFilter(request, response);
return; return;
} }
logger.debug(LogMessage.format("Digest Authorization header received from user agent: %s", header));
if (logger.isDebugEnabled()) {
logger.debug("Digest Authorization header received from user agent: " + header);
}
DigestData digestAuth = new DigestData(header); DigestData digestAuth = new DigestData(header);
try { try {
digestAuth.validateAndDecode(this.authenticationEntryPoint.getKey(), digestAuth.validateAndDecode(this.authenticationEntryPoint.getKey(),
this.authenticationEntryPoint.getRealmName()); this.authenticationEntryPoint.getRealmName());
} }
catch (BadCredentialsException ex) { catch (BadCredentialsException ex) {
fail(request, response, ex); fail(request, response, ex);
return; return;
} }
// Lookup password for presented username. N.B. DAO-provided password MUST be
// Lookup password for presented username // clear text - not encoded/salted (unless this instance's passwordAlreadyEncoded
// NB: DAO-provided password MUST be clear text - not encoded/salted // property is 'false')
// (unless this instance's passwordAlreadyEncoded property is 'false')
boolean cacheWasUsed = true; boolean cacheWasUsed = true;
UserDetails user = this.userCache.getUserFromCache(digestAuth.getUsername()); UserDetails user = this.userCache.getUserFromCache(digestAuth.getUsername());
String serverDigestMd5; String serverDigestMd5;
try { try {
if (user == null) { if (user == null) {
cacheWasUsed = false; cacheWasUsed = false;
user = this.userDetailsService.loadUserByUsername(digestAuth.getUsername()); user = this.userDetailsService.loadUserByUsername(digestAuth.getUsername());
if (user == null) { if (user == null) {
throw new AuthenticationServiceException( throw new AuthenticationServiceException(
"AuthenticationDao returned null, which is an interface contract violation"); "AuthenticationDao returned null, which is an interface contract violation");
} }
this.userCache.putUserInCache(user); this.userCache.putUserInCache(user);
} }
serverDigestMd5 = digestAuth.calculateServerDigest(user.getPassword(), request.getMethod()); serverDigestMd5 = digestAuth.calculateServerDigest(user.getPassword(), request.getMethod());
// If digest is incorrect, try refreshing from backend and recomputing // If digest is incorrect, try refreshing from backend and recomputing
if (!serverDigestMd5.equals(digestAuth.getResponse()) && cacheWasUsed) { if (!serverDigestMd5.equals(digestAuth.getResponse()) && cacheWasUsed) {
if (logger.isDebugEnabled()) { logger.debug("Digest comparison failure; trying to refresh user from DAO in case password had changed");
logger.debug(
"Digest comparison failure; trying to refresh user from DAO in case password had changed");
}
user = this.userDetailsService.loadUserByUsername(digestAuth.getUsername()); user = this.userDetailsService.loadUserByUsername(digestAuth.getUsername());
this.userCache.putUserInCache(user); this.userCache.putUserInCache(user);
serverDigestMd5 = digestAuth.calculateServerDigest(user.getPassword(), request.getMethod()); serverDigestMd5 = digestAuth.calculateServerDigest(user.getPassword(), request.getMethod());
} }
} }
catch (UsernameNotFoundException notFound) { catch (UsernameNotFoundException ex) {
fail(request, response, String message = this.messages.getMessage("DigestAuthenticationFilter.usernameNotFound",
new BadCredentialsException(this.messages.getMessage("DigestAuthenticationFilter.usernameNotFound", new Object[] { digestAuth.getUsername() }, "Username {0} not found");
new Object[] { digestAuth.getUsername() }, "Username {0} not found"))); fail(request, response, new BadCredentialsException(message));
return; return;
} }
// If digest is still incorrect, definitely reject authentication attempt // If digest is still incorrect, definitely reject authentication attempt
if (!serverDigestMd5.equals(digestAuth.getResponse())) { if (!serverDigestMd5.equals(digestAuth.getResponse())) {
if (logger.isDebugEnabled()) { logger.debug(LogMessage.format(
logger.debug("Expected response: '" + serverDigestMd5 + "' but received: '" + digestAuth.getResponse() "Expected response: '%s' but received: '%s'; is AuthenticationDao returning clear text passwords?",
+ "'; is AuthenticationDao returning clear text passwords?"); serverDigestMd5, digestAuth.getResponse()));
} String message = this.messages.getMessage("DigestAuthenticationFilter.incorrectResponse",
"Incorrect response");
fail(request, response, new BadCredentialsException( fail(request, response, new BadCredentialsException(message));
this.messages.getMessage("DigestAuthenticationFilter.incorrectResponse", "Incorrect response")));
return; return;
} }
// To get this far, the digest must have been valid // To get this far, the digest must have been valid
// Check the nonce has not expired // Check the nonce has not expired
// We do this last so we can direct the user agent its nonce is stale // We do this last so we can direct the user agent its nonce is stale
// but the request was otherwise appearing to be valid // but the request was otherwise appearing to be valid
if (digestAuth.isNonceExpired()) { if (digestAuth.isNonceExpired()) {
fail(request, response, new NonceExpiredException(this.messages String message = this.messages.getMessage("DigestAuthenticationFilter.nonceExpired",
.getMessage("DigestAuthenticationFilter.nonceExpired", "Nonce has expired/timed out"))); "Nonce has expired/timed out");
fail(request, response, new NonceExpiredException(message));
return; return;
} }
logger.debug(LogMessage.format("Authentication success for user: '%s' with response: '%s'",
if (logger.isDebugEnabled()) { digestAuth.getUsername(), digestAuth.getResponse()));
logger.debug("Authentication success for user: '" + digestAuth.getUsername() + "' with response: '"
+ digestAuth.getResponse() + "'");
}
Authentication authentication = createSuccessfulAuthentication(request, user); Authentication authentication = createSuccessfulAuthentication(request, user);
SecurityContext context = SecurityContextHolder.createEmptyContext(); SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authentication); context.setAuthentication(authentication);
SecurityContextHolder.setContext(context); SecurityContextHolder.setContext(context);
chain.doFilter(request, response); chain.doFilter(request, response);
} }
private Authentication createSuccessfulAuthentication(HttpServletRequest request, UserDetails user) { private Authentication createSuccessfulAuthentication(HttpServletRequest request, UserDetails user) {
UsernamePasswordAuthenticationToken authRequest; UsernamePasswordAuthenticationToken authRequest = getAuthRequest(user);
if (this.createAuthenticatedToken) {
authRequest = new UsernamePasswordAuthenticationToken(user, user.getPassword(), user.getAuthorities());
}
else {
authRequest = new UsernamePasswordAuthenticationToken(user, user.getPassword());
}
authRequest.setDetails(this.authenticationDetailsSource.buildDetails(request)); authRequest.setDetails(this.authenticationDetailsSource.buildDetails(request));
return authRequest; return authRequest;
} }
private UsernamePasswordAuthenticationToken getAuthRequest(UserDetails user) {
if (this.createAuthenticatedToken) {
return new UsernamePasswordAuthenticationToken(user, user.getPassword(), user.getAuthorities());
}
return new UsernamePasswordAuthenticationToken(user, user.getPassword());
}
private void fail(HttpServletRequest request, HttpServletResponse response, AuthenticationException failed) private void fail(HttpServletRequest request, HttpServletResponse response, AuthenticationException failed)
throws IOException, ServletException { throws IOException, ServletException {
SecurityContextHolder.getContext().setAuthentication(null); SecurityContextHolder.getContext().setAuthentication(null);
logger.debug(failed);
if (logger.isDebugEnabled()) {
logger.debug(failed);
}
this.authenticationEntryPoint.commence(request, response, failed); this.authenticationEntryPoint.commence(request, response, failed);
} }
@ -326,7 +296,6 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
this.section212response = header.substring(7); this.section212response = header.substring(7);
String[] headerEntries = DigestAuthUtils.splitIgnoringQuotes(this.section212response, ','); String[] headerEntries = DigestAuthUtils.splitIgnoringQuotes(this.section212response, ',');
Map<String, String> headerMap = DigestAuthUtils.splitEachArrayElementAndCreateMap(headerEntries, "=", "\""); Map<String, String> headerMap = DigestAuthUtils.splitEachArrayElementAndCreateMap(headerEntries, "=", "\"");
this.username = headerMap.get("username"); this.username = headerMap.get("username");
this.realm = headerMap.get("realm"); this.realm = headerMap.get("realm");
this.nonce = headerMap.get("nonce"); this.nonce = headerMap.get("nonce");
@ -335,11 +304,9 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
this.qop = headerMap.get("qop"); // RFC 2617 extension this.qop = headerMap.get("qop"); // RFC 2617 extension
this.nc = headerMap.get("nc"); // RFC 2617 extension this.nc = headerMap.get("nc"); // RFC 2617 extension
this.cnonce = headerMap.get("cnonce"); // RFC 2617 extension this.cnonce = headerMap.get("cnonce"); // RFC 2617 extension
logger.debug(
if (logger.isDebugEnabled()) { LogMessage.format("Extracted username: '%s'; realm: '%s'; nonce: '%s'; uri: '%s'; response: '%s'",
logger.debug("Extracted username: '" + this.username + "'; realm: '" + this.realm + "'; nonce: '" this.username, this.realm, this.nonce, this.uri, this.response));
+ this.nonce + "'; uri: '" + this.uri + "'; response: '" + this.response + "'");
}
} }
void validateAndDecode(String entryPointKey, String expectedRealm) throws BadCredentialsException { void validateAndDecode(String entryPointKey, String expectedRealm) throws BadCredentialsException {
@ -353,23 +320,18 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
// Check all required parameters for an "auth" qop were supplied (ie RFC 2617) // Check all required parameters for an "auth" qop were supplied (ie RFC 2617)
if ("auth".equals(this.qop)) { if ("auth".equals(this.qop)) {
if ((this.nc == null) || (this.cnonce == null)) { if ((this.nc == null) || (this.cnonce == null)) {
if (logger.isDebugEnabled()) { logger.debug(LogMessage.format("extracted nc: '%s'; cnonce: '%s'", this.nc, this.cnonce));
logger.debug("extracted nc: '" + this.nc + "'; cnonce: '" + this.cnonce + "'");
}
throw new BadCredentialsException(DigestAuthenticationFilter.this.messages.getMessage( throw new BadCredentialsException(DigestAuthenticationFilter.this.messages.getMessage(
"DigestAuthenticationFilter.missingAuth", new Object[] { this.section212response }, "DigestAuthenticationFilter.missingAuth", new Object[] { this.section212response },
"Missing mandatory digest value; received header {0}")); "Missing mandatory digest value; received header {0}"));
} }
} }
// Check realm name equals what we expected // Check realm name equals what we expected
if (!expectedRealm.equals(this.realm)) { if (!expectedRealm.equals(this.realm)) {
throw new BadCredentialsException(DigestAuthenticationFilter.this.messages.getMessage( throw new BadCredentialsException(DigestAuthenticationFilter.this.messages.getMessage(
"DigestAuthenticationFilter.incorrectRealm", new Object[] { this.realm, expectedRealm }, "DigestAuthenticationFilter.incorrectRealm", new Object[] { this.realm, expectedRealm },
"Response realm name '{0}' does not match system realm name of '{1}'")); "Response realm name '{0}' does not match system realm name of '{1}'"));
} }
// Check nonce was Base64 encoded (as sent by DigestAuthenticationEntryPoint) // Check nonce was Base64 encoded (as sent by DigestAuthenticationEntryPoint)
try { try {
Base64.getDecoder().decode(this.nonce.getBytes()); Base64.getDecoder().decode(this.nonce.getBytes());
@ -379,21 +341,16 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
DigestAuthenticationFilter.this.messages.getMessage("DigestAuthenticationFilter.nonceEncoding", DigestAuthenticationFilter.this.messages.getMessage("DigestAuthenticationFilter.nonceEncoding",
new Object[] { this.nonce }, "Nonce is not encoded in Base64; received nonce {0}")); new Object[] { this.nonce }, "Nonce is not encoded in Base64; received nonce {0}"));
} }
// Decode nonce from Base64 format of nonce is: base64(expirationTime + ":" +
// Decode nonce from Base64 // md5Hex(expirationTime + ":" + key))
// format of nonce is:
// base64(expirationTime + ":" + md5Hex(expirationTime + ":" + key))
String nonceAsPlainText = new String(Base64.getDecoder().decode(this.nonce.getBytes())); String nonceAsPlainText = new String(Base64.getDecoder().decode(this.nonce.getBytes()));
String[] nonceTokens = StringUtils.delimitedListToStringArray(nonceAsPlainText, ":"); String[] nonceTokens = StringUtils.delimitedListToStringArray(nonceAsPlainText, ":");
if (nonceTokens.length != 2) { if (nonceTokens.length != 2) {
throw new BadCredentialsException(DigestAuthenticationFilter.this.messages.getMessage( throw new BadCredentialsException(DigestAuthenticationFilter.this.messages.getMessage(
"DigestAuthenticationFilter.nonceNotTwoTokens", new Object[] { nonceAsPlainText }, "DigestAuthenticationFilter.nonceNotTwoTokens", new Object[] { nonceAsPlainText },
"Nonce should have yielded two tokens but was {0}")); "Nonce should have yielded two tokens but was {0}"));
} }
// Extract expiry time from nonce // Extract expiry time from nonce
try { try {
this.nonceExpiryTime = new Long(nonceTokens[0]); this.nonceExpiryTime = new Long(nonceTokens[0]);
} }
@ -402,10 +359,8 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
"DigestAuthenticationFilter.nonceNotNumeric", new Object[] { nonceAsPlainText }, "DigestAuthenticationFilter.nonceNotNumeric", new Object[] { nonceAsPlainText },
"Nonce token should have yielded a numeric first token, but was {0}")); "Nonce token should have yielded a numeric first token, but was {0}"));
} }
// Check signature of nonce matches this expiry time // Check signature of nonce matches this expiry time
String expectedNonceSignature = DigestAuthUtils.md5Hex(this.nonceExpiryTime + ":" + entryPointKey); String expectedNonceSignature = DigestAuthUtils.md5Hex(this.nonceExpiryTime + ":" + entryPointKey);
if (!expectedNonceSignature.equals(nonceTokens[1])) { if (!expectedNonceSignature.equals(nonceTokens[1])) {
throw new BadCredentialsException(DigestAuthenticationFilter.this.messages.getMessage( throw new BadCredentialsException(DigestAuthenticationFilter.this.messages.getMessage(
"DigestAuthenticationFilter.nonceCompromised", new Object[] { nonceAsPlainText }, "DigestAuthenticationFilter.nonceCompromised", new Object[] { nonceAsPlainText },
@ -414,9 +369,8 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
} }
String calculateServerDigest(String password, String httpMethod) { String calculateServerDigest(String password, String httpMethod) {
// Compute the expected response-digest (will be in hex form) // Compute the expected response-digest (will be in hex form). Don't catch
// IllegalArgumentException (already checked validity)
// Don't catch IllegalArgumentException (already checked validity)
return DigestAuthUtils.generateDigest(DigestAuthenticationFilter.this.passwordAlreadyEncoded, this.username, return DigestAuthUtils.generateDigest(DigestAuthenticationFilter.this.passwordAlreadyEncoded, this.username,
this.realm, password, httpMethod, this.uri, this.qop, this.nonce, this.nc, this.cnonce); this.realm, password, httpMethod, this.uri, this.qop, this.nonce, this.nc, this.cnonce);
} }

View File

@ -105,9 +105,7 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet
if (authPrincipal.errorOnInvalidType()) { if (authPrincipal.errorOnInvalidType()) {
throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType()); throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType());
} }
else { return null;
return null;
}
} }
return principal; return principal;
} }

View File

@ -173,11 +173,8 @@ public abstract class AbstractSecurityWebApplicationInitializer implements WebAp
*/ */
private void registerFilters(ServletContext servletContext, boolean insertBeforeOtherFilters, Filter... filters) { private void registerFilters(ServletContext servletContext, boolean insertBeforeOtherFilters, Filter... filters) {
Assert.notEmpty(filters, "filters cannot be null or empty"); Assert.notEmpty(filters, "filters cannot be null or empty");
for (Filter filter : filters) { for (Filter filter : filters) {
if (filter == null) { Assert.notNull(filter, () -> "filters cannot contain null values. Got " + Arrays.asList(filters));
throw new IllegalArgumentException("filters cannot contain null values. Got " + Arrays.asList(filters));
}
String filterName = Conventions.getVariableName(filter); String filterName = Conventions.getVariableName(filter);
registerFilter(servletContext, insertBeforeOtherFilters, filterName, filter); registerFilter(servletContext, insertBeforeOtherFilters, filterName, filter);
} }
@ -195,10 +192,8 @@ public abstract class AbstractSecurityWebApplicationInitializer implements WebAp
private void registerFilter(ServletContext servletContext, boolean insertBeforeOtherFilters, String filterName, private void registerFilter(ServletContext servletContext, boolean insertBeforeOtherFilters, String filterName,
Filter filter) { Filter filter) {
Dynamic registration = servletContext.addFilter(filterName, filter); Dynamic registration = servletContext.addFilter(filterName, filter);
if (registration == null) { Assert.state(registration != null, () -> "Duplicate Filter registration for '" + filterName
throw new IllegalStateException("Duplicate Filter registration for '" + filterName + "'. Check to ensure the Filter is only configured once.");
+ "'. Check to ensure the Filter is only configured once.");
}
registration.setAsyncSupported(isAsyncSecuritySupported()); registration.setAsyncSupported(isAsyncSecuritySupported());
EnumSet<DispatcherType> dispatcherTypes = getSecurityDispatcherTypes(); EnumSet<DispatcherType> dispatcherTypes = getSecurityDispatcherTypes();
registration.addMappingForUrlPatterns(dispatcherTypes, !insertBeforeOtherFilters, "/*"); registration.addMappingForUrlPatterns(dispatcherTypes, !insertBeforeOtherFilters, "/*");

View File

@ -28,6 +28,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationTrustResolver; import org.springframework.security.authentication.AuthenticationTrustResolver;
import org.springframework.security.authentication.AuthenticationTrustResolverImpl; import org.springframework.security.authentication.AuthenticationTrustResolverImpl;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@ -115,24 +116,18 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
HttpServletRequest request = requestResponseHolder.getRequest(); HttpServletRequest request = requestResponseHolder.getRequest();
HttpServletResponse response = requestResponseHolder.getResponse(); HttpServletResponse response = requestResponseHolder.getResponse();
HttpSession httpSession = request.getSession(false); HttpSession httpSession = request.getSession(false);
SecurityContext context = readSecurityContextFromSession(httpSession); SecurityContext context = readSecurityContextFromSession(httpSession);
if (context == null) { if (context == null) {
if (this.logger.isDebugEnabled()) { this.logger.debug(LogMessage.format(
this.logger.debug("No SecurityContext was available from the HttpSession: " + httpSession + ". " "No SecurityContext was available from the HttpSession: %s. A new one will be created.",
+ "A new one will be created."); httpSession));
}
context = generateNewContext(); context = generateNewContext();
} }
SaveToSessionResponseWrapper wrappedResponse = new SaveToSessionResponseWrapper(response, request, SaveToSessionResponseWrapper wrappedResponse = new SaveToSessionResponseWrapper(response, request,
httpSession != null, context); httpSession != null, context);
requestResponseHolder.setResponse(wrappedResponse); requestResponseHolder.setResponse(wrappedResponse);
requestResponseHolder.setRequest(new SaveToSessionRequestWrapper(request, wrappedResponse)); requestResponseHolder.setRequest(new SaveToSessionRequestWrapper(request, wrappedResponse));
return context; return context;
} }
@ -140,13 +135,10 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
public void saveContext(SecurityContext context, HttpServletRequest request, HttpServletResponse response) { public void saveContext(SecurityContext context, HttpServletRequest request, HttpServletResponse response) {
SaveContextOnUpdateOrErrorResponseWrapper responseWrapper = WebUtils.getNativeResponse(response, SaveContextOnUpdateOrErrorResponseWrapper responseWrapper = WebUtils.getNativeResponse(response,
SaveContextOnUpdateOrErrorResponseWrapper.class); SaveContextOnUpdateOrErrorResponseWrapper.class);
if (responseWrapper == null) { Assert.state(responseWrapper != null, () -> "Cannot invoke saveContext on response " + response
throw new IllegalStateException("Cannot invoke saveContext on response " + response + ". You must use the HttpRequestResponseHolder.response after invoking loadContext");
+ ". You must use the HttpRequestResponseHolder.response after invoking loadContext"); // saveContext() might already be called by the response wrapper if something in
} // the chain called sendError() or sendRedirect(). This ensures we only call it
// saveContext() might already be called by the response wrapper
// if something in the chain called sendError() or sendRedirect(). This ensures we
// only call it
// once per request. // once per request.
if (!responseWrapper.isContextSaved()) { if (!responseWrapper.isContextSaved()) {
responseWrapper.saveContext(context); responseWrapper.saveContext(context);
@ -156,11 +148,9 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
@Override @Override
public boolean containsContext(HttpServletRequest request) { public boolean containsContext(HttpServletRequest request) {
HttpSession session = request.getSession(false); HttpSession session = request.getSession(false);
if (session == null) { if (session == null) {
return false; return false;
} }
return session.getAttribute(this.springSecurityContextKey) != null; return session.getAttribute(this.springSecurityContextKey) != null;
} }
@ -168,47 +158,30 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
* @param httpSession the session obtained from the request. * @param httpSession the session obtained from the request.
*/ */
private SecurityContext readSecurityContextFromSession(HttpSession httpSession) { private SecurityContext readSecurityContextFromSession(HttpSession httpSession) {
final boolean debug = this.logger.isDebugEnabled();
if (httpSession == null) { if (httpSession == null) {
if (debug) { this.logger.debug("No HttpSession currently exists");
this.logger.debug("No HttpSession currently exists");
}
return null; return null;
} }
// Session exists, so try to obtain a context from it. // Session exists, so try to obtain a context from it.
Object contextFromSession = httpSession.getAttribute(this.springSecurityContextKey); Object contextFromSession = httpSession.getAttribute(this.springSecurityContextKey);
if (contextFromSession == null) { if (contextFromSession == null) {
if (debug) { this.logger.debug("HttpSession returned null object for SPRING_SECURITY_CONTEXT");
this.logger.debug("HttpSession returned null object for SPRING_SECURITY_CONTEXT");
}
return null; return null;
} }
// We now have the security context object from the session. // We now have the security context object from the session.
if (!(contextFromSession instanceof SecurityContext)) { if (!(contextFromSession instanceof SecurityContext)) {
if (this.logger.isWarnEnabled()) { this.logger.warn(LogMessage.format(
this.logger.warn(this.springSecurityContextKey + " did not contain a SecurityContext but contained: '" "%s did not contain a SecurityContext but contained: '%s'; are you improperly "
+ contextFromSession + "'; are you improperly modifying the HttpSession directly " + "modifying the HttpSession directly (you should always use SecurityContextHolder) "
+ "(you should always use SecurityContextHolder) or using the HttpSession attribute " + "or using the HttpSession attribute reserved for this class?",
+ "reserved for this class?"); this.springSecurityContextKey, contextFromSession));
}
return null; return null;
} }
if (debug) { this.logger.debug(LogMessage.format("Obtained a valid SecurityContext from %s: '%s'",
this.logger.debug("Obtained a valid SecurityContext from " + this.springSecurityContextKey + ": '" this.springSecurityContextKey, contextFromSession));
+ contextFromSession + "'");
}
// Everything OK. The only non-null return from this method. // Everything OK. The only non-null return from this method.
return (SecurityContext) contextFromSession; return (SecurityContext) contextFromSession;
} }
@ -306,6 +279,8 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
*/ */
final class SaveToSessionResponseWrapper extends SaveContextOnUpdateOrErrorResponseWrapper { final class SaveToSessionResponseWrapper extends SaveContextOnUpdateOrErrorResponseWrapper {
private final Log logger = HttpSessionSecurityContextRepository.this.logger;
private final HttpServletRequest request; private final HttpServletRequest request;
private final boolean httpSessionExistedAtStartOfRequest; private final boolean httpSessionExistedAtStartOfRequest;
@ -349,41 +324,29 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
protected void saveContext(SecurityContext context) { protected void saveContext(SecurityContext context) {
final Authentication authentication = context.getAuthentication(); final Authentication authentication = context.getAuthentication();
HttpSession httpSession = this.request.getSession(false); HttpSession httpSession = this.request.getSession(false);
String springSecurityContextKey = HttpSessionSecurityContextRepository.this.springSecurityContextKey;
// See SEC-776 // See SEC-776
if (authentication == null if (authentication == null
|| HttpSessionSecurityContextRepository.this.trustResolver.isAnonymous(authentication)) { || HttpSessionSecurityContextRepository.this.trustResolver.isAnonymous(authentication)) {
if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) { this.logger.debug("SecurityContext is empty or contents are anonymous - "
HttpSessionSecurityContextRepository.this.logger.debug( + "context will not be stored in HttpSession.");
"SecurityContext is empty or contents are anonymous - context will not be stored in HttpSession.");
}
if (httpSession != null && this.authBeforeExecution != null) { if (httpSession != null && this.authBeforeExecution != null) {
// SEC-1587 A non-anonymous context may still be in the session // SEC-1587 A non-anonymous context may still be in the session
// SEC-1735 remove if the contextBeforeExecution was not anonymous // SEC-1735 remove if the contextBeforeExecution was not anonymous
httpSession.removeAttribute(HttpSessionSecurityContextRepository.this.springSecurityContextKey); httpSession.removeAttribute(springSecurityContextKey);
} }
return; return;
} }
httpSession = (httpSession != null) ? httpSession : createNewSessionIfAllowed(context);
if (httpSession == null) {
httpSession = createNewSessionIfAllowed(context);
}
// If HttpSession exists, store current SecurityContext but only if it has // If HttpSession exists, store current SecurityContext but only if it has
// actually changed in this thread (see SEC-37, SEC-1307, SEC-1528) // actually changed in this thread (see SEC-37, SEC-1307, SEC-1528)
if (httpSession != null) { if (httpSession != null) {
// We may have a new session, so check also whether the context attribute // We may have a new session, so check also whether the context attribute
// is set SEC-1561 // is set SEC-1561
if (contextChanged(context) || httpSession if (contextChanged(context) || httpSession.getAttribute(springSecurityContextKey) == null) {
.getAttribute(HttpSessionSecurityContextRepository.this.springSecurityContextKey) == null) { httpSession.setAttribute(springSecurityContextKey, context);
httpSession.setAttribute(HttpSessionSecurityContextRepository.this.springSecurityContextKey, this.logger.debug(LogMessage.format("SecurityContext '%s' stored to HttpSession: '%s'", context,
context); httpSession));
if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) {
HttpSessionSecurityContextRepository.this.logger
.debug("SecurityContext '" + context + "' stored to HttpSession: '" + httpSession);
}
} }
} }
} }
@ -396,56 +359,37 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
if (isTransientAuthentication(context.getAuthentication())) { if (isTransientAuthentication(context.getAuthentication())) {
return null; return null;
} }
if (this.httpSessionExistedAtStartOfRequest) { if (this.httpSessionExistedAtStartOfRequest) {
if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) { this.logger.debug("HttpSession is now null, but was not null at start of request; "
HttpSessionSecurityContextRepository.this.logger + "session was invalidated, so do not create a new session");
.debug("HttpSession is now null, but was not null at start of request; "
+ "session was invalidated, so do not create a new session");
}
return null; return null;
} }
if (!HttpSessionSecurityContextRepository.this.allowSessionCreation) { if (!HttpSessionSecurityContextRepository.this.allowSessionCreation) {
if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) { this.logger.debug("The HttpSession is currently null, and the "
HttpSessionSecurityContextRepository.this.logger.debug("The HttpSession is currently null, and the " + HttpSessionSecurityContextRepository.class.getSimpleName()
+ HttpSessionSecurityContextRepository.class.getSimpleName() + " is prohibited from creating an HttpSession "
+ " is prohibited from creating an HttpSession " + "(because the allowSessionCreation property is false) - SecurityContext thus not "
+ "(because the allowSessionCreation property is false) - SecurityContext thus not " + "stored for next request");
+ "stored for next request");
}
return null; return null;
} }
// Generate a HttpSession only if we need to // Generate a HttpSession only if we need to
if (HttpSessionSecurityContextRepository.this.contextObject.equals(context)) { if (HttpSessionSecurityContextRepository.this.contextObject.equals(context)) {
if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) { this.logger.debug(LogMessage.format(
HttpSessionSecurityContextRepository.this.logger.debug( "HttpSession is null, but SecurityContext has not changed from "
"HttpSession is null, but SecurityContext has not changed from default empty context: ' " + "default empty context: '%s'; not creating HttpSession or storing SecurityContext",
+ context + "'; not creating HttpSession or storing SecurityContext"); context));
}
return null; return null;
} }
this.logger.debug("HttpSession being created as SecurityContext is non-default");
if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) {
HttpSessionSecurityContextRepository.this.logger
.debug("HttpSession being created as SecurityContext is non-default");
}
try { try {
return this.request.getSession(true); return this.request.getSession(true);
} }
catch (IllegalStateException ex) { catch (IllegalStateException ex) {
// Response must already be committed, therefore can't create a new // Response must already be committed, therefore can't create a new
// session // session
HttpSessionSecurityContextRepository.this.logger this.logger.warn("Failed to create a session, as response has been committed. "
.warn("Failed to create a session, as response has been committed. Unable to store" + "Unable to store SecurityContext.");
+ " SecurityContext.");
} }
return null; return null;
} }

View File

@ -44,7 +44,7 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends OnCommit
private boolean contextSaved = false; private boolean contextSaved = false;
/* See SEC-1052 */ // See SEC-1052
private final boolean disableUrlRewriting; private final boolean disableUrlRewriting;
/** /**

View File

@ -26,6 +26,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession; import javax.servlet.http.HttpSession;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.web.filter.GenericFilterBean; import org.springframework.web.filter.GenericFilterBean;
@ -74,49 +75,36 @@ public class SecurityContextPersistenceFilter extends GenericFilterBean {
} }
@Override @Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException { throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req; doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
HttpServletResponse response = (HttpServletResponse) res; }
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
// ensure that filter is only applied once per request
if (request.getAttribute(FILTER_APPLIED) != null) { if (request.getAttribute(FILTER_APPLIED) != null) {
// ensure that filter is only applied once per request
chain.doFilter(request, response); chain.doFilter(request, response);
return; return;
} }
final boolean debug = this.logger.isDebugEnabled();
request.setAttribute(FILTER_APPLIED, Boolean.TRUE); request.setAttribute(FILTER_APPLIED, Boolean.TRUE);
if (this.forceEagerSessionCreation) { if (this.forceEagerSessionCreation) {
HttpSession session = request.getSession(); HttpSession session = request.getSession();
this.logger.debug(LogMessage.format("Eagerly created session: %s", session.getId()));
if (debug && session.isNew()) {
this.logger.debug("Eagerly created session: " + session.getId());
}
} }
HttpRequestResponseHolder holder = new HttpRequestResponseHolder(request, response); HttpRequestResponseHolder holder = new HttpRequestResponseHolder(request, response);
SecurityContext contextBeforeChainExecution = this.repo.loadContext(holder); SecurityContext contextBeforeChainExecution = this.repo.loadContext(holder);
try { try {
SecurityContextHolder.setContext(contextBeforeChainExecution); SecurityContextHolder.setContext(contextBeforeChainExecution);
chain.doFilter(holder.getRequest(), holder.getResponse()); chain.doFilter(holder.getRequest(), holder.getResponse());
} }
finally { finally {
SecurityContext contextAfterChainExecution = SecurityContextHolder.getContext(); SecurityContext contextAfterChainExecution = SecurityContextHolder.getContext();
// Crucial removal of SecurityContextHolder contents - do this before anything // Crucial removal of SecurityContextHolder contents before anything else.
// else.
SecurityContextHolder.clearContext(); SecurityContextHolder.clearContext();
this.repo.saveContext(contextAfterChainExecution, holder.getRequest(), holder.getResponse()); this.repo.saveContext(contextAfterChainExecution, holder.getRequest(), holder.getResponse());
request.removeAttribute(FILTER_APPLIED); request.removeAttribute(FILTER_APPLIED);
this.logger.debug("SecurityContextHolder now cleared, as request processing completed");
if (debug) {
this.logger.debug("SecurityContextHolder now cleared, as request processing completed");
}
} }
} }

View File

@ -46,14 +46,12 @@ public final class WebAsyncManagerIntegrationFilter extends OncePerRequestFilter
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException { throws ServletException, IOException {
WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request); WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request);
SecurityContextCallableProcessingInterceptor securityProcessingInterceptor = (SecurityContextCallableProcessingInterceptor) asyncManager SecurityContextCallableProcessingInterceptor securityProcessingInterceptor = (SecurityContextCallableProcessingInterceptor) asyncManager
.getCallableInterceptor(CALLABLE_INTERCEPTOR_KEY); .getCallableInterceptor(CALLABLE_INTERCEPTOR_KEY);
if (securityProcessingInterceptor == null) { if (securityProcessingInterceptor == null) {
asyncManager.registerCallableInterceptor(CALLABLE_INTERCEPTOR_KEY, asyncManager.registerCallableInterceptor(CALLABLE_INTERCEPTOR_KEY,
new SecurityContextCallableProcessingInterceptor()); new SecurityContextCallableProcessingInterceptor());
} }
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
} }

View File

@ -20,6 +20,7 @@ import java.util.Enumeration;
import javax.servlet.ServletContext; import javax.servlet.ServletContext;
import org.springframework.util.Assert;
import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.WebApplicationContextUtils; import org.springframework.web.context.support.WebApplicationContextUtils;
@ -47,11 +48,10 @@ public abstract class SecurityWebApplicationContextUtils extends WebApplicationC
* @see ServletContext#getAttributeNames() * @see ServletContext#getAttributeNames()
*/ */
public static WebApplicationContext findRequiredWebApplicationContext(ServletContext servletContext) { public static WebApplicationContext findRequiredWebApplicationContext(ServletContext servletContext) {
WebApplicationContext wac = _findWebApplicationContext(servletContext); WebApplicationContext webApplicationContext = compatiblyFindWebApplicationContext(servletContext);
if (wac == null) { Assert.state(webApplicationContext != null,
throw new IllegalStateException("No WebApplicationContext found: no ContextLoaderListener registered?"); "No WebApplicationContext found: no ContextLoaderListener registered?");
} return webApplicationContext;
return wac;
} }
/** /**
@ -59,23 +59,21 @@ public abstract class SecurityWebApplicationContextUtils extends WebApplicationC
* spring framework 4.1.x. * spring framework 4.1.x.
* @see #findWebApplicationContext(ServletContext) * @see #findWebApplicationContext(ServletContext)
*/ */
private static WebApplicationContext _findWebApplicationContext(ServletContext sc) { private static WebApplicationContext compatiblyFindWebApplicationContext(ServletContext sc) {
WebApplicationContext wac = getWebApplicationContext(sc); WebApplicationContext webApplicationContext = getWebApplicationContext(sc);
if (wac == null) { if (webApplicationContext == null) {
Enumeration<String> attrNames = sc.getAttributeNames(); Enumeration<String> attrNames = sc.getAttributeNames();
while (attrNames.hasMoreElements()) { while (attrNames.hasMoreElements()) {
String attrName = attrNames.nextElement(); String attrName = attrNames.nextElement();
Object attrValue = sc.getAttribute(attrName); Object attrValue = sc.getAttribute(attrName);
if (attrValue instanceof WebApplicationContext) { if (attrValue instanceof WebApplicationContext) {
if (wac != null) { Assert.state(webApplicationContext == null, "No unique WebApplicationContext found: more than one "
throw new IllegalStateException("No unique WebApplicationContext found: more than one " + "DispatcherServlet registered with publishContext=true?");
+ "DispatcherServlet registered with publishContext=true?"); webApplicationContext = (WebApplicationContext) attrValue;
}
wac = (WebApplicationContext) attrValue;
} }
} }
} }
return wac; return webApplicationContext;
} }
} }

View File

@ -69,30 +69,13 @@ public final class CookieCsrfTokenRepository implements CsrfTokenRepository {
public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) { public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) {
String tokenValue = (token != null) ? token.getToken() : ""; String tokenValue = (token != null) ? token.getToken() : "";
Cookie cookie = new Cookie(this.cookieName, tokenValue); Cookie cookie = new Cookie(this.cookieName, tokenValue);
if (this.secure == null) { cookie.setSecure((this.secure != null) ? this.secure : request.isSecure());
cookie.setSecure(request.isSecure()); cookie.setPath(StringUtils.hasLength(this.cookiePath) ? this.cookiePath : this.getRequestContext(request));
} cookie.setMaxAge((token != null) ? -1 : 0);
else {
cookie.setSecure(this.secure);
}
if (this.cookiePath != null && !this.cookiePath.isEmpty()) {
cookie.setPath(this.cookiePath);
}
else {
cookie.setPath(this.getRequestContext(request));
}
if (token == null) {
cookie.setMaxAge(0);
}
else {
cookie.setMaxAge(-1);
}
cookie.setHttpOnly(this.cookieHttpOnly); cookie.setHttpOnly(this.cookieHttpOnly);
if (this.cookieDomain != null && !this.cookieDomain.isEmpty()) { if (StringUtils.hasLength(this.cookieDomain)) {
cookie.setDomain(this.cookieDomain); cookie.setDomain(this.cookieDomain);
} }
response.addCookie(cookie); response.addCookie(cookie);
} }

View File

@ -51,10 +51,8 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt
boolean containsToken = this.csrfTokenRepository.loadToken(request) != null; boolean containsToken = this.csrfTokenRepository.loadToken(request) != null;
if (containsToken) { if (containsToken) {
this.csrfTokenRepository.saveToken(null, request, response); this.csrfTokenRepository.saveToken(null, request, response);
CsrfToken newToken = this.csrfTokenRepository.generateToken(request); CsrfToken newToken = this.csrfTokenRepository.generateToken(request);
this.csrfTokenRepository.saveToken(newToken, request, response); this.csrfTokenRepository.saveToken(newToken, request, response);
request.setAttribute(CsrfToken.class.getName(), newToken); request.setAttribute(CsrfToken.class.getName(), newToken);
request.setAttribute(newToken.getParameterName(), newToken); request.setAttribute(newToken.getParameterName(), newToken);
} }

View File

@ -29,6 +29,8 @@ import javax.servlet.http.HttpSession;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.access.AccessDeniedHandlerImpl; import org.springframework.security.web.access.AccessDeniedHandlerImpl;
import org.springframework.security.web.util.UrlUtils; import org.springframework.security.web.util.UrlUtils;
@ -97,39 +99,30 @@ public final class CsrfFilter extends OncePerRequestFilter {
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException { throws ServletException, IOException {
request.setAttribute(HttpServletResponse.class.getName(), response); request.setAttribute(HttpServletResponse.class.getName(), response);
CsrfToken csrfToken = this.tokenRepository.loadToken(request); CsrfToken csrfToken = this.tokenRepository.loadToken(request);
final boolean missingToken = csrfToken == null; boolean missingToken = (csrfToken == null);
if (missingToken) { if (missingToken) {
csrfToken = this.tokenRepository.generateToken(request); csrfToken = this.tokenRepository.generateToken(request);
this.tokenRepository.saveToken(csrfToken, request, response); this.tokenRepository.saveToken(csrfToken, request, response);
} }
request.setAttribute(CsrfToken.class.getName(), csrfToken); request.setAttribute(CsrfToken.class.getName(), csrfToken);
request.setAttribute(csrfToken.getParameterName(), csrfToken); request.setAttribute(csrfToken.getParameterName(), csrfToken);
if (!this.requireCsrfProtectionMatcher.matches(request)) { if (!this.requireCsrfProtectionMatcher.matches(request)) {
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
return; return;
} }
String actualToken = request.getHeader(csrfToken.getHeaderName()); String actualToken = request.getHeader(csrfToken.getHeaderName());
if (actualToken == null) { if (actualToken == null) {
actualToken = request.getParameter(csrfToken.getParameterName()); actualToken = request.getParameter(csrfToken.getParameterName());
} }
if (!csrfToken.getToken().equals(actualToken)) { if (!csrfToken.getToken().equals(actualToken)) {
if (this.logger.isDebugEnabled()) { this.logger.debug(
this.logger.debug("Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request)); LogMessage.of(() -> "Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request)));
} AccessDeniedException exception = (!missingToken) ? new InvalidCsrfTokenException(csrfToken, actualToken)
if (missingToken) { : new MissingCsrfTokenException(actualToken);
this.accessDeniedHandler.handle(request, response, new MissingCsrfTokenException(actualToken)); this.accessDeniedHandler.handle(request, response, exception);
}
else {
this.accessDeniedHandler.handle(request, response,
new InvalidCsrfTokenException(csrfToken, actualToken));
}
return; return;
} }
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
} }

View File

@ -24,7 +24,6 @@ import java.io.Serializable;
* @author Rob Winch * @author Rob Winch
* @since 3.2 * @since 3.2
* @see DefaultCsrfToken * @see DefaultCsrfToken
*
*/ */
public interface CsrfToken extends Serializable { public interface CsrfToken extends Serializable {

View File

@ -87,11 +87,8 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository {
private HttpServletResponse getResponse(HttpServletRequest request) { private HttpServletResponse getResponse(HttpServletRequest request) {
HttpServletResponse response = (HttpServletResponse) request.getAttribute(HTTP_RESPONSE_ATTR); HttpServletResponse response = (HttpServletResponse) request.getAttribute(HTTP_RESPONSE_ATTR);
if (response == null) { Assert.notNull(response, () -> "The HttpServletRequest attribute must contain an HttpServletResponse "
throw new IllegalArgumentException( + "for the attribute " + HTTP_RESPONSE_ATTR);
"The HttpServletRequest attribute must contain an HttpServletResponse for the attribute "
+ HTTP_RESPONSE_ATTR);
}
return response; return response;
} }
@ -166,7 +163,6 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository {
if (this.tokenRepository == null) { if (this.tokenRepository == null) {
return; return;
} }
synchronized (this) { synchronized (this) {
if (this.tokenRepository != null) { if (this.tokenRepository != null) {
this.tokenRepository.saveToken(this.delegate, this.request, this.response); this.tokenRepository.saveToken(this.delegate, this.request, this.response);

View File

@ -50,35 +50,35 @@ public final class DebugFilter implements Filter {
static final String ALREADY_FILTERED_ATTR_NAME = DebugFilter.class.getName().concat(".FILTERED"); static final String ALREADY_FILTERED_ATTR_NAME = DebugFilter.class.getName().concat(".FILTERED");
private final FilterChainProxy fcp; private final FilterChainProxy filterChainProxy;
private final Logger logger = new Logger(); private final Logger logger = new Logger();
public DebugFilter(FilterChainProxy fcp) { public DebugFilter(FilterChainProxy filterChainProxy) {
this.fcp = fcp; this.filterChainProxy = filterChainProxy;
} }
@Override @Override
public void doFilter(ServletRequest srvltRequest, ServletResponse srvltResponse, FilterChain filterChain) public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain)
throws ServletException, IOException { throws ServletException, IOException {
if (!(request instanceof HttpServletRequest) || !(response instanceof HttpServletResponse)) {
if (!(srvltRequest instanceof HttpServletRequest) || !(srvltResponse instanceof HttpServletResponse)) {
throw new ServletException("DebugFilter just supports HTTP requests"); throw new ServletException("DebugFilter just supports HTTP requests");
} }
HttpServletRequest request = (HttpServletRequest) srvltRequest; doFilter((HttpServletRequest) request, (HttpServletResponse) response, filterChain);
HttpServletResponse response = (HttpServletResponse) srvltResponse; }
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws IOException, ServletException {
List<Filter> filters = getFilters(request); List<Filter> filters = getFilters(request);
this.logger.info("Request received for " + request.getMethod() + " '" + UrlUtils.buildRequestUrl(request) this.logger.info("Request received for " + request.getMethod() + " '" + UrlUtils.buildRequestUrl(request)
+ "':\n\n" + request + "\n\n" + "servletPath:" + request.getServletPath() + "\n" + "pathInfo:" + "':\n\n" + request + "\n\n" + "servletPath:" + request.getServletPath() + "\n" + "pathInfo:"
+ request.getPathInfo() + "\n" + "headers: \n" + formatHeaders(request) + "\n\n" + request.getPathInfo() + "\n" + "headers: \n" + formatHeaders(request) + "\n\n"
+ formatFilters(filters)); + formatFilters(filters));
if (request.getAttribute(ALREADY_FILTERED_ATTR_NAME) == null) { if (request.getAttribute(ALREADY_FILTERED_ATTR_NAME) == null) {
invokeWithWrappedRequest(request, response, filterChain); invokeWithWrappedRequest(request, response, filterChain);
} }
else { else {
this.fcp.doFilter(request, response, filterChain); this.filterChainProxy.doFilter(request, response, filterChain);
} }
} }
@ -87,7 +87,7 @@ public final class DebugFilter implements Filter {
request.setAttribute(ALREADY_FILTERED_ATTR_NAME, Boolean.TRUE); request.setAttribute(ALREADY_FILTERED_ATTR_NAME, Boolean.TRUE);
request = new DebugRequestWrapper(request); request = new DebugRequestWrapper(request);
try { try {
this.fcp.doFilter(request, response, filterChain); this.filterChainProxy.doFilter(request, response, filterChain);
} }
finally { finally {
request.removeAttribute(ALREADY_FILTERED_ATTR_NAME); request.removeAttribute(ALREADY_FILTERED_ATTR_NAME);
@ -134,7 +134,7 @@ public final class DebugFilter implements Filter {
} }
private List<Filter> getFilters(HttpServletRequest request) { private List<Filter> getFilters(HttpServletRequest request) {
for (SecurityFilterChain chain : this.fcp.getFilterChains()) { for (SecurityFilterChain chain : this.filterChainProxy.getFilterChains()) {
if (chain.matches(request)) { if (chain.matches(request)) {
return chain.getFilters(); return chain.getFilters();
} }
@ -163,11 +163,9 @@ public final class DebugFilter implements Filter {
public HttpSession getSession() { public HttpSession getSession() {
boolean sessionExists = super.getSession(false) != null; boolean sessionExists = super.getSession(false) != null;
HttpSession session = super.getSession(); HttpSession session = super.getSession();
if (!sessionExists) { if (!sessionExists) {
DebugRequestWrapper.logger.info("New HTTP session created: " + session.getId(), true); DebugRequestWrapper.logger.info("New HTTP session created: " + session.getId(), true);
} }
return session; return session;
} }

View File

@ -50,19 +50,17 @@ public class DefaultHttpFirewall implements HttpFirewall {
@Override @Override
public FirewalledRequest getFirewalledRequest(HttpServletRequest request) throws RequestRejectedException { public FirewalledRequest getFirewalledRequest(HttpServletRequest request) throws RequestRejectedException {
FirewalledRequest fwr = new RequestWrapper(request); FirewalledRequest firewalledRequest = new RequestWrapper(request);
if (!isNormalized(firewalledRequest.getServletPath()) || !isNormalized(firewalledRequest.getPathInfo())) {
if (!isNormalized(fwr.getServletPath()) || !isNormalized(fwr.getPathInfo())) { throw new RequestRejectedException(
throw new RequestRejectedException("Un-normalized paths are not supported: " + fwr.getServletPath() "Un-normalized paths are not supported: " + firewalledRequest.getServletPath()
+ ((fwr.getPathInfo() != null) ? fwr.getPathInfo() : "")); + ((firewalledRequest.getPathInfo() != null) ? firewalledRequest.getPathInfo() : ""));
} }
String requestURI = firewalledRequest.getRequestURI();
String requestURI = fwr.getRequestURI();
if (containsInvalidUrlEncodedSlash(requestURI)) { if (containsInvalidUrlEncodedSlash(requestURI)) {
throw new RequestRejectedException("The requestURI cannot contain encoded slash. Got " + requestURI); throw new RequestRejectedException("The requestURI cannot contain encoded slash. Got " + requestURI);
} }
return firewalledRequest;
return fwr;
} }
@Override @Override
@ -89,11 +87,9 @@ public class DefaultHttpFirewall implements HttpFirewall {
if (this.allowUrlEncodedSlash || uri == null) { if (this.allowUrlEncodedSlash || uri == null) {
return false; return false;
} }
if (uri.contains("%2f") || uri.contains("%2F")) { if (uri.contains("%2f") || uri.contains("%2F")) {
return true; return true;
} }
return false; return false;
} }
@ -107,22 +103,18 @@ public class DefaultHttpFirewall implements HttpFirewall {
if (path == null) { if (path == null) {
return true; return true;
} }
for (int i = path.length(); i > 0;) {
for (int j = path.length(); j > 0;) { int slashIndex = path.lastIndexOf('/', i - 1);
int i = path.lastIndexOf('/', j - 1); int gap = i - slashIndex;
int gap = j - i; if (gap == 2 && path.charAt(slashIndex + 1) == '.') {
if (gap == 2 && path.charAt(i + 1) == '.') {
// ".", "/./" or "/." // ".", "/./" or "/."
return false; return false;
} }
else if (gap == 3 && path.charAt(i + 1) == '.' && path.charAt(i + 2) == '.') { if (gap == 3 && path.charAt(slashIndex + 1) == '.' && path.charAt(slashIndex + 2) == '.') {
return false; return false;
} }
i = slashIndex;
j = i;
} }
return true; return true;
} }

View File

@ -22,6 +22,8 @@ import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper; import javax.servlet.http.HttpServletResponseWrapper;
import org.springframework.util.Assert;
/** /**
* @author Luke Taylor * @author Luke Taylor
* @author Eddú Meléndez * @author Eddú Meléndez
@ -71,9 +73,7 @@ class FirewalledResponse extends HttpServletResponseWrapper {
} }
void validateCrlf(String name, String value) { void validateCrlf(String name, String value) {
if (hasCrlf(name) || hasCrlf(value)) { Assert.isTrue(!hasCrlf(name) && !hasCrlf(value), () -> "Invalid characters (CR/LF) in header " + name);
throw new IllegalArgumentException("Invalid characters (CR/LF) in header " + name);
}
} }
private boolean hasCrlf(String value) { private boolean hasCrlf(String value) {

View File

@ -24,6 +24,8 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
/** /**
* A simple implementation of {@link RequestRejectedHandler} that sends an error with * A simple implementation of {@link RequestRejectedHandler} that sends an error with
* configurable status code. * configurable status code.
@ -55,10 +57,8 @@ public class HttpStatusRequestRejectedHandler implements RequestRejectedHandler
@Override @Override
public void handle(HttpServletRequest request, HttpServletResponse response, public void handle(HttpServletRequest request, HttpServletResponse response,
RequestRejectedException requestRejectedException) throws IOException { RequestRejectedException requestRejectedException) throws IOException {
if (logger.isDebugEnabled()) { logger.debug(LogMessage.format("Rejecting request due to: %s", requestRejectedException.getMessage()),
logger.debug("Rejecting request due to: " + requestRejectedException.getMessage(), requestRejectedException);
requestRejectedException);
}
response.sendError(this.httpError); response.sendError(this.httpError);
} }

View File

@ -74,10 +74,8 @@ final class RequestWrapper extends FirewalledRequest {
if (path == null) { if (path == null) {
return null; return null;
} }
int semicolonIndex = path.indexOf(';');
int scIndex = path.indexOf(';'); if (semicolonIndex < 0) {
if (scIndex < 0) {
int doubleSlashIndex = path.indexOf("//"); int doubleSlashIndex = path.indexOf("//");
if (doubleSlashIndex < 0) { if (doubleSlashIndex < 0) {
// Most likely case, no parameters in any segment and no '//', so no // Most likely case, no parameters in any segment and no '//', so no
@ -85,29 +83,23 @@ final class RequestWrapper extends FirewalledRequest {
return path; return path;
} }
} }
StringTokenizer tokenizer = new StringTokenizer(path, "/");
StringTokenizer st = new StringTokenizer(path, "/");
StringBuilder stripped = new StringBuilder(path.length()); StringBuilder stripped = new StringBuilder(path.length());
if (path.charAt(0) == '/') { if (path.charAt(0) == '/') {
stripped.append('/'); stripped.append('/');
} }
while (tokenizer.hasMoreTokens()) {
while (st.hasMoreTokens()) { String segment = tokenizer.nextToken();
String segment = st.nextToken(); semicolonIndex = segment.indexOf(';');
scIndex = segment.indexOf(';'); if (semicolonIndex >= 0) {
segment = segment.substring(0, semicolonIndex);
if (scIndex >= 0) {
segment = segment.substring(0, scIndex);
} }
stripped.append(segment).append('/'); stripped.append(segment).append('/');
} }
// Remove the trailing slash if the original path didn't have one // Remove the trailing slash if the original path didn't have one
if (path.charAt(path.length() - 1) != '/') { if (path.charAt(path.length() - 1) != '/') {
stripped.deleteCharAt(stripped.length() - 1); stripped.deleteCharAt(stripped.length() - 1);
} }
return stripped.toString(); return stripped.toString();
} }

View File

@ -31,6 +31,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.util.Assert;
/** /**
* <p> * <p>
@ -83,7 +84,7 @@ public class StrictHttpFirewall implements HttpFirewall {
* Used to specify to {@link #setAllowedHttpMethods(Collection)} that any HTTP method * Used to specify to {@link #setAllowedHttpMethods(Collection)} that any HTTP method
* should be allowed. * should be allowed.
*/ */
private static final Set<String> ALLOW_ANY_HTTP_METHOD = Collections.unmodifiableSet(Collections.emptySet()); private static final Set<String> ALLOW_ANY_HTTP_METHOD = Collections.emptySet();
private static final String ENCODED_PERCENT = "%25"; private static final String ENCODED_PERCENT = "%25";
@ -165,15 +166,9 @@ public class StrictHttpFirewall implements HttpFirewall {
* @see #setUnsafeAllowAnyHttpMethod(boolean) * @see #setUnsafeAllowAnyHttpMethod(boolean)
*/ */
public void setAllowedHttpMethods(Collection<String> allowedHttpMethods) { public void setAllowedHttpMethods(Collection<String> allowedHttpMethods) {
if (allowedHttpMethods == null) { Assert.notNull(allowedHttpMethods, "allowedHttpMethods cannot be null");
throw new IllegalArgumentException("allowedHttpMethods cannot be null"); this.allowedHttpMethods = (allowedHttpMethods != ALLOW_ANY_HTTP_METHOD) ? new HashSet<>(allowedHttpMethods)
} : ALLOW_ANY_HTTP_METHOD;
if (allowedHttpMethods == ALLOW_ANY_HTTP_METHOD) {
this.allowedHttpMethods = ALLOW_ANY_HTTP_METHOD;
}
else {
this.allowedHttpMethods = new HashSet<>(allowedHttpMethods);
}
} }
/** /**
@ -361,9 +356,7 @@ public class StrictHttpFirewall implements HttpFirewall {
* @see Character#isDefined(int) * @see Character#isDefined(int)
*/ */
public void setAllowedHeaderNames(Predicate<String> allowedHeaderNames) { public void setAllowedHeaderNames(Predicate<String> allowedHeaderNames) {
if (allowedHeaderNames == null) { Assert.notNull(allowedHeaderNames, "allowedHeaderNames cannot be null");
throw new IllegalArgumentException("allowedHeaderNames cannot be null");
}
this.allowedHeaderNames = allowedHeaderNames; this.allowedHeaderNames = allowedHeaderNames;
} }
@ -378,28 +371,20 @@ public class StrictHttpFirewall implements HttpFirewall {
* @see Character#isDefined(int) * @see Character#isDefined(int)
*/ */
public void setAllowedHeaderValues(Predicate<String> allowedHeaderValues) { public void setAllowedHeaderValues(Predicate<String> allowedHeaderValues) {
if (allowedHeaderValues == null) { Assert.notNull(allowedHeaderValues, "allowedHeaderValues cannot be null");
throw new IllegalArgumentException("allowedHeaderValues cannot be null");
}
this.allowedHeaderValues = allowedHeaderValues; this.allowedHeaderValues = allowedHeaderValues;
} }
/* /**
* Determines which parameter names should be allowed. The default is to reject header * Determines which parameter names should be allowed. The default is to reject header
* names that contain ISO control characters and characters that are not defined. </p> * names that contain ISO control characters and characters that are not defined.
*
* @param allowedParameterNames the predicate for testing parameter names * @param allowedParameterNames the predicate for testing parameter names
*
* @see Character#isISOControl(int)
*
* @see Character#isDefined(int)
*
* @since 5.4 * @since 5.4
* @see Character#isISOControl(int)
* @see Character#isDefined(int)
*/ */
public void setAllowedParameterNames(Predicate<String> allowedParameterNames) { public void setAllowedParameterNames(Predicate<String> allowedParameterNames) {
if (allowedParameterNames == null) { Assert.notNull(allowedParameterNames, "allowedParameterNames cannot be null");
throw new IllegalArgumentException("allowedParameterNames cannot be null");
}
this.allowedParameterNames = allowedParameterNames; this.allowedParameterNames = allowedParameterNames;
} }
@ -412,9 +397,7 @@ public class StrictHttpFirewall implements HttpFirewall {
* @since 5.4 * @since 5.4
*/ */
public void setAllowedParameterValues(Predicate<String> allowedParameterValues) { public void setAllowedParameterValues(Predicate<String> allowedParameterValues) {
if (allowedParameterValues == null) { Assert.notNull(allowedParameterValues, "allowedParameterValues cannot be null");
throw new IllegalArgumentException("allowedParameterValues cannot be null");
}
this.allowedParameterValues = allowedParameterValues; this.allowedParameterValues = allowedParameterValues;
} }
@ -426,9 +409,7 @@ public class StrictHttpFirewall implements HttpFirewall {
* @since 5.2 * @since 5.2
*/ */
public void setAllowedHostnames(Predicate<String> allowedHostnames) { public void setAllowedHostnames(Predicate<String> allowedHostnames) {
if (allowedHostnames == null) { Assert.notNull(allowedHostnames, "allowedHostnames cannot be null");
throw new IllegalArgumentException("allowedHostnames cannot be null");
}
this.allowedHostnames = allowedHostnames; this.allowedHostnames = allowedHostnames;
} }
@ -447,173 +428,15 @@ public class StrictHttpFirewall implements HttpFirewall {
rejectForbiddenHttpMethod(request); rejectForbiddenHttpMethod(request);
rejectedBlocklistedUrls(request); rejectedBlocklistedUrls(request);
rejectedUntrustedHosts(request); rejectedUntrustedHosts(request);
if (!isNormalized(request)) { if (!isNormalized(request)) {
throw new RequestRejectedException("The request was rejected because the URL was not normalized."); throw new RequestRejectedException("The request was rejected because the URL was not normalized.");
} }
String requestUri = request.getRequestURI(); String requestUri = request.getRequestURI();
if (!containsOnlyPrintableAsciiCharacters(requestUri)) { if (!containsOnlyPrintableAsciiCharacters(requestUri)) {
throw new RequestRejectedException( throw new RequestRejectedException(
"The requestURI was rejected because it can only contain printable ASCII characters."); "The requestURI was rejected because it can only contain printable ASCII characters.");
} }
return new FirewalledRequest(request) { return new StrictFirewalledRequest(request);
@Override
public long getDateHeader(String name) {
if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) {
throw new RequestRejectedException(
"The request was rejected because the header name \"" + name + "\" is not allowed.");
}
return super.getDateHeader(name);
}
@Override
public int getIntHeader(String name) {
if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) {
throw new RequestRejectedException(
"The request was rejected because the header name \"" + name + "\" is not allowed.");
}
return super.getIntHeader(name);
}
@Override
public String getHeader(String name) {
if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) {
throw new RequestRejectedException(
"The request was rejected because the header name \"" + name + "\" is not allowed.");
}
String value = super.getHeader(name);
if (value != null && !StrictHttpFirewall.this.allowedHeaderValues.test(value)) {
throw new RequestRejectedException(
"The request was rejected because the header value \"" + value + "\" is not allowed.");
}
return value;
}
@Override
public Enumeration<String> getHeaders(String name) {
if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) {
throw new RequestRejectedException(
"The request was rejected because the header name \"" + name + "\" is not allowed.");
}
Enumeration<String> valuesEnumeration = super.getHeaders(name);
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return valuesEnumeration.hasMoreElements();
}
@Override
public String nextElement() {
String value = valuesEnumeration.nextElement();
if (!StrictHttpFirewall.this.allowedHeaderValues.test(value)) {
throw new RequestRejectedException("The request was rejected because the header value \""
+ value + "\" is not allowed.");
}
return value;
}
};
}
@Override
public Enumeration<String> getHeaderNames() {
Enumeration<String> namesEnumeration = super.getHeaderNames();
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return namesEnumeration.hasMoreElements();
}
@Override
public String nextElement() {
String name = namesEnumeration.nextElement();
if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the header name \""
+ name + "\" is not allowed.");
}
return name;
}
};
}
@Override
public String getParameter(String name) {
if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) {
throw new RequestRejectedException(
"The request was rejected because the parameter name \"" + name + "\" is not allowed.");
}
String value = super.getParameter(name);
if (value != null && !StrictHttpFirewall.this.allowedParameterValues.test(value)) {
throw new RequestRejectedException(
"The request was rejected because the parameter value \"" + value + "\" is not allowed.");
}
return value;
}
@Override
public Map<String, String[]> getParameterMap() {
Map<String, String[]> parameterMap = super.getParameterMap();
for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
String name = entry.getKey();
String[] values = entry.getValue();
if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) {
throw new RequestRejectedException(
"The request was rejected because the parameter name \"" + name + "\" is not allowed.");
}
for (String value : values) {
if (!StrictHttpFirewall.this.allowedParameterValues.test(value)) {
throw new RequestRejectedException("The request was rejected because the parameter value \""
+ value + "\" is not allowed.");
}
}
}
return parameterMap;
}
@Override
public Enumeration<String> getParameterNames() {
Enumeration<String> namesEnumeration = super.getParameterNames();
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return namesEnumeration.hasMoreElements();
}
@Override
public String nextElement() {
String name = namesEnumeration.nextElement();
if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the parameter name \""
+ name + "\" is not allowed.");
}
return name;
}
};
}
@Override
public String[] getParameterValues(String name) {
if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) {
throw new RequestRejectedException(
"The request was rejected because the parameter name \"" + name + "\" is not allowed.");
}
String[] values = super.getParameterValues(name);
if (values != null) {
for (String value : values) {
if (!StrictHttpFirewall.this.allowedParameterValues.test(value)) {
throw new RequestRejectedException("The request was rejected because the parameter value \""
+ value + "\" is not allowed.");
}
}
}
return values;
}
@Override
public void reset() {
}
};
} }
private void rejectForbiddenHttpMethod(HttpServletRequest request) { private void rejectForbiddenHttpMethod(HttpServletRequest request) {
@ -705,12 +528,11 @@ public class StrictHttpFirewall implements HttpFirewall {
private static boolean containsOnlyPrintableAsciiCharacters(String uri) { private static boolean containsOnlyPrintableAsciiCharacters(String uri) {
int length = uri.length(); int length = uri.length();
for (int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {
char c = uri.charAt(i); char ch = uri.charAt(i);
if (c < '\u0020' || c > '\u007e') { if (ch < '\u0020' || ch > '\u007e') {
return false; return false;
} }
} }
return true; return true;
} }
@ -728,22 +550,17 @@ public class StrictHttpFirewall implements HttpFirewall {
if (path == null) { if (path == null) {
return true; return true;
} }
for (int i = path.length(); i > 0;) {
for (int j = path.length(); j > 0;) { int slashIndex = path.lastIndexOf('/', i - 1);
int i = path.lastIndexOf('/', j - 1); int gap = i - slashIndex;
int gap = j - i; if (gap == 2 && path.charAt(slashIndex + 1) == '.') {
return false; // ".", "/./" or "/."
if (gap == 2 && path.charAt(i + 1) == '.') { }
// ".", "/./" or "/." if (gap == 3 && path.charAt(slashIndex + 1) == '.' && path.charAt(slashIndex + 2) == '.') {
return false; return false;
} }
else if (gap == 3 && path.charAt(i + 1) == '.' && path.charAt(i + 2) == '.') { i = slashIndex;
return false;
}
j = i;
} }
return true; return true;
} }
@ -782,4 +599,166 @@ public class StrictHttpFirewall implements HttpFirewall {
return getDecodedUrlBlocklist(); return getDecodedUrlBlocklist();
} }
/**
* Strict {@link FirewalledRequest}.
*/
private class StrictFirewalledRequest extends FirewalledRequest {
StrictFirewalledRequest(HttpServletRequest request) {
super(request);
}
@Override
public long getDateHeader(String name) {
validateAllowedHeaderName(name);
return super.getDateHeader(name);
}
@Override
public int getIntHeader(String name) {
validateAllowedHeaderName(name);
return super.getIntHeader(name);
}
@Override
public String getHeader(String name) {
validateAllowedHeaderName(name);
String value = super.getHeader(name);
if (value != null) {
validateAllowedHeaderValue(value);
}
return value;
}
@Override
public Enumeration<String> getHeaders(String name) {
validateAllowedHeaderName(name);
Enumeration<String> headers = super.getHeaders(name);
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return headers.hasMoreElements();
}
@Override
public String nextElement() {
String value = headers.nextElement();
validateAllowedHeaderValue(value);
return value;
}
};
}
@Override
public Enumeration<String> getHeaderNames() {
Enumeration<String> names = super.getHeaderNames();
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return names.hasMoreElements();
}
@Override
public String nextElement() {
String headerNames = names.nextElement();
validateAllowedHeaderName(headerNames);
return headerNames;
}
};
}
@Override
public String getParameter(String name) {
validateAllowedParameterName(name);
String value = super.getParameter(name);
if (value != null) {
validateAllowedParameterValue(value);
}
return value;
}
@Override
public Map<String, String[]> getParameterMap() {
Map<String, String[]> parameterMap = super.getParameterMap();
for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
String name = entry.getKey();
String[] values = entry.getValue();
validateAllowedParameterName(name);
for (String value : values) {
validateAllowedParameterValue(value);
}
}
return parameterMap;
}
@Override
public Enumeration<String> getParameterNames() {
Enumeration<String> paramaterNames = super.getParameterNames();
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return paramaterNames.hasMoreElements();
}
@Override
public String nextElement() {
String name = paramaterNames.nextElement();
validateAllowedParameterName(name);
return name;
}
};
}
@Override
public String[] getParameterValues(String name) {
validateAllowedParameterName(name);
String[] values = super.getParameterValues(name);
if (values != null) {
for (String value : values) {
validateAllowedParameterValue(value);
}
}
return values;
}
private void validateAllowedHeaderName(String headerNames) {
if (!StrictHttpFirewall.this.allowedHeaderNames.test(headerNames)) {
throw new RequestRejectedException(
"The request was rejected because the header name \"" + headerNames + "\" is not allowed.");
}
}
private void validateAllowedHeaderValue(String value) {
if (!StrictHttpFirewall.this.allowedHeaderValues.test(value)) {
throw new RequestRejectedException(
"The request was rejected because the header value \"" + value + "\" is not allowed.");
}
}
private void validateAllowedParameterName(String name) {
if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) {
throw new RequestRejectedException(
"The request was rejected because the parameter name \"" + name + "\" is not allowed.");
}
}
private void validateAllowedParameterValue(String value) {
if (!StrictHttpFirewall.this.allowedParameterValues.test(value)) {
throw new RequestRejectedException(
"The request was rejected because the parameter value \"" + value + "\" is not allowed.");
}
}
@Override
public void reset() {
}
};
} }

View File

@ -62,20 +62,18 @@ public final class Header {
} }
@Override @Override
public boolean equals(Object o) { public boolean equals(Object obj) {
if (this == o) { if (this == obj) {
return true; return true;
} }
if (o == null || getClass() != o.getClass()) { if (obj == null || getClass() != obj.getClass()) {
return false; return false;
} }
Header other = (Header) obj;
Header header = (Header) o; if (!this.headerName.equals(other.headerName)) {
if (!this.headerName.equals(header.headerName)) {
return false; return false;
} }
return this.headerValues.equals(header.headerValues); return this.headerValues.equals(other.headerValues);
} }
@Override @Override

View File

@ -68,7 +68,6 @@ public class HeaderWriterFilter extends OncePerRequestFilter {
@Override @Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException { throws ServletException, IOException {
if (this.shouldWriteHeadersEagerly) { if (this.shouldWriteHeadersEagerly) {
doHeadersBefore(request, response, filterChain); doHeadersBefore(request, response, filterChain);
} }

View File

@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.web.header.HeaderWriter; import org.springframework.security.web.header.HeaderWriter;
import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -76,10 +77,9 @@ public final class ClearSiteDataHeaderWriter implements HeaderWriter {
response.setHeader(CLEAR_SITE_DATA_HEADER, this.headerValue); response.setHeader(CLEAR_SITE_DATA_HEADER, this.headerValue);
} }
} }
else if (this.logger.isDebugEnabled()) { this.logger.debug(
this.logger.debug("Not injecting Clear-Site-Data header since it did not match the " + "requestMatcher " LogMessage.format("Not injecting Clear-Site-Data header since it did not match the requestMatcher %s",
+ this.requestMatcher); this.requestMatcher));
}
} }
private String transformToHeaderValue(Directive... directives) { private String transformToHeaderValue(Directive... directives) {
@ -97,14 +97,19 @@ public final class ClearSiteDataHeaderWriter implements HeaderWriter {
} }
/** /**
* <p> * Represents the directive values expected by the {@link ClearSiteDataHeaderWriter}.
* Represents the directive values expected by the {@link ClearSiteDataHeaderWriter}
* </p>
* .
*/ */
public enum Directive { public enum Directive {
CACHE("cache"), COOKIES("cookies"), STORAGE("storage"), EXECUTION_CONTEXTS("executionContexts"), ALL("*"); CACHE("cache"),
COOKIES("cookies"),
STORAGE("storage"),
EXECUTION_CONTEXTS("executionContexts"),
ALL("*");
private final String headerValue; private final String headerValue;

View File

@ -117,7 +117,7 @@ public final class ContentSecurityPolicyHeaderWriter implements HeaderWriter {
*/ */
@Override @Override
public void writeHeaders(HttpServletRequest request, HttpServletResponse response) { public void writeHeaders(HttpServletRequest request, HttpServletResponse response) {
String headerName = !this.reportOnly ? CONTENT_SECURITY_POLICY_HEADER String headerName = (!this.reportOnly) ? CONTENT_SECURITY_POLICY_HEADER
: CONTENT_SECURITY_POLICY_REPORT_ONLY_HEADER; : CONTENT_SECURITY_POLICY_REPORT_ONLY_HEADER;
if (!response.containsHeader(headerName)) { if (!response.containsHeader(headerName)) {
response.setHeader(headerName, this.policyDirectives); response.setHeader(headerName, this.policyDirectives);

View File

@ -174,19 +174,17 @@ public final class HpkpHeaderWriter implements HeaderWriter {
@Override @Override
public void writeHeaders(HttpServletRequest request, HttpServletResponse response) { public void writeHeaders(HttpServletRequest request, HttpServletResponse response) {
if (this.requestMatcher.matches(request)) { if (!this.requestMatcher.matches(request)) {
if (!this.pins.isEmpty()) {
String headerName = this.reportOnly ? HPKP_RO_HEADER_NAME : HPKP_HEADER_NAME;
if (!response.containsHeader(headerName)) {
response.setHeader(headerName, this.hpkpHeaderValue);
}
}
if (this.logger.isDebugEnabled()) {
this.logger.debug("Not injecting HPKP header since there aren't any pins");
}
}
else if (this.logger.isDebugEnabled()) {
this.logger.debug("Not injecting HPKP header since it wasn't a secure connection"); this.logger.debug("Not injecting HPKP header since it wasn't a secure connection");
return;
}
if (this.pins.isEmpty()) {
this.logger.debug("Not injecting HPKP header since there aren't any pins");
return;
}
String headerName = (this.reportOnly) ? HPKP_RO_HEADER_NAME : HPKP_HEADER_NAME;
if (!response.containsHeader(headerName)) {
response.setHeader(headerName, this.hpkpHeaderValue);
} }
} }
@ -294,9 +292,7 @@ public final class HpkpHeaderWriter implements HeaderWriter {
* @throws IllegalArgumentException if maxAgeInSeconds is negative * @throws IllegalArgumentException if maxAgeInSeconds is negative
*/ */
public void setMaxAgeInSeconds(long maxAgeInSeconds) { public void setMaxAgeInSeconds(long maxAgeInSeconds) {
if (maxAgeInSeconds < 0) { Assert.isTrue(maxAgeInSeconds > 0, () -> "maxAgeInSeconds must be non-negative. Got " + maxAgeInSeconds);
throw new IllegalArgumentException("maxAgeInSeconds must be non-negative. Got " + maxAgeInSeconds);
}
this.maxAgeInSeconds = maxAgeInSeconds; this.maxAgeInSeconds = maxAgeInSeconds;
updateHpkpHeaderValue(); updateHpkpHeaderValue();
} }
@ -414,11 +410,11 @@ public final class HpkpHeaderWriter implements HeaderWriter {
public void setReportUri(String reportUri) { public void setReportUri(String reportUri) {
try { try {
this.reportUri = new URI(reportUri); this.reportUri = new URI(reportUri);
updateHpkpHeaderValue();
} }
catch (URISyntaxException ex) { catch (URISyntaxException ex) {
throw new IllegalArgumentException(ex); throw new IllegalArgumentException(ex);
} }
updateHpkpHeaderValue();
} }
private void updateHpkpHeaderValue() { private void updateHpkpHeaderValue() {

View File

@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.web.header.HeaderWriter; import org.springframework.security.web.header.HeaderWriter;
import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -148,14 +149,13 @@ public final class HstsHeaderWriter implements HeaderWriter {
@Override @Override
public void writeHeaders(HttpServletRequest request, HttpServletResponse response) { public void writeHeaders(HttpServletRequest request, HttpServletResponse response) {
if (this.requestMatcher.matches(request)) { if (!this.requestMatcher.matches(request)) {
if (!response.containsHeader(HSTS_HEADER_NAME)) { this.logger.debug(LogMessage.format(
response.setHeader(HSTS_HEADER_NAME, this.hstsHeaderValue); "Not injecting HSTS header since it did not match the requestMatcher %s", this.requestMatcher));
} return;
} }
else if (this.logger.isDebugEnabled()) { if (!response.containsHeader(HSTS_HEADER_NAME)) {
this.logger.debug( response.setHeader(HSTS_HEADER_NAME, this.hstsHeaderValue);
"Not injecting HSTS header since it did not match the requestMatcher " + this.requestMatcher);
} }
} }
@ -188,9 +188,7 @@ public final class HstsHeaderWriter implements HeaderWriter {
* @throws IllegalArgumentException if maxAgeInSeconds is negative * @throws IllegalArgumentException if maxAgeInSeconds is negative
*/ */
public void setMaxAgeInSeconds(long maxAgeInSeconds) { public void setMaxAgeInSeconds(long maxAgeInSeconds) {
if (maxAgeInSeconds < 0) { Assert.isTrue(maxAgeInSeconds >= 0, () -> "maxAgeInSeconds must be non-negative. Got " + maxAgeInSeconds);
throw new IllegalArgumentException("maxAgeInSeconds must be non-negative. Got " + maxAgeInSeconds);
}
this.maxAgeInSeconds = maxAgeInSeconds; this.maxAgeInSeconds = maxAgeInSeconds;
updateHstsHeaderValue(); updateHstsHeaderValue();
} }

View File

@ -100,10 +100,21 @@ public class ReferrerPolicyHeaderWriter implements HeaderWriter {
public enum ReferrerPolicy { public enum ReferrerPolicy {
NO_REFERRER("no-referrer"), NO_REFERRER_WHEN_DOWNGRADE("no-referrer-when-downgrade"), SAME_ORIGIN( NO_REFERRER("no-referrer"),
"same-origin"), ORIGIN("origin"), STRICT_ORIGIN("strict-origin"), ORIGIN_WHEN_CROSS_ORIGIN(
"origin-when-cross-origin"), STRICT_ORIGIN_WHEN_CROSS_ORIGIN( NO_REFERRER_WHEN_DOWNGRADE("no-referrer-when-downgrade"),
"strict-origin-when-cross-origin"), UNSAFE_URL("unsafe-url");
SAME_ORIGIN("same-origin"),
ORIGIN("origin"),
STRICT_ORIGIN("strict-origin"),
ORIGIN_WHEN_CROSS_ORIGIN("origin-when-cross-origin"),
STRICT_ORIGIN_WHEN_CROSS_ORIGIN("strict-origin-when-cross-origin"),
UNSAFE_URL("unsafe-url");
private static final Map<String, ReferrerPolicy> REFERRER_POLICIES; private static final Map<String, ReferrerPolicy> REFERRER_POLICIES;
@ -115,7 +126,7 @@ public class ReferrerPolicyHeaderWriter implements HeaderWriter {
REFERRER_POLICIES = Collections.unmodifiableMap(referrerPolicies); REFERRER_POLICIES = Collections.unmodifiableMap(referrerPolicies);
} }
private String policy; private final String policy;
ReferrerPolicy(String policy) { ReferrerPolicy(String policy) {
this.policy = policy; this.policy = policy;

View File

@ -21,6 +21,7 @@ import javax.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
@ -52,15 +53,11 @@ public abstract class AbstractRequestParameterAllowFromStrategy implements Allow
@Override @Override
public String getAllowFromValue(HttpServletRequest request) { public String getAllowFromValue(HttpServletRequest request) {
String allowFromOrigin = request.getParameter(this.allowFromParameterName); String allowFromOrigin = request.getParameter(this.allowFromParameterName);
if (this.log.isDebugEnabled()) { this.log.debug(LogMessage.format("Supplied origin '%s'", allowFromOrigin));
this.log.debug("Supplied origin '" + allowFromOrigin + "'");
}
if (StringUtils.hasText(allowFromOrigin) && allowed(allowFromOrigin)) { if (StringUtils.hasText(allowFromOrigin) && allowed(allowFromOrigin)) {
return allowFromOrigin; return allowFromOrigin;
} }
else { return "DENY";
return "DENY";
}
} }
/** /**

View File

@ -55,10 +55,9 @@ public final class XFrameOptionsHeaderWriter implements HeaderWriter {
*/ */
public XFrameOptionsHeaderWriter(XFrameOptionsMode frameOptionsMode) { public XFrameOptionsHeaderWriter(XFrameOptionsMode frameOptionsMode) {
Assert.notNull(frameOptionsMode, "frameOptionsMode cannot be null"); Assert.notNull(frameOptionsMode, "frameOptionsMode cannot be null");
if (XFrameOptionsMode.ALLOW_FROM.equals(frameOptionsMode)) { Assert.isTrue(!XFrameOptionsMode.ALLOW_FROM.equals(frameOptionsMode),
throw new IllegalArgumentException( "ALLOW_FROM requires an AllowFromStrategy. Please use "
"ALLOW_FROM requires an AllowFromStrategy. Please use FrameOptionsHeaderWriter(AllowFromStrategy allowFromStrategy) instead"); + "FrameOptionsHeaderWriter(AllowFromStrategy allowFromStrategy) instead");
}
this.frameOptionsMode = frameOptionsMode; this.frameOptionsMode = frameOptionsMode;
this.allowFromStrategy = null; this.allowFromStrategy = null;
} }
@ -113,7 +112,10 @@ public final class XFrameOptionsHeaderWriter implements HeaderWriter {
*/ */
public enum XFrameOptionsMode { public enum XFrameOptionsMode {
DENY("DENY"), SAMEORIGIN("SAMEORIGIN"), DENY("DENY"),
SAMEORIGIN("SAMEORIGIN"),
/** /**
* @deprecated ALLOW-FROM is an obsolete directive that no longer works in modern * @deprecated ALLOW-FROM is an obsolete directive that no longer works in modern
* browsers. Instead use Content-Security-Policy with the <a href= * browsers. Instead use Content-Security-Policy with the <a href=
@ -123,7 +125,7 @@ public final class XFrameOptionsHeaderWriter implements HeaderWriter {
@Deprecated @Deprecated
ALLOW_FROM("ALLOW-FROM"); ALLOW_FROM("ALLOW-FROM");
private String mode; private final String mode;
XFrameOptionsMode(String mode) { XFrameOptionsMode(String mode) {
this.mode = mode; this.mode = mode;

View File

@ -29,6 +29,9 @@ import org.springframework.util.Assert;
*/ */
public final class SecurityHeaders { public final class SecurityHeaders {
private SecurityHeaders() {
}
/** /**
* Sets the provided value as a Bearer token in a header with the name of * Sets the provided value as a Bearer token in a header with the name of
* {@link HttpHeaders#AUTHORIZATION} * {@link HttpHeaders#AUTHORIZATION}
@ -40,7 +43,4 @@ public final class SecurityHeaders {
return (headers) -> headers.set(HttpHeaders.AUTHORIZATION, "Bearer " + bearerTokenValue); return (headers) -> headers.set(HttpHeaders.AUTHORIZATION, "Bearer " + bearerTokenValue);
} }
private SecurityHeaders() {
}
} }

Some files were not shown because too many files have changed in this diff Show More