diff --git a/web/src/main/java/org/springframework/security/web/AuthenticationEntryPoint.java b/web/src/main/java/org/springframework/security/web/AuthenticationEntryPoint.java index eb47070ccc..ef59b28453 100644 --- a/web/src/main/java/org/springframework/security/web/AuthenticationEntryPoint.java +++ b/web/src/main/java/org/springframework/security/web/AuthenticationEntryPoint.java @@ -45,7 +45,6 @@ public interface AuthenticationEntryPoint { * @param request that resulted in an AuthenticationException * @param response so that the user agent can begin authentication * @param authException that caused the invocation - * */ void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException authException) throws IOException, ServletException; diff --git a/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java b/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java index 6175bac52b..b135699766 100644 --- a/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java +++ b/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java @@ -24,7 +24,9 @@ import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.security.web.util.UrlUtils; +import org.springframework.util.Assert; /** * Simple implementation of RedirectStrategy which is the default used throughout @@ -51,11 +53,7 @@ public class DefaultRedirectStrategy implements RedirectStrategy { public void sendRedirect(HttpServletRequest request, HttpServletResponse response, String url) throws IOException { String redirectUrl = calculateRedirectUrl(request.getContextPath(), url); redirectUrl = response.encodeRedirectURL(redirectUrl); - - if (this.logger.isDebugEnabled()) { - this.logger.debug("Redirecting to '" + redirectUrl + "'"); - } - + this.logger.debug(LogMessage.format("Redirecting to '%s'", redirectUrl)); response.sendRedirect(redirectUrl); } @@ -64,30 +62,20 @@ public class DefaultRedirectStrategy implements RedirectStrategy { if (isContextRelative()) { return url; } - else { - return contextPath + url; - } + return contextPath + url; } - // Full URL, including http(s):// - if (!isContextRelative()) { return url; } - - if (!url.contains(contextPath)) { - throw new IllegalArgumentException("The fully qualified URL does not include context path."); - } - + Assert.isTrue(url.contains(contextPath), "The fully qualified URL does not include context path."); // Calculate the relative URL from the fully qualified URL, minus the last // occurrence of the scheme and base context. - url = url.substring(url.lastIndexOf("://") + 3); // strip off scheme + url = url.substring(url.lastIndexOf("://") + 3); url = url.substring(url.indexOf(contextPath) + contextPath.length()); - if (url.length() > 1 && url.charAt(0) == '/') { url = url.substring(1); } - return url; } diff --git a/web/src/main/java/org/springframework/security/web/DefaultSecurityFilterChain.java b/web/src/main/java/org/springframework/security/web/DefaultSecurityFilterChain.java index 8786a86a1c..6e52979b6e 100644 --- a/web/src/main/java/org/springframework/security/web/DefaultSecurityFilterChain.java +++ b/web/src/main/java/org/springframework/security/web/DefaultSecurityFilterChain.java @@ -26,6 +26,7 @@ import javax.servlet.http.HttpServletRequest; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.security.web.util.matcher.RequestMatcher; /** @@ -47,7 +48,7 @@ public final class DefaultSecurityFilterChain implements SecurityFilterChain { } public DefaultSecurityFilterChain(RequestMatcher requestMatcher, List filters) { - logger.info("Creating filter chain: " + requestMatcher + ", " + filters); + logger.info(LogMessage.format("Creating filter chain: %s, %s", requestMatcher, filters)); this.requestMatcher = requestMatcher; this.filters = new ArrayList<>(filters); } diff --git a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java index 47bb405a83..6a71c73e26 100644 --- a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java +++ b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java @@ -32,6 +32,7 @@ import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.firewall.DefaultRequestRejectedHandler; import org.springframework.security.web.firewall.FirewalledRequest; @@ -173,47 +174,37 @@ public class FilterChainProxy extends GenericFilterBean { public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { boolean clearContext = request.getAttribute(FILTER_APPLIED) == null; - if (clearContext) { - try { - request.setAttribute(FILTER_APPLIED, Boolean.TRUE); - doFilterInternal(request, response, chain); - } - catch (RequestRejectedException ex) { - this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response, ex); - } - finally { - SecurityContextHolder.clearContext(); - request.removeAttribute(FILTER_APPLIED); - } - } - else { + if (!clearContext) { doFilterInternal(request, response, chain); + return; + } + try { + request.setAttribute(FILTER_APPLIED, Boolean.TRUE); + doFilterInternal(request, response, chain); + } + catch (RequestRejectedException ex) { + this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response, ex); + } + finally { + SecurityContextHolder.clearContext(); + request.removeAttribute(FILTER_APPLIED); } } private void doFilterInternal(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - - FirewalledRequest fwRequest = this.firewall.getFirewalledRequest((HttpServletRequest) request); - HttpServletResponse fwResponse = this.firewall.getFirewalledResponse((HttpServletResponse) response); - - List filters = getFilters(fwRequest); - + FirewalledRequest firewallRequest = this.firewall.getFirewalledRequest((HttpServletRequest) request); + HttpServletResponse firewallResponse = this.firewall.getFirewalledResponse((HttpServletResponse) response); + List filters = getFilters(firewallRequest); if (filters == null || filters.size() == 0) { - if (logger.isDebugEnabled()) { - logger.debug(UrlUtils.buildRequestUrl(fwRequest) - + ((filters != null) ? " has an empty filter list" : " has no matching filters")); - } - - fwRequest.reset(); - - chain.doFilter(fwRequest, fwResponse); - + logger.debug(LogMessage.of(() -> UrlUtils.buildRequestUrl(firewallRequest) + + ((filters != null) ? " has an empty filter list" : " has no matching filters"))); + firewallRequest.reset(); + chain.doFilter(firewallRequest, firewallResponse); return; } - - VirtualFilterChain vfc = new VirtualFilterChain(fwRequest, chain, filters); - vfc.doFilter(fwRequest, fwResponse); + VirtualFilterChain virtualFilterChain = new VirtualFilterChain(firewallRequest, chain, filters); + virtualFilterChain.doFilter(firewallRequest, firewallResponse); } /** @@ -227,7 +218,6 @@ public class FilterChainProxy extends GenericFilterBean { return chain.getFilters(); } } - return null; } @@ -286,7 +276,6 @@ public class FilterChainProxy extends GenericFilterBean { sb.append("Filter Chains: "); sb.append(this.filterChains); sb.append("]"); - return sb.toString(); } @@ -317,30 +306,19 @@ public class FilterChainProxy extends GenericFilterBean { @Override public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { if (this.currentPosition == this.size) { - if (logger.isDebugEnabled()) { - logger.debug(UrlUtils.buildRequestUrl(this.firewalledRequest) - + " reached end of additional filter chain; proceeding with original chain"); - } - + logger.debug(LogMessage.of(() -> UrlUtils.buildRequestUrl(this.firewalledRequest) + + " reached end of additional filter chain; proceeding with original chain")); // Deactivate path stripping as we exit the security filter chain this.firewalledRequest.reset(); - this.originalChain.doFilter(request, response); + return; } - else { - this.currentPosition++; - - Filter nextFilter = this.additionalFilters.get(this.currentPosition - 1); - - if (logger.isDebugEnabled()) { - logger.debug( - UrlUtils.buildRequestUrl(this.firewalledRequest) + " at position " + this.currentPosition - + " of " + this.size + " in additional filter chain; firing Filter: '" - + nextFilter.getClass().getSimpleName() + "'"); - } - - nextFilter.doFilter(request, response, this); - } + this.currentPosition++; + Filter nextFilter = this.additionalFilters.get(this.currentPosition - 1); + logger.debug(LogMessage.of(() -> UrlUtils.buildRequestUrl(this.firewalledRequest) + " at position " + + this.currentPosition + " of " + this.size + " in additional filter chain; firing Filter: '" + + nextFilter.getClass().getSimpleName() + "'")); + nextFilter.doFilter(request, response, this); } } diff --git a/web/src/main/java/org/springframework/security/web/FilterInvocation.java b/web/src/main/java/org/springframework/security/web/FilterInvocation.java index 879b1ce94b..1062c4eedc 100644 --- a/web/src/main/java/org/springframework/security/web/FilterInvocation.java +++ b/web/src/main/java/org/springframework/security/web/FilterInvocation.java @@ -37,6 +37,7 @@ import javax.servlet.http.HttpServletResponse; import org.springframework.http.HttpHeaders; import org.springframework.security.web.util.UrlUtils; +import org.springframework.util.Assert; /** * Holds objects associated with a HTTP filter. @@ -65,10 +66,7 @@ public class FilterInvocation { private HttpServletResponse response; public FilterInvocation(ServletRequest request, ServletResponse response, FilterChain chain) { - if ((request == null) || (response == null) || (chain == null)) { - throw new IllegalArgumentException("Cannot pass null values to constructor"); - } - + Assert.isTrue(request != null && response != null && chain != null, "Cannot pass null values to constructor"); this.request = (HttpServletRequest) request; this.response = (HttpServletResponse) response; this.chain = chain; @@ -84,9 +82,7 @@ public class FilterInvocation { public FilterInvocation(String contextPath, String servletPath, String pathInfo, String query, String method) { DummyRequest request = new DummyRequest(); - if (contextPath == null) { - contextPath = "/cp"; - } + contextPath = (contextPath != null) ? contextPath : "/cp"; request.setContextPath(contextPath); request.setServletPath(servletPath); request.setRequestURI(contextPath + servletPath + ((pathInfo != null) ? pathInfo : "")); @@ -256,9 +252,7 @@ public class FilterInvocation { if (value == null) { return -1; } - else { - return Integer.parseInt(value); - } + return Integer.parseInt(value); } void addHeader(String name, String value) { @@ -267,8 +261,8 @@ public class FilterInvocation { @Override public String getParameter(String name) { - String[] arr = this.parameters.get(name); - return (arr != null && arr.length > 0) ? arr[0] : null; + String[] array = this.parameters.get(name); + return (array != null && array.length > 0) ? array[0] : null; } @Override @@ -317,7 +311,6 @@ public class FilterInvocation { private Object invokeDefaultMethodForJdk8(Object proxy, Method method, Object[] args) throws Throwable { Constructor constructor = Lookup.class.getDeclaredConstructor(Class.class); constructor.setAccessible(true); - Class clazz = method.getDeclaringClass(); return constructor.newInstance(clazz).in(clazz).unreflectSpecial(method, clazz).bindTo(proxy) .invokeWithArguments(args); diff --git a/web/src/main/java/org/springframework/security/web/PortMapperImpl.java b/web/src/main/java/org/springframework/security/web/PortMapperImpl.java index 9b5f416ea5..947cc93c62 100644 --- a/web/src/main/java/org/springframework/security/web/PortMapperImpl.java +++ b/web/src/main/java/org/springframework/security/web/PortMapperImpl.java @@ -56,7 +56,6 @@ public class PortMapperImpl implements PortMapper { return httpPort; } } - return null; } @@ -88,24 +87,19 @@ public class PortMapperImpl implements PortMapper { */ public void setPortMappings(Map newMappings) { Assert.notNull(newMappings, "A valid list of HTTPS port mappings must be provided"); - this.httpsPortMappings.clear(); - for (Map.Entry entry : newMappings.entrySet()) { Integer httpPort = Integer.valueOf(entry.getKey()); Integer httpsPort = Integer.valueOf(entry.getValue()); - - if ((httpPort < 1) || (httpPort > 65535) || (httpsPort < 1) || (httpsPort > 65535)) { - throw new IllegalArgumentException( - "one or both ports out of legal range: " + httpPort + ", " + httpsPort); - } - + Assert.isTrue(isInPortRange(httpPort) && isInPortRange(httpsPort), + () -> "one or both ports out of legal range: " + httpPort + ", " + httpsPort); this.httpsPortMappings.put(httpPort, httpsPort); } + Assert.isTrue(!this.httpsPortMappings.isEmpty(), "must map at least one port"); + } - if (this.httpsPortMappings.size() < 1) { - throw new IllegalArgumentException("must map at least one port"); - } + private boolean isInPortRange(int port) { + return port >= 1 && port <= 65535; } } diff --git a/web/src/main/java/org/springframework/security/web/PortResolverImpl.java b/web/src/main/java/org/springframework/security/web/PortResolverImpl.java index e17edef87b..faa01d83c3 100644 --- a/web/src/main/java/org/springframework/security/web/PortResolverImpl.java +++ b/web/src/main/java/org/springframework/security/web/PortResolverImpl.java @@ -45,24 +45,19 @@ public class PortResolverImpl implements PortResolver { @Override public int getServerPort(ServletRequest request) { int serverPort = request.getServerPort(); - Integer portLookup = null; - String scheme = request.getScheme().toLowerCase(); + Integer mappedPort = getMappedPort(serverPort, scheme); + return (mappedPort != null) ? mappedPort : serverPort; + } + private Integer getMappedPort(int serverPort, String scheme) { if ("http".equals(scheme)) { - portLookup = this.portMapper.lookupHttpPort(serverPort); - + return this.portMapper.lookupHttpPort(serverPort); } - else if ("https".equals(scheme)) { - portLookup = this.portMapper.lookupHttpsPort(serverPort); + if ("https".equals(scheme)) { + return this.portMapper.lookupHttpsPort(serverPort); } - - if (portLookup != null) { - // IE 6 bug - serverPort = portLookup; - } - - return serverPort; + return null; } public void setPortMapper(PortMapper portMapper) { diff --git a/web/src/main/java/org/springframework/security/web/access/AccessDeniedHandlerImpl.java b/web/src/main/java/org/springframework/security/web/access/AccessDeniedHandlerImpl.java index f129b4d8bc..26315b9ec0 100644 --- a/web/src/main/java/org/springframework/security/web/access/AccessDeniedHandlerImpl.java +++ b/web/src/main/java/org/springframework/security/web/access/AccessDeniedHandlerImpl.java @@ -18,7 +18,6 @@ package org.springframework.security.web.access; import java.io.IOException; -import javax.servlet.RequestDispatcher; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -29,6 +28,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.http.HttpStatus; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.web.WebAttributes; +import org.springframework.util.Assert; /** * Base implementation of {@link AccessDeniedHandler}. @@ -52,22 +52,19 @@ public class AccessDeniedHandlerImpl implements AccessDeniedHandler { @Override public void handle(HttpServletRequest request, HttpServletResponse response, AccessDeniedException accessDeniedException) throws IOException, ServletException { - if (!response.isCommitted()) { - if (this.errorPage != null) { - // Put exception into request scope (perhaps of use to a view) - request.setAttribute(WebAttributes.ACCESS_DENIED_403, accessDeniedException); - - // Set the 403 status code. - response.setStatus(HttpStatus.FORBIDDEN.value()); - - // forward to error page. - RequestDispatcher dispatcher = request.getRequestDispatcher(this.errorPage); - dispatcher.forward(request, response); - } - else { - response.sendError(HttpStatus.FORBIDDEN.value(), HttpStatus.FORBIDDEN.getReasonPhrase()); - } + if (response.isCommitted()) { + return; } + if (this.errorPage == null) { + response.sendError(HttpStatus.FORBIDDEN.value(), HttpStatus.FORBIDDEN.getReasonPhrase()); + return; + } + // Put exception into request scope (perhaps of use to a view) + request.setAttribute(WebAttributes.ACCESS_DENIED_403, accessDeniedException); + // Set the 403 status code. + response.setStatus(HttpStatus.FORBIDDEN.value()); + // forward to error page. + request.getRequestDispatcher(this.errorPage).forward(request, response); } /** @@ -78,10 +75,7 @@ public class AccessDeniedHandlerImpl implements AccessDeniedHandler { * limitations */ public void setErrorPage(String errorPage) { - if ((errorPage != null) && !errorPage.startsWith("/")) { - throw new IllegalArgumentException("errorPage must begin with '/'"); - } - + Assert.isTrue(errorPage == null || errorPage.startsWith("/"), "errorPage must begin with '/'"); this.errorPage = errorPage; } diff --git a/web/src/main/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluator.java b/web/src/main/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluator.java index 92ceac2ba7..7030d29c46 100644 --- a/web/src/main/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluator.java +++ b/web/src/main/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluator.java @@ -21,6 +21,7 @@ import java.util.Collection; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.intercept.AbstractSecurityInterceptor; @@ -47,7 +48,6 @@ public class DefaultWebInvocationPrivilegeEvaluator implements WebInvocationPriv "AbstractSecurityInterceptor does not support FilterInvocations"); Assert.notNull(securityInterceptor.getAccessDecisionManager(), "AbstractSecurityInterceptor must provide a non-null AccessDecisionManager"); - this.securityInterceptor = securityInterceptor; } @@ -82,34 +82,23 @@ public class DefaultWebInvocationPrivilegeEvaluator implements WebInvocationPriv @Override public boolean isAllowed(String contextPath, String uri, String method, Authentication authentication) { Assert.notNull(uri, "uri parameter is required"); - - FilterInvocation fi = new FilterInvocation(contextPath, uri, method); - Collection attrs = this.securityInterceptor.obtainSecurityMetadataSource().getAttributes(fi); - - if (attrs == null) { - if (this.securityInterceptor.isRejectPublicInvocations()) { - return false; - } - - return true; + FilterInvocation filterInvocation = new FilterInvocation(contextPath, uri, method); + Collection attributes = this.securityInterceptor.obtainSecurityMetadataSource() + .getAttributes(filterInvocation); + if (attributes == null) { + return (!this.securityInterceptor.isRejectPublicInvocations()); } - if (authentication == null) { return false; } - try { - this.securityInterceptor.getAccessDecisionManager().decide(authentication, fi, attrs); + this.securityInterceptor.getAccessDecisionManager().decide(authentication, filterInvocation, attributes); + return true; } - catch (AccessDeniedException unauthorized) { - if (logger.isDebugEnabled()) { - logger.debug(fi.toString() + " denied for " + authentication.toString(), unauthorized); - } - + catch (AccessDeniedException ex) { + logger.debug(LogMessage.format("%s denied for %s", filterInvocation, authentication), ex); return false; } - - return true; } } diff --git a/web/src/main/java/org/springframework/security/web/access/ExceptionTranslationFilter.java b/web/src/main/java/org/springframework/security/web/access/ExceptionTranslationFilter.java index 2ff4908806..dd0360aa87 100644 --- a/web/src/main/java/org/springframework/security/web/access/ExceptionTranslationFilter.java +++ b/web/src/main/java/org/springframework/security/web/access/ExceptionTranslationFilter.java @@ -26,6 +26,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.springframework.context.support.MessageSourceAccessor; +import org.springframework.core.log.LogMessage; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authentication.AuthenticationTrustResolver; import org.springframework.security.authentication.AuthenticationTrustResolverImpl; @@ -107,14 +108,15 @@ public class ExceptionTranslationFilter extends GenericFilterBean { } @Override - public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - HttpServletRequest request = (HttpServletRequest) req; - HttpServletResponse response = (HttpServletResponse) res; + doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain); + } + private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) + throws IOException, ServletException { try { chain.doFilter(request, response); - this.logger.debug("Chain processed normally"); } catch (IOException ex) { @@ -123,38 +125,36 @@ public class ExceptionTranslationFilter extends GenericFilterBean { catch (Exception ex) { // Try to extract a SpringSecurityException from the stacktrace Throwable[] causeChain = this.throwableAnalyzer.determineCauseChain(ex); - RuntimeException ase = (AuthenticationException) this.throwableAnalyzer + RuntimeException securityException = (AuthenticationException) this.throwableAnalyzer .getFirstThrowableOfType(AuthenticationException.class, causeChain); - - if (ase == null) { - ase = (AccessDeniedException) this.throwableAnalyzer + if (securityException == null) { + securityException = (AccessDeniedException) this.throwableAnalyzer .getFirstThrowableOfType(AccessDeniedException.class, causeChain); } - - if (ase != null) { - if (response.isCommitted()) { - throw new ServletException( - "Unable to handle the Spring Security Exception because the response is already committed.", - ex); - } - handleSpringSecurityException(request, response, chain, ase); + if (securityException == null) { + rethrow(ex); } - else { - // Rethrow ServletExceptions and RuntimeExceptions as-is - if (ex instanceof ServletException) { - throw (ServletException) ex; - } - else if (ex instanceof RuntimeException) { - throw (RuntimeException) ex; - } - - // Wrap other Exceptions. This shouldn't actually happen - // as we've already covered all the possibilities for doFilter - throw new RuntimeException(ex); + if (response.isCommitted()) { + throw new ServletException("Unable to handle the Spring Security Exception " + + "because the response is already committed.", ex); } + handleSpringSecurityException(request, response, chain, securityException); } } + private void rethrow(Exception ex) throws ServletException { + // Rethrow ServletExceptions and RuntimeExceptions as-is + if (ex instanceof ServletException) { + throw (ServletException) ex; + } + if (ex instanceof RuntimeException) { + throw (RuntimeException) ex; + } + // Wrap other Exceptions. This shouldn't actually happen + // as we've already covered all the possibilities for doFilter + throw new RuntimeException(ex); + } + public AuthenticationEntryPoint getAuthenticationEntryPoint() { return this.authenticationEntryPoint; } @@ -166,32 +166,36 @@ public class ExceptionTranslationFilter extends GenericFilterBean { private void handleSpringSecurityException(HttpServletRequest request, HttpServletResponse response, FilterChain chain, RuntimeException exception) throws IOException, ServletException { if (exception instanceof AuthenticationException) { - this.logger.debug("Authentication exception occurred; redirecting to authentication entry point", - exception); - - sendStartAuthentication(request, response, chain, (AuthenticationException) exception); + handleAuthenticationException(request, response, chain, (AuthenticationException) exception); } else if (exception instanceof AccessDeniedException) { - Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); - if (this.authenticationTrustResolver.isAnonymous(authentication) - || this.authenticationTrustResolver.isRememberMe(authentication)) { - this.logger.debug( - "Access is denied (user is " + (this.authenticationTrustResolver.isAnonymous(authentication) - ? "anonymous" : "not fully authenticated") - + "); redirecting to authentication entry point", - exception); + handleAccessDeniedException(request, response, chain, (AccessDeniedException) exception); + } + } - sendStartAuthentication(request, response, chain, - new InsufficientAuthenticationException( - this.messages.getMessage("ExceptionTranslationFilter.insufficientAuthentication", - "Full authentication is required to access this resource"))); - } - else { - this.logger.debug("Access is denied (user is not anonymous); delegating to AccessDeniedHandler", - exception); + private void handleAuthenticationException(HttpServletRequest request, HttpServletResponse response, + FilterChain chain, AuthenticationException exception) throws ServletException, IOException { + this.logger.debug("Authentication exception occurred; redirecting to authentication entry point", exception); + sendStartAuthentication(request, response, chain, exception); + } - this.accessDeniedHandler.handle(request, response, (AccessDeniedException) exception); - } + private void handleAccessDeniedException(HttpServletRequest request, HttpServletResponse response, + FilterChain chain, AccessDeniedException exception) throws ServletException, IOException { + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + boolean isAnonymous = this.authenticationTrustResolver.isAnonymous(authentication); + if (isAnonymous || this.authenticationTrustResolver.isRememberMe(authentication)) { + this.logger.debug(LogMessage + .of(() -> "Access is denied (user is " + (isAnonymous ? "anonymous" : "not fully authenticated") + + "); redirecting to authentication entry point"), + exception); + sendStartAuthentication(request, response, chain, + new InsufficientAuthenticationException( + this.messages.getMessage("ExceptionTranslationFilter.insufficientAuthentication", + "Full authentication is required to access this resource"))); + } + else { + this.logger.debug("Access is denied (user is not anonymous); delegating to AccessDeniedHandler", exception); + this.accessDeniedHandler.handle(request, response, exception); } } @@ -232,7 +236,6 @@ public class ExceptionTranslationFilter extends GenericFilterBean { @Override protected void initExtractorMap() { super.initExtractorMap(); - registerExtractor(ServletException.class, (throwable) -> { ThrowableAnalyzer.verifyThrowableHierarchy(throwable, ServletException.class); return ((ServletException) throwable).getRootCause(); diff --git a/web/src/main/java/org/springframework/security/web/access/channel/AbstractRetryEntryPoint.java b/web/src/main/java/org/springframework/security/web/access/channel/AbstractRetryEntryPoint.java index 607ada7667..81fd2279b4 100644 --- a/web/src/main/java/org/springframework/security/web/access/channel/AbstractRetryEntryPoint.java +++ b/web/src/main/java/org/springframework/security/web/access/channel/AbstractRetryEntryPoint.java @@ -24,6 +24,7 @@ import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.PortMapper; import org.springframework.security.web.PortMapperImpl; @@ -43,10 +44,14 @@ public abstract class AbstractRetryEntryPoint implements ChannelEntryPoint { private PortResolver portResolver = new PortResolverImpl(); - /** The scheme ("http://" or "https://") */ + /** + * The scheme ("http://" or "https://") + */ private final String scheme; - /** The standard port for the scheme (80 for http, 443 for https) */ + /** + * The standard port for the scheme (80 for http, 443 for https) + */ private final int standardPort; private RedirectStrategy redirectStrategy = new DefaultRedirectStrategy(); @@ -60,21 +65,14 @@ public abstract class AbstractRetryEntryPoint implements ChannelEntryPoint { public void commence(HttpServletRequest request, HttpServletResponse response) throws IOException { String queryString = request.getQueryString(); String redirectUrl = request.getRequestURI() + ((queryString != null) ? ("?" + queryString) : ""); - Integer currentPort = this.portResolver.getServerPort(request); Integer redirectPort = getMappedPort(currentPort); - if (redirectPort != null) { boolean includePort = redirectPort != this.standardPort; - - redirectUrl = this.scheme + request.getServerName() + ((includePort) ? (":" + redirectPort) : "") - + redirectUrl; + String port = (includePort) ? (":" + redirectPort) : ""; + redirectUrl = this.scheme + request.getServerName() + port + redirectUrl; } - - if (this.logger.isDebugEnabled()) { - this.logger.debug("Redirecting to: " + redirectUrl); - } - + this.logger.debug(LogMessage.format("Redirecting to: %s", redirectUrl)); this.redirectStrategy.sendRedirect(request, response, redirectUrl); } diff --git a/web/src/main/java/org/springframework/security/web/access/channel/ChannelDecisionManagerImpl.java b/web/src/main/java/org/springframework/security/web/access/channel/ChannelDecisionManagerImpl.java index 2266f457ac..5f6e05b29b 100644 --- a/web/src/main/java/org/springframework/security/web/access/channel/ChannelDecisionManagerImpl.java +++ b/web/src/main/java/org/springframework/security/web/access/channel/ChannelDecisionManagerImpl.java @@ -64,10 +64,8 @@ public class ChannelDecisionManagerImpl implements ChannelDecisionManager, Initi return; } } - for (ChannelProcessor processor : this.channelProcessors) { processor.decide(invocation, config); - if (invocation.getResponse().isCommitted()) { break; } @@ -79,11 +77,10 @@ public class ChannelDecisionManagerImpl implements ChannelDecisionManager, Initi } @SuppressWarnings("cast") - public void setChannelProcessors(List newList) { - Assert.notEmpty(newList, "A list of ChannelProcessors is required"); - this.channelProcessors = new ArrayList<>(newList.size()); - - for (Object currentObject : newList) { + public void setChannelProcessors(List channelProcessors) { + Assert.notEmpty(channelProcessors, "A list of ChannelProcessors is required"); + this.channelProcessors = new ArrayList<>(channelProcessors.size()); + for (Object currentObject : channelProcessors) { Assert.isInstanceOf(ChannelProcessor.class, currentObject, () -> "ChannelProcessor " + currentObject.getClass().getName() + " must implement ChannelProcessor"); this.channelProcessors.add((ChannelProcessor) currentObject); @@ -95,13 +92,11 @@ public class ChannelDecisionManagerImpl implements ChannelDecisionManager, Initi if (ANY_CHANNEL.equals(attribute.getAttribute())) { return true; } - for (ChannelProcessor processor : this.channelProcessors) { if (processor.supports(attribute)) { return true; } } - return false; } diff --git a/web/src/main/java/org/springframework/security/web/access/channel/ChannelProcessingFilter.java b/web/src/main/java/org/springframework/security/web/access/channel/ChannelProcessingFilter.java index 2d7d8c6992..b9dc1acb3e 100644 --- a/web/src/main/java/org/springframework/security/web/access/channel/ChannelProcessingFilter.java +++ b/web/src/main/java/org/springframework/security/web/access/channel/ChannelProcessingFilter.java @@ -28,6 +28,7 @@ import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.springframework.core.log.LogMessage; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.web.FilterInvocation; import org.springframework.security.web.access.intercept.FilterInvocationSecurityMetadataSource; @@ -93,35 +94,26 @@ public class ChannelProcessingFilter extends GenericFilterBean { public void afterPropertiesSet() { Assert.notNull(this.securityMetadataSource, "securityMetadataSource must be specified"); Assert.notNull(this.channelDecisionManager, "channelDecisionManager must be specified"); - - Collection attrDefs = this.securityMetadataSource.getAllConfigAttributes(); - - if (attrDefs == null) { - if (this.logger.isWarnEnabled()) { - this.logger.warn( - "Could not validate configuration attributes as the FilterInvocationSecurityMetadataSource did " - + "not return any attributes"); - } - + Collection attributes = this.securityMetadataSource.getAllConfigAttributes(); + if (attributes == null) { + this.logger.warn("Could not validate configuration attributes as the " + + "FilterInvocationSecurityMetadataSource did not return any attributes"); return; } + Set unsupportedAttributes = getUnsupportedAttributes(attributes); + Assert.isTrue(unsupportedAttributes.isEmpty(), + () -> "Unsupported configuration attributes: " + unsupportedAttributes); + this.logger.info("Validated configuration attributes"); + } + private Set getUnsupportedAttributes(Collection attrDefs) { Set unsupportedAttributes = new HashSet<>(); - for (ConfigAttribute attr : attrDefs) { if (!this.channelDecisionManager.supports(attr)) { unsupportedAttributes.add(attr); } } - - if (unsupportedAttributes.size() == 0) { - if (this.logger.isInfoEnabled()) { - this.logger.info("Validated configuration attributes"); - } - } - else { - throw new IllegalArgumentException("Unsupported configuration attributes: " + unsupportedAttributes); - } + return unsupportedAttributes; } @Override @@ -129,22 +121,15 @@ public class ChannelProcessingFilter extends GenericFilterBean { throws IOException, ServletException { HttpServletRequest request = (HttpServletRequest) req; HttpServletResponse response = (HttpServletResponse) res; - - FilterInvocation fi = new FilterInvocation(request, response, chain); - Collection attr = this.securityMetadataSource.getAttributes(fi); - - if (attr != null) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("Request: " + fi.toString() + "; ConfigAttributes: " + attr); - } - - this.channelDecisionManager.decide(fi, attr); - - if (fi.getResponse().isCommitted()) { + FilterInvocation filterInvocation = new FilterInvocation(request, response, chain); + Collection attributes = this.securityMetadataSource.getAttributes(filterInvocation); + if (attributes != null) { + this.logger.debug(LogMessage.format("Request: %s; ConfigAttributes: %s", filterInvocation, attributes)); + this.channelDecisionManager.decide(filterInvocation, attributes); + if (filterInvocation.getResponse().isCommitted()) { return; } } - chain.doFilter(request, response); } diff --git a/web/src/main/java/org/springframework/security/web/access/channel/ChannelProcessor.java b/web/src/main/java/org/springframework/security/web/access/channel/ChannelProcessor.java index 1aca25bf4b..19cb406f35 100644 --- a/web/src/main/java/org/springframework/security/web/access/channel/ChannelProcessor.java +++ b/web/src/main/java/org/springframework/security/web/access/channel/ChannelProcessor.java @@ -40,7 +40,6 @@ public interface ChannelProcessor { /** * Decided whether the presented {@link FilterInvocation} provides the appropriate * level of channel security based on the requested list of ConfigAttributes. - * */ void decide(FilterInvocation invocation, Collection config) throws IOException, ServletException; diff --git a/web/src/main/java/org/springframework/security/web/access/channel/InsecureChannelProcessor.java b/web/src/main/java/org/springframework/security/web/access/channel/InsecureChannelProcessor.java index 6f4d7b43bc..3c214f73d7 100644 --- a/web/src/main/java/org/springframework/security/web/access/channel/InsecureChannelProcessor.java +++ b/web/src/main/java/org/springframework/security/web/access/channel/InsecureChannelProcessor.java @@ -55,10 +55,7 @@ public class InsecureChannelProcessor implements InitializingBean, ChannelProces @Override public void decide(FilterInvocation invocation, Collection config) throws IOException, ServletException { - if ((invocation == null) || (config == null)) { - throw new IllegalArgumentException("Nulls cannot be provided"); - } - + Assert.isTrue(invocation != null && config != null, "Nulls cannot be provided"); for (ConfigAttribute attribute : config) { if (supports(attribute)) { if (invocation.getHttpRequest().isSecure()) { diff --git a/web/src/main/java/org/springframework/security/web/access/channel/SecureChannelProcessor.java b/web/src/main/java/org/springframework/security/web/access/channel/SecureChannelProcessor.java index 4e0f501208..507bd67906 100644 --- a/web/src/main/java/org/springframework/security/web/access/channel/SecureChannelProcessor.java +++ b/web/src/main/java/org/springframework/security/web/access/channel/SecureChannelProcessor.java @@ -56,7 +56,6 @@ public class SecureChannelProcessor implements InitializingBean, ChannelProcesso public void decide(FilterInvocation invocation, Collection config) throws IOException, ServletException { Assert.isTrue((invocation != null) && (config != null), "Nulls cannot be provided"); - for (ConfigAttribute attribute : config) { if (supports(attribute)) { if (!invocation.getHttpRequest().isSecure()) { diff --git a/web/src/main/java/org/springframework/security/web/access/expression/AbstractVariableEvaluationContextPostProcessor.java b/web/src/main/java/org/springframework/security/web/access/expression/AbstractVariableEvaluationContextPostProcessor.java index e562ea1e4b..b15a7c02ff 100644 --- a/web/src/main/java/org/springframework/security/web/access/expression/AbstractVariableEvaluationContextPostProcessor.java +++ b/web/src/main/java/org/springframework/security/web/access/expression/AbstractVariableEvaluationContextPostProcessor.java @@ -41,25 +41,37 @@ abstract class AbstractVariableEvaluationContextPostProcessor @Override public final EvaluationContext postProcess(EvaluationContext context, FilterInvocation invocation) { - final HttpServletRequest request = invocation.getHttpRequest(); - return new DelegatingEvaluationContext(context) { - private Map variables; - - @Override - public Object lookupVariable(String name) { - Object result = super.lookupVariable(name); - if (result != null) { - return result; - } - if (this.variables == null) { - this.variables = extractVariables(request); - } - return this.variables.get(name); - } - - }; + return new VariableEvaluationContext(context, invocation.getHttpRequest()); } abstract Map extractVariables(HttpServletRequest request); + /** + * {@link DelegatingEvaluationContext} to expose variable. + */ + class VariableEvaluationContext extends DelegatingEvaluationContext { + + private final HttpServletRequest request; + + private Map variables; + + VariableEvaluationContext(EvaluationContext delegate, HttpServletRequest request) { + super(delegate); + this.request = request; + } + + @Override + public Object lookupVariable(String name) { + Object result = super.lookupVariable(name); + if (result != null) { + return result; + } + if (this.variables == null) { + this.variables = extractVariables(this.request); + } + return this.variables.get(name); + } + + } + } diff --git a/web/src/main/java/org/springframework/security/web/access/expression/ExpressionBasedFilterInvocationSecurityMetadataSource.java b/web/src/main/java/org/springframework/security/web/access/expression/ExpressionBasedFilterInvocationSecurityMetadataSource.java index b8e4838794..c1b50f2140 100644 --- a/web/src/main/java/org/springframework/security/web/access/expression/ExpressionBasedFilterInvocationSecurityMetadataSource.java +++ b/web/src/main/java/org/springframework/security/web/access/expression/ExpressionBasedFilterInvocationSecurityMetadataSource.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.LinkedHashMap; import java.util.Map; +import java.util.function.BiConsumer; import javax.servlet.http.HttpServletRequest; @@ -58,29 +59,29 @@ public final class ExpressionBasedFilterInvocationSecurityMetadataSource private static LinkedHashMap> processMap( LinkedHashMap> requestMap, ExpressionParser parser) { Assert.notNull(parser, "SecurityExpressionHandler returned a null parser object"); + LinkedHashMap> processed = new LinkedHashMap<>(requestMap); + requestMap.forEach((request, value) -> process(parser, request, value, processed::put)); + return processed; + } - LinkedHashMap> requestToExpressionAttributesMap = new LinkedHashMap<>( - requestMap); - - for (Map.Entry> entry : requestMap.entrySet()) { - RequestMatcher request = entry.getKey(); - Assert.isTrue(entry.getValue().size() == 1, () -> "Expected a single expression attribute for " + request); - ArrayList attributes = new ArrayList<>(1); - String expression = entry.getValue().toArray(new ConfigAttribute[1])[0].getAttribute(); - logger.debug("Adding web access control expression '" + expression + "', for " + request); - - AbstractVariableEvaluationContextPostProcessor postProcessor = createPostProcessor(request); - try { - attributes.add(new WebExpressionConfigAttribute(parser.parseExpression(expression), postProcessor)); - } - catch (ParseException ex) { - throw new IllegalArgumentException("Failed to parse expression '" + expression + "'"); - } - - requestToExpressionAttributesMap.put(request, attributes); + private static void process(ExpressionParser parser, RequestMatcher request, Collection value, + BiConsumer> consumer) { + String expression = getExpression(request, value); + logger.debug("Adding web access control expression '" + expression + "', for " + request); + AbstractVariableEvaluationContextPostProcessor postProcessor = createPostProcessor(request); + ArrayList processed = new ArrayList<>(1); + try { + processed.add(new WebExpressionConfigAttribute(parser.parseExpression(expression), postProcessor)); } + catch (ParseException ex) { + throw new IllegalArgumentException("Failed to parse expression '" + expression + "'"); + } + consumer.accept(request, processed); + } - return requestToExpressionAttributesMap; + private static String getExpression(RequestMatcher request, Collection value) { + Assert.isTrue(value.size() == 1, () -> "Expected a single expression attribute for " + request); + return value.toArray(new ConfigAttribute[1])[0].getAttribute(); } private static AbstractVariableEvaluationContextPostProcessor createPostProcessor(RequestMatcher request) { diff --git a/web/src/main/java/org/springframework/security/web/access/expression/WebExpressionVoter.java b/web/src/main/java/org/springframework/security/web/access/expression/WebExpressionVoter.java index 88db17f211..0822cb1532 100644 --- a/web/src/main/java/org/springframework/security/web/access/expression/WebExpressionVoter.java +++ b/web/src/main/java/org/springframework/security/web/access/expression/WebExpressionVoter.java @@ -25,6 +25,7 @@ import org.springframework.security.access.expression.ExpressionUtils; import org.springframework.security.access.expression.SecurityExpressionHandler; import org.springframework.security.core.Authentication; import org.springframework.security.web.FilterInvocation; +import org.springframework.util.Assert; /** * Voter which handles web authorisation decisions. @@ -37,21 +38,19 @@ public class WebExpressionVoter implements AccessDecisionVoter private SecurityExpressionHandler expressionHandler = new DefaultWebSecurityExpressionHandler(); @Override - public int vote(Authentication authentication, FilterInvocation fi, Collection attributes) { - assert authentication != null; - assert fi != null; - assert attributes != null; - - WebExpressionConfigAttribute weca = findConfigAttribute(attributes); - - if (weca == null) { + public int vote(Authentication authentication, FilterInvocation filterInvocation, + Collection attributes) { + Assert.notNull(authentication, "authentication must not be null"); + Assert.notNull(filterInvocation, "filterInvocation must not be null"); + Assert.notNull(attributes, "attributes must not be null"); + WebExpressionConfigAttribute webExpressionConfigAttribute = findConfigAttribute(attributes); + if (webExpressionConfigAttribute == null) { return ACCESS_ABSTAIN; } - - EvaluationContext ctx = this.expressionHandler.createEvaluationContext(authentication, fi); - ctx = weca.postProcess(ctx, fi); - - return ExpressionUtils.evaluateAsBoolean(weca.getAuthorizeExpression(), ctx) ? ACCESS_GRANTED : ACCESS_DENIED; + EvaluationContext ctx = webExpressionConfigAttribute.postProcess( + this.expressionHandler.createEvaluationContext(authentication, filterInvocation), filterInvocation); + return ExpressionUtils.evaluateAsBoolean(webExpressionConfigAttribute.getAuthorizeExpression(), ctx) + ? ACCESS_GRANTED : ACCESS_DENIED; } private WebExpressionConfigAttribute findConfigAttribute(Collection attributes) { diff --git a/web/src/main/java/org/springframework/security/web/access/expression/WebSecurityExpressionRoot.java b/web/src/main/java/org/springframework/security/web/access/expression/WebSecurityExpressionRoot.java index 70fd24f022..91bc2df290 100644 --- a/web/src/main/java/org/springframework/security/web/access/expression/WebSecurityExpressionRoot.java +++ b/web/src/main/java/org/springframework/security/web/access/expression/WebSecurityExpressionRoot.java @@ -29,13 +29,13 @@ import org.springframework.security.web.util.matcher.IpAddressMatcher; */ public class WebSecurityExpressionRoot extends SecurityExpressionRoot { - // private FilterInvocation filterInvocation; - /** Allows direct access to the request object */ + /** + * Allows direct access to the request object + */ public final HttpServletRequest request; public WebSecurityExpressionRoot(Authentication a, FilterInvocation fi) { super(a); - // this.filterInvocation = fi; this.request = fi.getRequest(); } @@ -47,7 +47,8 @@ public class WebSecurityExpressionRoot extends SecurityExpressionRoot { * @return true if the IP address of the current request is in the required range. */ public boolean hasIpAddress(String ipAddress) { - return (new IpAddressMatcher(ipAddress).matches(this.request)); + IpAddressMatcher matcher = new IpAddressMatcher(ipAddress); + return matcher.matches(this.request); } } diff --git a/web/src/main/java/org/springframework/security/web/access/intercept/DefaultFilterInvocationSecurityMetadataSource.java b/web/src/main/java/org/springframework/security/web/access/intercept/DefaultFilterInvocationSecurityMetadataSource.java index 91507ec62d..5f2661a32f 100644 --- a/web/src/main/java/org/springframework/security/web/access/intercept/DefaultFilterInvocationSecurityMetadataSource.java +++ b/web/src/main/java/org/springframework/security/web/access/intercept/DefaultFilterInvocationSecurityMetadataSource.java @@ -65,18 +65,13 @@ public class DefaultFilterInvocationSecurityMetadataSource implements FilterInvo */ public DefaultFilterInvocationSecurityMetadataSource( LinkedHashMap> requestMap) { - this.requestMap = requestMap; } @Override public Collection getAllConfigAttributes() { Set allAttributes = new HashSet<>(); - - for (Map.Entry> entry : this.requestMap.entrySet()) { - allAttributes.addAll(entry.getValue()); - } - + this.requestMap.values().forEach(allAttributes::addAll); return allAttributes; } diff --git a/web/src/main/java/org/springframework/security/web/access/intercept/FilterSecurityInterceptor.java b/web/src/main/java/org/springframework/security/web/access/intercept/FilterSecurityInterceptor.java index 43d84624cc..1472a1a05e 100644 --- a/web/src/main/java/org/springframework/security/web/access/intercept/FilterSecurityInterceptor.java +++ b/web/src/main/java/org/springframework/security/web/access/intercept/FilterSecurityInterceptor.java @@ -78,8 +78,7 @@ public class FilterSecurityInterceptor extends AbstractSecurityInterceptor imple @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - FilterInvocation fi = new FilterInvocation(request, response, chain); - invoke(fi); + invoke(new FilterInvocation(request, response, chain)); } public FilterInvocationSecurityMetadataSource getSecurityMetadataSource() { @@ -100,30 +99,30 @@ public class FilterSecurityInterceptor extends AbstractSecurityInterceptor imple return FilterInvocation.class; } - public void invoke(FilterInvocation fi) throws IOException, ServletException { - if ((fi.getRequest() != null) && (fi.getRequest().getAttribute(FILTER_APPLIED) != null) - && this.observeOncePerRequest) { + public void invoke(FilterInvocation filterInvocation) throws IOException, ServletException { + if (isApplied(filterInvocation) && this.observeOncePerRequest) { // filter already applied to this request and user wants us to observe // once-per-request handling, so don't re-do security checking - fi.getChain().doFilter(fi.getRequest(), fi.getResponse()); + filterInvocation.getChain().doFilter(filterInvocation.getRequest(), filterInvocation.getResponse()); + return; } - else { - // first time this request being called, so perform security checking - if (fi.getRequest() != null && this.observeOncePerRequest) { - fi.getRequest().setAttribute(FILTER_APPLIED, Boolean.TRUE); - } - - InterceptorStatusToken token = super.beforeInvocation(fi); - - try { - fi.getChain().doFilter(fi.getRequest(), fi.getResponse()); - } - finally { - super.finallyInvocation(token); - } - - super.afterInvocation(token, null); + // first time this request being called, so perform security checking + if (filterInvocation.getRequest() != null && this.observeOncePerRequest) { + filterInvocation.getRequest().setAttribute(FILTER_APPLIED, Boolean.TRUE); } + InterceptorStatusToken token = super.beforeInvocation(filterInvocation); + try { + filterInvocation.getChain().doFilter(filterInvocation.getRequest(), filterInvocation.getResponse()); + } + finally { + super.finallyInvocation(token); + } + super.afterInvocation(token, null); + } + + private boolean isApplied(FilterInvocation filterInvocation) { + return (filterInvocation.getRequest() != null) + && (filterInvocation.getRequest().getAttribute(FILTER_APPLIED) != null); } /** diff --git a/web/src/main/java/org/springframework/security/web/access/intercept/RequestKey.java b/web/src/main/java/org/springframework/security/web/access/intercept/RequestKey.java index 8e6c569313..b608e4f8b3 100644 --- a/web/src/main/java/org/springframework/security/web/access/intercept/RequestKey.java +++ b/web/src/main/java/org/springframework/security/web/access/intercept/RequestKey.java @@ -77,7 +77,6 @@ public class RequestKey { } sb.append(this.url); sb.append("]"); - return sb.toString(); } diff --git a/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java b/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java index a0ae2461d1..f4f135979f 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java @@ -30,6 +30,7 @@ import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.context.MessageSource; import org.springframework.context.MessageSourceAware; import org.springframework.context.support.MessageSourceAccessor; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.InternalAuthenticationServiceException; @@ -206,52 +207,39 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt * */ @Override - public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { + doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain); + } - HttpServletRequest request = (HttpServletRequest) req; - HttpServletResponse response = (HttpServletResponse) res; - + private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) + throws IOException, ServletException { if (!requiresAuthentication(request, response)) { chain.doFilter(request, response); - return; } - - if (this.logger.isDebugEnabled()) { - this.logger.debug("Request is to process authentication"); - } - - Authentication authResult; - + this.logger.debug("Request is to process authentication"); try { - authResult = attemptAuthentication(request, response); - if (authResult == null) { + Authentication authenticationResult = attemptAuthentication(request, response); + if (authenticationResult == null) { // return immediately as subclass has indicated that it hasn't completed - // authentication return; } - this.sessionStrategy.onAuthentication(authResult, request, response); + this.sessionStrategy.onAuthentication(authenticationResult, request, response); + // Authentication success + if (this.continueChainBeforeSuccessfulAuthentication) { + chain.doFilter(request, response); + } + successfulAuthentication(request, response, chain, authenticationResult); } catch (InternalAuthenticationServiceException failed) { this.logger.error("An internal error occurred while trying to authenticate the user.", failed); unsuccessfulAuthentication(request, response, failed); - - return; } - catch (AuthenticationException failed) { + catch (AuthenticationException ex) { // Authentication failed - unsuccessfulAuthentication(request, response, failed); - - return; + unsuccessfulAuthentication(request, response, ex); } - - // Authentication success - if (this.continueChainBeforeSuccessfulAuthentication) { - chain.doFilter(request, response); - } - - successfulAuthentication(request, response, chain, authResult); } /** @@ -316,20 +304,13 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt */ protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain, Authentication authResult) throws IOException, ServletException { - - if (this.logger.isDebugEnabled()) { - this.logger.debug("Authentication success. Updating SecurityContextHolder to contain: " + authResult); - } - + this.logger.debug( + LogMessage.format("Authentication success. Updating SecurityContextHolder to contain: %s", authResult)); SecurityContextHolder.getContext().setAuthentication(authResult); - this.rememberMeServices.loginSuccess(request, response, authResult); - - // Fire event if (this.eventPublisher != null) { this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass())); } - this.successHandler.onAuthenticationSuccess(request, response, authResult); } @@ -347,15 +328,12 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, AuthenticationException failed) throws IOException, ServletException { SecurityContextHolder.clearContext(); - if (this.logger.isDebugEnabled()) { this.logger.debug("Authentication request failed: " + failed.toString(), failed); this.logger.debug("Updated SecurityContextHolder to contain null Authentication"); this.logger.debug("Delegating to authentication failure handler " + this.failureHandler); } - this.rememberMeServices.loginFail(request, response); - this.failureHandler.onAuthenticationFailure(request, response, failed); } diff --git a/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationTargetUrlRequestHandler.java b/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationTargetUrlRequestHandler.java index 32e38c8525..f52b0a2a96 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationTargetUrlRequestHandler.java +++ b/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationTargetUrlRequestHandler.java @@ -25,6 +25,7 @@ import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.security.core.Authentication; import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.RedirectStrategy; @@ -84,18 +85,16 @@ public abstract class AbstractAuthenticationTargetUrlRequestHandler { protected void handle(HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException, ServletException { String targetUrl = determineTargetUrl(request, response, authentication); - if (response.isCommitted()) { - this.logger.debug("Response has already been committed. Unable to redirect to " + targetUrl); + this.logger.debug( + LogMessage.format("Response has already been committed. Unable to redirect to %s", targetUrl)); return; } - this.redirectStrategy.sendRedirect(request, response, targetUrl); } /** * Builds the target URL according to the logic defined in the main class Javadoc - * * @since 5.2 */ protected String determineTargetUrl(HttpServletRequest request, HttpServletResponse response, @@ -110,30 +109,23 @@ public abstract class AbstractAuthenticationTargetUrlRequestHandler { if (isAlwaysUseDefaultTargetUrl()) { return this.defaultTargetUrl; } - // Check for the parameter and use that if available String targetUrl = null; - if (this.targetUrlParameter != null) { targetUrl = request.getParameter(this.targetUrlParameter); - if (StringUtils.hasText(targetUrl)) { this.logger.debug("Found targetUrlParameter in request: " + targetUrl); - return targetUrl; } } - if (this.useReferer && !StringUtils.hasLength(targetUrl)) { targetUrl = request.getHeader("Referer"); this.logger.debug("Using Referer header: " + targetUrl); } - if (!StringUtils.hasText(targetUrl)) { targetUrl = this.defaultTargetUrl; this.logger.debug("Using default Url: " + targetUrl); } - return targetUrl; } diff --git a/web/src/main/java/org/springframework/security/web/authentication/AnonymousAuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/AnonymousAuthenticationFilter.java index fcd57272fe..1b29bebbe2 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/AnonymousAuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/AnonymousAuthenticationFilter.java @@ -26,6 +26,7 @@ import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import org.springframework.beans.factory.InitializingBean; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.core.Authentication; @@ -85,31 +86,24 @@ public class AnonymousAuthenticationFilter extends GenericFilterBean implements @Override public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException { - if (SecurityContextHolder.getContext().getAuthentication() == null) { SecurityContextHolder.getContext().setAuthentication(createAuthentication((HttpServletRequest) req)); - - if (this.logger.isDebugEnabled()) { - this.logger.debug("Populated SecurityContextHolder with anonymous token: '" - + SecurityContextHolder.getContext().getAuthentication() + "'"); - } + this.logger.debug(LogMessage.of(() -> "Populated SecurityContextHolder with anonymous token: '" + + SecurityContextHolder.getContext().getAuthentication() + "'")); } else { - if (this.logger.isDebugEnabled()) { - this.logger.debug("SecurityContextHolder not populated with anonymous token, as it already contained: '" - + SecurityContextHolder.getContext().getAuthentication() + "'"); - } + this.logger.debug(LogMessage + .of(() -> "SecurityContextHolder not populated with anonymous token, as it already contained: '" + + SecurityContextHolder.getContext().getAuthentication() + "'")); } - chain.doFilter(req, res); } protected Authentication createAuthentication(HttpServletRequest request) { - AnonymousAuthenticationToken auth = new AnonymousAuthenticationToken(this.key, this.principal, + AnonymousAuthenticationToken token = new AnonymousAuthenticationToken(this.key, this.principal, this.authorities); - auth.setDetails(this.authenticationDetailsSource.buildDetails(request)); - - return auth; + token.setDetails(this.authenticationDetailsSource.buildDetails(request)); + return token; } public void setAuthenticationDetailsSource( diff --git a/web/src/main/java/org/springframework/security/web/authentication/AuthenticationEntryPointFailureHandler.java b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationEntryPointFailureHandler.java index f0f375b8ba..0c6040f099 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/AuthenticationEntryPointFailureHandler.java +++ b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationEntryPointFailureHandler.java @@ -29,7 +29,7 @@ import org.springframework.util.Assert; /** * Adapts a {@link AuthenticationEntryPoint} into a {@link AuthenticationFailureHandler} * - * @author sbespalov + * @author Sergey Bespalov * @since 5.2.0 */ public class AuthenticationEntryPointFailureHandler implements AuthenticationFailureHandler { diff --git a/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java index 25f0f1e0b9..89cb9cc5d9 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java @@ -84,7 +84,6 @@ public class AuthenticationFilter extends OncePerRequestFilter { AuthenticationConverter authenticationConverter) { Assert.notNull(authenticationManagerResolver, "authenticationManagerResolver cannot be null"); Assert.notNull(authenticationConverter, "authenticationConverter cannot be null"); - this.authenticationManagerResolver = authenticationManagerResolver; this.authenticationConverter = authenticationConverter; } @@ -142,19 +141,16 @@ public class AuthenticationFilter extends OncePerRequestFilter { filterChain.doFilter(request, response); return; } - try { Authentication authenticationResult = attemptAuthentication(request, response); if (authenticationResult == null) { filterChain.doFilter(request, response); return; } - HttpSession session = request.getSession(false); if (session != null) { request.changeSessionId(); } - successfulAuthentication(request, response, filterChain, authenticationResult); } catch (AuthenticationException ex) { @@ -182,13 +178,11 @@ public class AuthenticationFilter extends OncePerRequestFilter { if (authentication == null) { return null; } - AuthenticationManager authenticationManager = this.authenticationManagerResolver.resolve(request); Authentication authenticationResult = authenticationManager.authenticate(authentication); if (authenticationResult == null) { throw new ServletException("AuthenticationManager should not return null Authentication object."); } - return authenticationResult; } diff --git a/web/src/main/java/org/springframework/security/web/authentication/DelegatingAuthenticationEntryPoint.java b/web/src/main/java/org/springframework/security/web/authentication/DelegatingAuthenticationEntryPoint.java index d9d2a5cbbd..d67dbf604f 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/DelegatingAuthenticationEntryPoint.java +++ b/web/src/main/java/org/springframework/security/web/authentication/DelegatingAuthenticationEntryPoint.java @@ -27,6 +27,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.InitializingBean; +import org.springframework.core.log.LogMessage; import org.springframework.security.core.AuthenticationException; import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.util.matcher.ELRequestMatcher; @@ -62,7 +63,7 @@ import org.springframework.util.Assert; */ public class DelegatingAuthenticationEntryPoint implements AuthenticationEntryPoint, InitializingBean { - private final Log logger = LogFactory.getLog(getClass()); + private static final Log logger = LogFactory.getLog(DelegatingAuthenticationEntryPoint.class); private final LinkedHashMap entryPoints; @@ -75,25 +76,16 @@ public class DelegatingAuthenticationEntryPoint implements AuthenticationEntryPo @Override public void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException authException) throws IOException, ServletException { - for (RequestMatcher requestMatcher : this.entryPoints.keySet()) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("Trying to match using " + requestMatcher); - } + logger.debug(LogMessage.format("Trying to match using %s", requestMatcher)); if (requestMatcher.matches(request)) { AuthenticationEntryPoint entryPoint = this.entryPoints.get(requestMatcher); - if (this.logger.isDebugEnabled()) { - this.logger.debug("Match found! Executing " + entryPoint); - } + logger.debug(LogMessage.format("Match found! Executing %s", entryPoint)); entryPoint.commence(request, response, authException); return; } } - - if (this.logger.isDebugEnabled()) { - this.logger.debug("No match found. Using default entry point " + this.defaultEntryPoint); - } - + logger.debug(LogMessage.format("No match found. Using default entry point %s", this.defaultEntryPoint)); // No EntryPoint matched, use defaultEntryPoint this.defaultEntryPoint.commence(request, response, authException); } diff --git a/web/src/main/java/org/springframework/security/web/authentication/DelegatingAuthenticationFailureHandler.java b/web/src/main/java/org/springframework/security/web/authentication/DelegatingAuthenticationFailureHandler.java index 41ec7712a8..4891c1ea83 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/DelegatingAuthenticationFailureHandler.java +++ b/web/src/main/java/org/springframework/security/web/authentication/DelegatingAuthenticationFailureHandler.java @@ -62,9 +62,6 @@ public class DelegatingAuthenticationFailureHandler implements AuthenticationFai this.defaultHandler = defaultHandler; } - /** - * {@inheritDoc} - */ @Override public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, AuthenticationException exception) throws IOException, ServletException { diff --git a/web/src/main/java/org/springframework/security/web/authentication/ExceptionMappingAuthenticationFailureHandler.java b/web/src/main/java/org/springframework/security/web/authentication/ExceptionMappingAuthenticationFailureHandler.java index a85603ea7b..5a619000f3 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/ExceptionMappingAuthenticationFailureHandler.java +++ b/web/src/main/java/org/springframework/security/web/authentication/ExceptionMappingAuthenticationFailureHandler.java @@ -49,7 +49,6 @@ public class ExceptionMappingAuthenticationFailureHandler extends SimpleUrlAuthe public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, AuthenticationException exception) throws IOException, ServletException { String url = this.failureUrlMap.get(exception.getClass().getName()); - if (url != null) { getRedirectStrategy().sendRedirect(request, response, url); } diff --git a/web/src/main/java/org/springframework/security/web/authentication/Http403ForbiddenEntryPoint.java b/web/src/main/java/org/springframework/security/web/authentication/Http403ForbiddenEntryPoint.java index bff458f006..216654945c 100755 --- a/web/src/main/java/org/springframework/security/web/authentication/Http403ForbiddenEntryPoint.java +++ b/web/src/main/java/org/springframework/security/web/authentication/Http403ForbiddenEntryPoint.java @@ -55,9 +55,7 @@ public class Http403ForbiddenEntryPoint implements AuthenticationEntryPoint { @Override public void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException arg2) throws IOException { - if (logger.isDebugEnabled()) { - logger.debug("Pre-authenticated entry point called. Rejecting access"); - } + logger.debug("Pre-authenticated entry point called. Rejecting access"); response.sendError(HttpServletResponse.SC_FORBIDDEN, "Access Denied"); } diff --git a/web/src/main/java/org/springframework/security/web/authentication/LoginUrlAuthenticationEntryPoint.java b/web/src/main/java/org/springframework/security/web/authentication/LoginUrlAuthenticationEntryPoint.java index ca59b6842f..08de9369c0 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/LoginUrlAuthenticationEntryPoint.java +++ b/web/src/main/java/org/springframework/security/web/authentication/LoginUrlAuthenticationEntryPoint.java @@ -27,6 +27,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.InitializingBean; +import org.springframework.core.log.LogMessage; import org.springframework.security.core.AuthenticationException; import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.DefaultRedirectStrategy; @@ -93,9 +94,8 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin public void afterPropertiesSet() { Assert.isTrue(StringUtils.hasText(this.loginFormUrl) && UrlUtils.isValidRedirectUrl(this.loginFormUrl), "loginFormUrl must be specified and must be a valid redirect URL"); - if (this.useForward && UrlUtils.isAbsoluteUrl(this.loginFormUrl)) { - throw new IllegalArgumentException("useForward must be false if using an absolute loginFormURL"); - } + Assert.isTrue(!this.useForward || !UrlUtils.isAbsoluteUrl(this.loginFormUrl), + "useForward must be false if using an absolute loginFormURL"); Assert.notNull(this.portMapper, "portMapper must be specified"); Assert.notNull(this.portResolver, "portResolver must be specified"); } @@ -110,7 +110,6 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin */ protected String determineUrlToUseForThisRequest(HttpServletRequest request, HttpServletResponse response, AuthenticationException exception) { - return getLoginFormUrl(); } @@ -120,75 +119,55 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin @Override public void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException authException) throws IOException, ServletException { - - String redirectUrl = null; - - if (this.useForward) { - - if (this.forceHttps && "http".equals(request.getScheme())) { - // First redirect the current request to HTTPS. - // When that request is received, the forward to the login page will be - // used. - redirectUrl = buildHttpsRedirectUrlForRequest(request); - } - - if (redirectUrl == null) { - String loginForm = determineUrlToUseForThisRequest(request, response, authException); - - if (logger.isDebugEnabled()) { - logger.debug("Server side forward to: " + loginForm); - } - - RequestDispatcher dispatcher = request.getRequestDispatcher(loginForm); - - dispatcher.forward(request, response); - - return; - } - } - else { + if (!this.useForward) { // redirect to login page. Use https if forceHttps true - - redirectUrl = buildRedirectUrlToLoginPage(request, response, authException); - + String redirectUrl = buildRedirectUrlToLoginPage(request, response, authException); + this.redirectStrategy.sendRedirect(request, response, redirectUrl); + return; } - - this.redirectStrategy.sendRedirect(request, response, redirectUrl); + String redirectUrl = null; + if (this.forceHttps && "http".equals(request.getScheme())) { + // First redirect the current request to HTTPS. When that request is received, + // the forward to the login page will be used. + redirectUrl = buildHttpsRedirectUrlForRequest(request); + } + if (redirectUrl != null) { + this.redirectStrategy.sendRedirect(request, response, redirectUrl); + return; + } + String loginForm = determineUrlToUseForThisRequest(request, response, authException); + logger.debug(LogMessage.format("Server side forward to: %s", loginForm)); + RequestDispatcher dispatcher = request.getRequestDispatcher(loginForm); + dispatcher.forward(request, response); + return; } protected String buildRedirectUrlToLoginPage(HttpServletRequest request, HttpServletResponse response, AuthenticationException authException) { - String loginForm = determineUrlToUseForThisRequest(request, response, authException); - if (UrlUtils.isAbsoluteUrl(loginForm)) { return loginForm; } - int serverPort = this.portResolver.getServerPort(request); String scheme = request.getScheme(); - RedirectUrlBuilder urlBuilder = new RedirectUrlBuilder(); - urlBuilder.setScheme(scheme); urlBuilder.setServerName(request.getServerName()); urlBuilder.setPort(serverPort); urlBuilder.setContextPath(request.getContextPath()); urlBuilder.setPathInfo(loginForm); - if (this.forceHttps && "http".equals(scheme)) { Integer httpsPort = this.portMapper.lookupHttpsPort(serverPort); - if (httpsPort != null) { // Overwrite scheme and port in the redirect URL urlBuilder.setScheme("https"); urlBuilder.setPort(httpsPort); } else { - logger.warn("Unable to redirect to HTTPS as no port mapping found for HTTP port " + serverPort); + logger.warn(LogMessage.format("Unable to redirect to HTTPS as no port mapping found for HTTP port %s", + serverPort)); } } - return urlBuilder.getUrl(); } @@ -197,10 +176,8 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin * current request to HTTPS, before doing a forward to the login page. */ protected String buildHttpsRedirectUrlForRequest(HttpServletRequest request) throws IOException, ServletException { - int serverPort = this.portResolver.getServerPort(request); Integer httpsPort = this.portMapper.lookupHttpsPort(serverPort); - if (httpsPort != null) { RedirectUrlBuilder urlBuilder = new RedirectUrlBuilder(); urlBuilder.setScheme("https"); @@ -210,13 +187,11 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin urlBuilder.setServletPath(request.getServletPath()); urlBuilder.setPathInfo(request.getPathInfo()); urlBuilder.setQuery(request.getQueryString()); - return urlBuilder.getUrl(); } - // Fall through to server-side forward with warning message - logger.warn("Unable to redirect to HTTPS as no port mapping found for HTTP port " + serverPort); - + logger.warn( + LogMessage.format("Unable to redirect to HTTPS as no port mapping found for HTTP port %s", serverPort)); return null; } diff --git a/web/src/main/java/org/springframework/security/web/authentication/SavedRequestAwareAuthenticationSuccessHandler.java b/web/src/main/java/org/springframework/security/web/authentication/SavedRequestAwareAuthenticationSuccessHandler.java index 8eab8096ca..3fa3802948 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/SavedRequestAwareAuthenticationSuccessHandler.java +++ b/web/src/main/java/org/springframework/security/web/authentication/SavedRequestAwareAuthenticationSuccessHandler.java @@ -74,10 +74,8 @@ public class SavedRequestAwareAuthenticationSuccessHandler extends SimpleUrlAuth public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws ServletException, IOException { SavedRequest savedRequest = this.requestCache.getRequest(request, response); - if (savedRequest == null) { super.onAuthenticationSuccess(request, response, authentication); - return; } String targetUrlParameter = getTargetUrlParameter(); @@ -85,12 +83,9 @@ public class SavedRequestAwareAuthenticationSuccessHandler extends SimpleUrlAuth || (targetUrlParameter != null && StringUtils.hasText(request.getParameter(targetUrlParameter)))) { this.requestCache.removeRequest(request, response); super.onAuthenticationSuccess(request, response, authentication); - return; } - clearAuthenticationAttributes(request); - // Use the DefaultSavedRequest URL String targetUrl = savedRequest.getRedirectUrl(); this.logger.debug("Redirecting to DefaultSavedRequest Url: " + targetUrl); diff --git a/web/src/main/java/org/springframework/security/web/authentication/SimpleUrlAuthenticationFailureHandler.java b/web/src/main/java/org/springframework/security/web/authentication/SimpleUrlAuthenticationFailureHandler.java index f376014279..8eba0ffa1d 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/SimpleUrlAuthenticationFailureHandler.java +++ b/web/src/main/java/org/springframework/security/web/authentication/SimpleUrlAuthenticationFailureHandler.java @@ -76,24 +76,19 @@ public class SimpleUrlAuthenticationFailureHandler implements AuthenticationFail @Override public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, AuthenticationException exception) throws IOException, ServletException { - if (this.defaultFailureUrl == null) { this.logger.debug("No failure URL set, sending 401 Unauthorized error"); - response.sendError(HttpStatus.UNAUTHORIZED.value(), HttpStatus.UNAUTHORIZED.getReasonPhrase()); + return; + } + saveException(request, exception); + if (this.forwardToDestination) { + this.logger.debug("Forwarding to " + this.defaultFailureUrl); + request.getRequestDispatcher(this.defaultFailureUrl).forward(request, response); } else { - saveException(request, exception); - - if (this.forwardToDestination) { - this.logger.debug("Forwarding to " + this.defaultFailureUrl); - - request.getRequestDispatcher(this.defaultFailureUrl).forward(request, response); - } - else { - this.logger.debug("Redirecting to " + this.defaultFailureUrl); - this.redirectStrategy.sendRedirect(request, response, this.defaultFailureUrl); - } + this.logger.debug("Redirecting to " + this.defaultFailureUrl); + this.redirectStrategy.sendRedirect(request, response, this.defaultFailureUrl); } } @@ -108,13 +103,11 @@ public class SimpleUrlAuthenticationFailureHandler implements AuthenticationFail protected final void saveException(HttpServletRequest request, AuthenticationException exception) { if (this.forwardToDestination) { request.setAttribute(WebAttributes.AUTHENTICATION_EXCEPTION, exception); + return; } - else { - HttpSession session = request.getSession(false); - - if (session != null || this.allowSessionCreation) { - request.getSession().setAttribute(WebAttributes.AUTHENTICATION_EXCEPTION, exception); - } + HttpSession session = request.getSession(false); + if (session != null || this.allowSessionCreation) { + request.getSession().setAttribute(WebAttributes.AUTHENTICATION_EXCEPTION, exception); } } diff --git a/web/src/main/java/org/springframework/security/web/authentication/SimpleUrlAuthenticationSuccessHandler.java b/web/src/main/java/org/springframework/security/web/authentication/SimpleUrlAuthenticationSuccessHandler.java index 4a7647239a..b8ac058cd8 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/SimpleUrlAuthenticationSuccessHandler.java +++ b/web/src/main/java/org/springframework/security/web/authentication/SimpleUrlAuthenticationSuccessHandler.java @@ -59,7 +59,6 @@ public class SimpleUrlAuthenticationSuccessHandler extends AbstractAuthenticatio @Override public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException, ServletException { - handle(request, response, authentication); clearAuthenticationAttributes(request); } @@ -70,12 +69,9 @@ public class SimpleUrlAuthenticationSuccessHandler extends AbstractAuthenticatio */ protected final void clearAuthenticationAttributes(HttpServletRequest request) { HttpSession session = request.getSession(false); - - if (session == null) { - return; + if (session != null) { + session.removeAttribute(WebAttributes.AUTHENTICATION_EXCEPTION); } - - session.removeAttribute(WebAttributes.AUTHENTICATION_EXCEPTION); } } diff --git a/web/src/main/java/org/springframework/security/web/authentication/UsernamePasswordAuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/UsernamePasswordAuthenticationFilter.java index b1f94f7e3f..e1a444594f 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/UsernamePasswordAuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/UsernamePasswordAuthenticationFilter.java @@ -74,25 +74,14 @@ public class UsernamePasswordAuthenticationFilter extends AbstractAuthentication if (this.postOnly && !request.getMethod().equals("POST")) { throw new AuthenticationServiceException("Authentication method not supported: " + request.getMethod()); } - String username = obtainUsername(request); - String password = obtainPassword(request); - - if (username == null) { - username = ""; - } - - if (password == null) { - password = ""; - } - + username = (username != null) ? username : ""; username = username.trim(); - + String password = obtainPassword(request); + password = (password != null) ? password : ""; UsernamePasswordAuthenticationToken authRequest = new UsernamePasswordAuthenticationToken(username, password); - // Allow subclasses to set the "details" property setDetails(request, authRequest); - return this.getAuthenticationManager().authenticate(authRequest); } diff --git a/web/src/main/java/org/springframework/security/web/authentication/WebAuthenticationDetails.java b/web/src/main/java/org/springframework/security/web/authentication/WebAuthenticationDetails.java index 24eef04985..41052b4fb9 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/WebAuthenticationDetails.java +++ b/web/src/main/java/org/springframework/security/web/authentication/WebAuthenticationDetails.java @@ -44,7 +44,6 @@ public class WebAuthenticationDetails implements Serializable { */ public WebAuthenticationDetails(HttpServletRequest request) { this.remoteAddress = request.getRemoteAddr(); - HttpSession session = request.getSession(false); this.sessionId = (session != null) ? session.getId() : null; } @@ -62,39 +61,31 @@ public class WebAuthenticationDetails implements Serializable { @Override public boolean equals(Object obj) { if (obj instanceof WebAuthenticationDetails) { - WebAuthenticationDetails rhs = (WebAuthenticationDetails) obj; - - if ((this.remoteAddress == null) && (rhs.getRemoteAddress() != null)) { + WebAuthenticationDetails other = (WebAuthenticationDetails) obj; + if ((this.remoteAddress == null) && (other.getRemoteAddress() != null)) { return false; } - - if ((this.remoteAddress != null) && (rhs.getRemoteAddress() == null)) { + if ((this.remoteAddress != null) && (other.getRemoteAddress() == null)) { return false; } - if (this.remoteAddress != null) { - if (!this.remoteAddress.equals(rhs.getRemoteAddress())) { + if (!this.remoteAddress.equals(other.getRemoteAddress())) { return false; } } - - if ((this.sessionId == null) && (rhs.getSessionId() != null)) { + if ((this.sessionId == null) && (other.getSessionId() != null)) { return false; } - - if ((this.sessionId != null) && (rhs.getSessionId() == null)) { + if ((this.sessionId != null) && (other.getSessionId() == null)) { return false; } - if (this.sessionId != null) { - if (!this.sessionId.equals(rhs.getSessionId())) { + if (!this.sessionId.equals(other.getSessionId())) { return false; } } - return true; } - return false; } @@ -118,15 +109,12 @@ public class WebAuthenticationDetails implements Serializable { @Override public int hashCode() { int code = 7654; - if (this.remoteAddress != null) { code = code * (this.remoteAddress.hashCode() % 7); } - if (this.sessionId != null) { code = code * (this.sessionId.hashCode() % 7); } - return code; } @@ -136,7 +124,6 @@ public class WebAuthenticationDetails implements Serializable { sb.append(super.toString()).append(": "); sb.append("RemoteIpAddress: ").append(this.getRemoteAddress()).append("; "); sb.append("SessionId: ").append(this.getSessionId()); - return sb.toString(); } diff --git a/web/src/main/java/org/springframework/security/web/authentication/logout/CookieClearingLogoutHandler.java b/web/src/main/java/org/springframework/security/web/authentication/logout/CookieClearingLogoutHandler.java index 70272a7893..e47c07dce1 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/logout/CookieClearingLogoutHandler.java +++ b/web/src/main/java/org/springframework/security/web/authentication/logout/CookieClearingLogoutHandler.java @@ -43,15 +43,14 @@ public final class CookieClearingLogoutHandler implements LogoutHandler { Assert.notNull(cookiesToClear, "List of cookies cannot be null"); List> cookieList = new ArrayList<>(); for (String cookieName : cookiesToClear) { - Function f = (request) -> { + cookieList.add((request) -> { Cookie cookie = new Cookie(cookieName, null); String cookiePath = request.getContextPath() + "/"; cookie.setPath(cookiePath); cookie.setMaxAge(0); cookie.setSecure(request.isSecure()); return cookie; - }; - cookieList.add(f); + }); } this.cookiesToClear = cookieList; } @@ -65,8 +64,7 @@ public final class CookieClearingLogoutHandler implements LogoutHandler { List> cookieList = new ArrayList<>(); for (Cookie cookie : cookiesToClear) { Assert.isTrue(cookie.getMaxAge() == 0, "Cookie maxAge must be 0"); - Function f = (request) -> cookie; - cookieList.add(f); + cookieList.add((request) -> cookie); } this.cookiesToClear = cookieList; } diff --git a/web/src/main/java/org/springframework/security/web/authentication/logout/LogoutFilter.java b/web/src/main/java/org/springframework/security/web/authentication/logout/LogoutFilter.java index ff9ad94f07..54ec8ee6bb 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/logout/LogoutFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/logout/LogoutFilter.java @@ -25,6 +25,7 @@ import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.springframework.core.log.LogMessage; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.util.UrlUtils; @@ -83,25 +84,20 @@ public class LogoutFilter extends GenericFilterBean { } @Override - public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - HttpServletRequest request = (HttpServletRequest) req; - HttpServletResponse response = (HttpServletResponse) res; + doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain); + } + private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) + throws IOException, ServletException { if (requiresLogout(request, response)) { Authentication auth = SecurityContextHolder.getContext().getAuthentication(); - - if (this.logger.isDebugEnabled()) { - this.logger.debug("Logging out user '" + auth + "' and transferring to logout destination"); - } - + this.logger.debug(LogMessage.format("Logging out user '%s' and transferring to logout destination", auth)); this.handler.logout(request, response, auth); - this.logoutSuccessHandler.onLogoutSuccess(request, response, auth); - return; } - chain.doFilter(request, response); } diff --git a/web/src/main/java/org/springframework/security/web/authentication/logout/SecurityContextLogoutHandler.java b/web/src/main/java/org/springframework/security/web/authentication/logout/SecurityContextLogoutHandler.java index de82f99994..76c5104553 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/logout/SecurityContextLogoutHandler.java +++ b/web/src/main/java/org/springframework/security/web/authentication/logout/SecurityContextLogoutHandler.java @@ -23,6 +23,7 @@ import javax.servlet.http.HttpSession; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; @@ -61,16 +62,14 @@ public class SecurityContextLogoutHandler implements LogoutHandler { if (this.invalidateHttpSession) { HttpSession session = request.getSession(false); if (session != null) { - this.logger.debug("Invalidating session: " + session.getId()); + this.logger.debug(LogMessage.format("Invalidating session: %s", session.getId())); session.invalidate(); } } - if (this.clearAuthentication) { SecurityContext context = SecurityContextHolder.getContext(); context.setAuthentication(null); } - SecurityContextHolder.clearContext(); } diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java index 65b1fa56ed..b31cd91473 100755 --- a/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java @@ -28,6 +28,7 @@ import javax.servlet.http.HttpSession; import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisherAware; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.event.InteractiveAuthenticationSuccessEvent; @@ -124,16 +125,11 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - - if (this.logger.isDebugEnabled()) { - this.logger - .debug("Checking secure context token: " + SecurityContextHolder.getContext().getAuthentication()); - } - + this.logger.debug(LogMessage + .of(() -> "Checking secure context token: " + SecurityContextHolder.getContext().getAuthentication())); if (this.requiresAuthenticationRequestMatcher.matches((HttpServletRequest) request)) { doAuthenticate((HttpServletRequest) request, (HttpServletResponse) response); } - chain.doFilter(request, response); } @@ -156,21 +152,15 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi * @return true if the principal has changed, else false */ protected boolean principalChanged(HttpServletRequest request, Authentication currentAuthentication) { - Object principal = getPreAuthenticatedPrincipal(request); - if ((principal instanceof String) && currentAuthentication.getName().equals(principal)) { return false; } - if (principal != null && principal.equals(currentAuthentication.getPrincipal())) { return false; } - - if (this.logger.isDebugEnabled()) { - this.logger - .debug("Pre-authenticated principal has changed to " + principal + " and will be reauthenticated"); - } + this.logger.debug(LogMessage.format("Pre-authenticated principal has changed to %s and will be reauthenticated", + principal)); return true; } @@ -179,35 +169,24 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi */ private void doAuthenticate(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException { - Authentication authResult; - Object principal = getPreAuthenticatedPrincipal(request); - Object credentials = getPreAuthenticatedCredentials(request); - if (principal == null) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("No pre-authenticated principal found in request"); - } - + this.logger.debug("No pre-authenticated principal found in request"); return; } - - if (this.logger.isDebugEnabled()) { - this.logger.debug("preAuthenticatedPrincipal = " + principal + ", trying to authenticate"); - } - + this.logger.debug(LogMessage.format("preAuthenticatedPrincipal = %s, trying to authenticate", principal)); + Object credentials = getPreAuthenticatedCredentials(request); try { - PreAuthenticatedAuthenticationToken authRequest = new PreAuthenticatedAuthenticationToken(principal, - credentials); - authRequest.setDetails(this.authenticationDetailsSource.buildDetails(request)); - authResult = this.authenticationManager.authenticate(authRequest); - successfulAuthentication(request, response, authResult); + PreAuthenticatedAuthenticationToken authenticationRequest = new PreAuthenticatedAuthenticationToken( + principal, credentials); + authenticationRequest.setDetails(this.authenticationDetailsSource.buildDetails(request)); + Authentication authenticationResult = this.authenticationManager.authenticate(authenticationRequest); + successfulAuthentication(request, response, authenticationResult); } - catch (AuthenticationException failed) { - unsuccessfulAuthentication(request, response, failed); - + catch (AuthenticationException ex) { + unsuccessfulAuthentication(request, response, ex); if (!this.continueFilterChainOnUnsuccessfulAuthentication) { - throw failed; + throw ex; } } } @@ -218,15 +197,11 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi */ protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, Authentication authResult) throws IOException, ServletException { - if (this.logger.isDebugEnabled()) { - this.logger.debug("Authentication success: " + authResult); - } + this.logger.debug(LogMessage.format("Authentication success: %s", authResult)); SecurityContextHolder.getContext().setAuthentication(authResult); - // Fire event if (this.eventPublisher != null) { this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass())); } - if (this.authenticationSuccessHandler != null) { this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authResult); } @@ -241,12 +216,8 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, AuthenticationException failed) throws IOException, ServletException { SecurityContextHolder.clearContext(); - - if (this.logger.isDebugEnabled()) { - this.logger.debug("Cleared security context due to exception", failed); - } + this.logger.debug("Cleared security context due to exception", failed); request.setAttribute(WebAttributes.AUTHENTICATION_EXCEPTION, failed); - if (this.authenticationFailureHandler != null) { this.authenticationFailureHandler.onAuthenticationFailure(request, response, failed); } @@ -355,36 +326,27 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi @Override public boolean matches(HttpServletRequest request) { - Authentication currentUser = SecurityContextHolder.getContext().getAuthentication(); - if (currentUser == null) { return true; } - if (!AbstractPreAuthenticatedProcessingFilter.this.checkForPrincipalChanges) { return false; } - if (!principalChanged(request, currentUser)) { return false; } - AbstractPreAuthenticatedProcessingFilter.this.logger .debug("Pre-authenticated principal has changed and will be reauthenticated"); - if (AbstractPreAuthenticatedProcessingFilter.this.invalidateSessionOnPrincipalChange) { SecurityContextHolder.clearContext(); - HttpSession session = request.getSession(false); - if (session != null) { AbstractPreAuthenticatedProcessingFilter.this.logger.debug("Invalidating existing session"); session.invalidate(); request.getSession(); } } - return true; } diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/PreAuthenticatedAuthenticationProvider.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/PreAuthenticatedAuthenticationProvider.java index f34e36fbaa..320797d077 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/preauth/PreAuthenticatedAuthenticationProvider.java +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/PreAuthenticatedAuthenticationProvider.java @@ -21,6 +21,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.InitializingBean; import org.springframework.core.Ordered; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AccountStatusUserDetailsChecker; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.BadCredentialsException; @@ -50,11 +51,11 @@ public class PreAuthenticatedAuthenticationProvider implements AuthenticationPro private static final Log logger = LogFactory.getLog(PreAuthenticatedAuthenticationProvider.class); - private AuthenticationUserDetailsService preAuthenticatedUserDetailsService = null; + private AuthenticationUserDetailsService preAuthenticatedUserDetailsService; private UserDetailsChecker userDetailsChecker = new AccountStatusUserDetailsChecker(); - private boolean throwExceptionWhenTokenRejected = false; + private boolean throwExceptionWhenTokenRejected; private int order = -1; // default: same as non-ordered @@ -77,38 +78,27 @@ public class PreAuthenticatedAuthenticationProvider implements AuthenticationPro if (!supports(authentication.getClass())) { return null; } - - if (logger.isDebugEnabled()) { - logger.debug("PreAuthenticated authentication request: " + authentication); - } - + logger.debug(LogMessage.format("PreAuthenticated authentication request: %s", authentication)); if (authentication.getPrincipal() == null) { logger.debug("No pre-authenticated principal found in request."); - if (this.throwExceptionWhenTokenRejected) { throw new BadCredentialsException("No pre-authenticated principal found in request."); } return null; } - if (authentication.getCredentials() == null) { logger.debug("No pre-authenticated credentials found in request."); - if (this.throwExceptionWhenTokenRejected) { throw new BadCredentialsException("No pre-authenticated credentials found in request."); } return null; } - - UserDetails ud = this.preAuthenticatedUserDetailsService + UserDetails userDetails = this.preAuthenticatedUserDetailsService .loadUserDetails((PreAuthenticatedAuthenticationToken) authentication); - - this.userDetailsChecker.check(ud); - - PreAuthenticatedAuthenticationToken result = new PreAuthenticatedAuthenticationToken(ud, - authentication.getCredentials(), ud.getAuthorities()); + this.userDetailsChecker.check(userDetails); + PreAuthenticatedAuthenticationToken result = new PreAuthenticatedAuthenticationToken(userDetails, + authentication.getCredentials(), userDetails.getAuthorities()); result.setDetails(authentication.getDetails()); - return result; } diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails.java index f104f2d514..f9534b5a90 100755 --- a/web/src/main/java/org/springframework/security/web/authentication/preauth/PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails.java +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails.java @@ -46,7 +46,6 @@ public class PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails extends public PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails(HttpServletRequest request, Collection authorities) { super(request); - List temp = new ArrayList<>(authorities.size()); temp.addAll(authorities); this.authorities = Collections.unmodifiableList(temp); diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/RequestAttributeAuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/RequestAttributeAuthenticationFilter.java index 7eb9693acc..6d9176a230 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/preauth/RequestAttributeAuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/RequestAttributeAuthenticationFilter.java @@ -59,12 +59,10 @@ public class RequestAttributeAuthenticationFilter extends AbstractPreAuthenticat @Override protected Object getPreAuthenticatedPrincipal(HttpServletRequest request) { String principal = (String) request.getAttribute(this.principalEnvironmentVariable); - if (principal == null && this.exceptionIfVariableMissing) { throw new PreAuthenticatedCredentialsNotFoundException( this.principalEnvironmentVariable + " variable not found in request."); } - return principal; } @@ -78,7 +76,6 @@ public class RequestAttributeAuthenticationFilter extends AbstractPreAuthenticat if (this.credentialsEnvironmentVariable != null) { return request.getAttribute(this.credentialsEnvironmentVariable); } - return "N/A"; } diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/RequestHeaderAuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/RequestHeaderAuthenticationFilter.java index 0af0c4e55f..4f5fbee3f5 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/preauth/RequestHeaderAuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/RequestHeaderAuthenticationFilter.java @@ -60,12 +60,10 @@ public class RequestHeaderAuthenticationFilter extends AbstractPreAuthenticatedP @Override protected Object getPreAuthenticatedPrincipal(HttpServletRequest request) { String principal = request.getHeader(this.principalRequestHeader); - if (principal == null && this.exceptionIfHeaderMissing) { throw new PreAuthenticatedCredentialsNotFoundException( this.principalRequestHeader + " header not found in request."); } - return principal; } @@ -79,7 +77,6 @@ public class RequestHeaderAuthenticationFilter extends AbstractPreAuthenticatedP if (this.credentialsRequestHeader != null) { return request.getHeader(this.credentialsRequestHeader); } - return "N/A"; } diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/j2ee/J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/j2ee/J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource.java index b819a575d6..4fbd010402 100755 --- a/web/src/main/java/org/springframework/security/web/authentication/preauth/j2ee/J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource.java +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/j2ee/J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource.java @@ -27,6 +27,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.InitializingBean; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.mapping.Attributes2GrantedAuthoritiesMapper; @@ -76,13 +77,11 @@ public class J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource implements */ protected Collection getUserRoles(HttpServletRequest request) { ArrayList j2eeUserRolesList = new ArrayList<>(); - for (String role : this.j2eeMappableRoles) { if (request.isUserInRole(role)) { j2eeUserRolesList.add(role); } } - return j2eeUserRolesList; } @@ -93,19 +92,14 @@ public class J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource implements */ @Override public PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails buildDetails(HttpServletRequest context) { - Collection j2eeUserRoles = getUserRoles(context); - Collection userGas = this.j2eeUserRoles2GrantedAuthoritiesMapper + Collection userGrantedAuthorities = this.j2eeUserRoles2GrantedAuthoritiesMapper .getGrantedAuthorities(j2eeUserRoles); - if (this.logger.isDebugEnabled()) { - this.logger.debug("J2EE roles [" + j2eeUserRoles + "] mapped to Granted Authorities: [" + userGas + "]"); + this.logger.debug(LogMessage.format("J2EE roles [%s] mapped to Granted Authorities: [%s]", j2eeUserRoles, + userGrantedAuthorities)); } - - PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails result = new PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails( - context, userGas); - - return result; + return new PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails(context, userGrantedAuthorities); } /** diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/j2ee/J2eePreAuthenticatedProcessingFilter.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/j2ee/J2eePreAuthenticatedProcessingFilter.java index 4fce63bd38..6d5a5dfa52 100755 --- a/web/src/main/java/org/springframework/security/web/authentication/preauth/j2ee/J2eePreAuthenticatedProcessingFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/j2ee/J2eePreAuthenticatedProcessingFilter.java @@ -18,6 +18,7 @@ package org.springframework.security.web.authentication.preauth.j2ee; import javax.servlet.http.HttpServletRequest; +import org.springframework.core.log.LogMessage; import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter; /** @@ -36,9 +37,7 @@ public class J2eePreAuthenticatedProcessingFilter extends AbstractPreAuthenticat @Override protected Object getPreAuthenticatedPrincipal(HttpServletRequest httpRequest) { Object principal = (httpRequest.getUserPrincipal() != null) ? httpRequest.getUserPrincipal().getName() : null; - if (this.logger.isDebugEnabled()) { - this.logger.debug("PreAuthenticated J2EE principal: " + principal); - } + this.logger.debug(LogMessage.format("PreAuthenticated J2EE principal: %s", principal)); return principal; } diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/j2ee/WebXmlMappableAttributesRetriever.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/j2ee/WebXmlMappableAttributesRetriever.java index a7b685853f..18d01fb706 100755 --- a/web/src/main/java/org/springframework/security/web/authentication/preauth/j2ee/WebXmlMappableAttributesRetriever.java +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/j2ee/WebXmlMappableAttributesRetriever.java @@ -22,6 +22,7 @@ import java.io.StringReader; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Set; import javax.xml.parsers.DocumentBuilder; @@ -43,6 +44,7 @@ import org.springframework.context.ResourceLoaderAware; import org.springframework.core.io.Resource; import org.springframework.core.io.ResourceLoader; import org.springframework.security.core.authority.mapping.MappableAttributesRetriever; +import org.springframework.util.Assert; /** * This MappableAttributesRetriever implementation reads the list of defined J2EE @@ -82,17 +84,17 @@ public class WebXmlMappableAttributesRetriever Resource webXml = this.resourceLoader.getResource("/WEB-INF/web.xml"); Document doc = getDocument(webXml.getInputStream()); NodeList webApp = doc.getElementsByTagName("web-app"); - if (webApp.getLength() != 1) { - throw new IllegalArgumentException("Failed to find 'web-app' element in resource" + webXml); - } + Assert.isTrue(webApp.getLength() == 1, () -> "Failed to find 'web-app' element in resource" + webXml); NodeList securityRoles = ((Element) webApp.item(0)).getElementsByTagName("security-role"); + List roleNames = getRoleNames(webXml, securityRoles); + this.mappableAttributes = Collections.unmodifiableSet(new HashSet<>(roleNames)); + } + private List getRoleNames(Resource webXml, NodeList securityRoles) { ArrayList roleNames = new ArrayList<>(); - for (int i = 0; i < securityRoles.getLength(); i++) { - Element secRoleElt = (Element) securityRoles.item(i); - NodeList roles = secRoleElt.getElementsByTagName("role-name"); - + Element securityRoleElement = (Element) securityRoles.item(i); + NodeList roles = securityRoleElement.getElementsByTagName("role-name"); if (roles.getLength() > 0) { String roleName = roles.item(0).getTextContent().trim(); roleNames.add(roleName); @@ -102,22 +104,19 @@ public class WebXmlMappableAttributesRetriever this.logger.info("No security-role elements found in " + webXml); } } - - this.mappableAttributes = Collections.unmodifiableSet(new HashSet<>(roleNames)); + return roleNames; } /** * @return Document for the specified InputStream */ private Document getDocument(InputStream aStream) { - Document doc; try { DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance(); factory.setValidating(false); - DocumentBuilder db = factory.newDocumentBuilder(); - db.setEntityResolver(new MyEntityResolver()); - doc = db.parse(aStream); - return doc; + DocumentBuilder builder = factory.newDocumentBuilder(); + builder.setEntityResolver(new MyEntityResolver()); + return builder.parse(aStream); } catch (FactoryConfigurationError | IOException | SAXException | ParserConfigurationException ex) { throw new RuntimeException("Unable to parse document object", ex); diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/websphere/DefaultWASUsernameAndGroupsExtractor.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/websphere/DefaultWASUsernameAndGroupsExtractor.java index a95723a3a4..a934a0e35e 100755 --- a/web/src/main/java/org/springframework/security/web/authentication/preauth/websphere/DefaultWASUsernameAndGroupsExtractor.java +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/websphere/DefaultWASUsernameAndGroupsExtractor.java @@ -31,6 +31,8 @@ import javax.security.auth.Subject; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; + /** * WebSphere Security helper class to allow retrieval of the current username and groups. *

@@ -75,9 +77,7 @@ final class DefaultWASUsernameAndGroupsExtractor implements WASUsernameAndGroups * @return String the security name for the given subject */ private static String getSecurityName(final Subject subject) { - if (logger.isDebugEnabled()) { - logger.debug("Determining Websphere security name for subject " + subject); - } + logger.debug(LogMessage.format("Determining Websphere security name for subject %s", subject)); String userSecurityName = null; if (subject != null) { // SEC-803 @@ -86,9 +86,7 @@ final class DefaultWASUsernameAndGroupsExtractor implements WASUsernameAndGroups userSecurityName = (String) invokeMethod(getSecurityNameMethod(), credential); } } - if (logger.isDebugEnabled()) { - logger.debug("Websphere security name is " + userSecurityName + " for subject " + subject); - } + logger.debug(LogMessage.format("Websphere security name is %s for subject %s", subject, userSecurityName)); return userSecurityName; } @@ -119,69 +117,56 @@ final class DefaultWASUsernameAndGroupsExtractor implements WASUsernameAndGroups */ @SuppressWarnings("unchecked") private static List getWebSphereGroups(final String securityName) { - Context ic = null; + Context context = null; try { // TODO: Cache UserRegistry object - ic = new InitialContext(); - Object objRef = ic.lookup(USER_REGISTRY); + context = new InitialContext(); + Object objRef = context.lookup(USER_REGISTRY); Object userReg = invokeMethod(getNarrowMethod(), null, objRef, Class.forName("com.ibm.websphere.security.UserRegistry")); - if (logger.isDebugEnabled()) { - logger.debug("Determining WebSphere groups for user " + securityName + " using WebSphere UserRegistry " - + userReg); - } - final Collection groups = (Collection) invokeMethod(getGroupsForUserMethod(), userReg, + logger.debug(LogMessage.format("Determining WebSphere groups for user %s using WebSphere UserRegistry %s", + securityName, userReg)); + final Collection groups = (Collection) invokeMethod(getGroupsForUserMethod(), userReg, new Object[] { securityName }); - if (logger.isDebugEnabled()) { - logger.debug("Groups for user " + securityName + ": " + groups.toString()); - } - - return new ArrayList(groups); + logger.debug(LogMessage.format("Groups for user %s: %s", securityName, groups)); + return new ArrayList(groups); } catch (Exception ex) { logger.error("Exception occured while looking up groups for user", ex); throw new RuntimeException("Exception occured while looking up groups for user", ex); } finally { - try { - if (ic != null) { - ic.close(); - } - } - catch (NamingException ex) { - logger.debug("Exception occured while closing context", ex); + closeContext(context); + } + } + + private static void closeContext(Context context) { + try { + if (context != null) { + context.close(); } } + catch (NamingException ex) { + logger.debug("Exception occured while closing context", ex); + } } private static Object invokeMethod(Method method, Object instance, Object... args) { try { return method.invoke(instance, args); } - catch (IllegalArgumentException ex) { - logger.error("Error while invoking method " + method.getClass().getName() + "." + method.getName() + "(" - + Arrays.asList(args) + ")", ex); - throw new RuntimeException("Error while invoking method " + method.getClass().getName() + "." - + method.getName() + "(" + Arrays.asList(args) + ")", ex); - } - catch (IllegalAccessException ex) { - logger.error("Error while invoking method " + method.getClass().getName() + "." + method.getName() + "(" - + Arrays.asList(args) + ")", ex); - throw new RuntimeException("Error while invoking method " + method.getClass().getName() + "." - + method.getName() + "(" + Arrays.asList(args) + ")", ex); - } - catch (InvocationTargetException ex) { - logger.error("Error while invoking method " + method.getClass().getName() + "." + method.getName() + "(" - + Arrays.asList(args) + ")", ex); - throw new RuntimeException("Error while invoking method " + method.getClass().getName() + "." - + method.getName() + "(" + Arrays.asList(args) + ")", ex); + catch (IllegalArgumentException | IllegalAccessException | InvocationTargetException ex) { + String message = "Error while invoking method " + method.getClass().getName() + "." + method.getName() + "(" + + Arrays.asList(args) + ")"; + logger.error(message, ex); + throw new RuntimeException(message, ex); } } private static Method getMethod(String className, String methodName, String[] parameterTypeNames) { try { Class c = Class.forName(className); - final int len = parameterTypeNames.length; + int len = parameterTypeNames.length; Class[] parameterTypes = new Class[len]; for (int i = 0; i < len; i++) { parameterTypes[i] = Class.forName(parameterTypeNames[i]); diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/websphere/WebSpherePreAuthenticatedProcessingFilter.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/websphere/WebSpherePreAuthenticatedProcessingFilter.java index f867d52a82..e58ee0b182 100755 --- a/web/src/main/java/org/springframework/security/web/authentication/preauth/websphere/WebSpherePreAuthenticatedProcessingFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/websphere/WebSpherePreAuthenticatedProcessingFilter.java @@ -18,6 +18,7 @@ package org.springframework.security.web.authentication.preauth.websphere; import javax.servlet.http.HttpServletRequest; +import org.springframework.core.log.LogMessage; import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter; /** @@ -51,9 +52,7 @@ public class WebSpherePreAuthenticatedProcessingFilter extends AbstractPreAuthen @Override protected Object getPreAuthenticatedPrincipal(HttpServletRequest httpRequest) { Object principal = this.wasHelper.getCurrentUserName(); - if (this.logger.isDebugEnabled()) { - this.logger.debug("PreAuthenticated WebSphere principal: " + principal); - } + this.logger.debug(LogMessage.format("PreAuthenticated WebSphere principal: %s", principal)); return principal; } diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/websphere/WebSpherePreAuthenticatedWebAuthenticationDetailsSource.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/websphere/WebSpherePreAuthenticatedWebAuthenticationDetailsSource.java index fdb6f1a3aa..5f52f928f6 100755 --- a/web/src/main/java/org/springframework/security/web/authentication/preauth/websphere/WebSpherePreAuthenticatedWebAuthenticationDetailsSource.java +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/websphere/WebSpherePreAuthenticatedWebAuthenticationDetailsSource.java @@ -24,6 +24,7 @@ import javax.servlet.http.HttpServletRequest; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.mapping.Attributes2GrantedAuthoritiesMapper; @@ -68,9 +69,8 @@ public class WebSpherePreAuthenticatedWebAuthenticationDetailsSource implements List webSphereGroups = this.wasHelper.getGroupsForCurrentUser(); Collection userGas = this.webSphereGroups2GrantedAuthoritiesMapper .getGrantedAuthorities(webSphereGroups); - if (this.logger.isDebugEnabled()) { - this.logger.debug("WebSphere groups: " + webSphereGroups + " mapped to Granted Authorities: " + userGas); - } + this.logger.debug( + LogMessage.format("WebSphere groups: %s mapped to Granted Authorities: %s", webSphereGroups, userGas)); return userGas; } diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/x509/SubjectDnX509PrincipalExtractor.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/x509/SubjectDnX509PrincipalExtractor.java index e590ff70b2..56e5c33b9e 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/preauth/x509/SubjectDnX509PrincipalExtractor.java +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/x509/SubjectDnX509PrincipalExtractor.java @@ -25,6 +25,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.context.MessageSource; import org.springframework.context.support.MessageSourceAccessor; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.core.SpringSecurityMessageSource; import org.springframework.util.Assert; @@ -58,24 +59,15 @@ public class SubjectDnX509PrincipalExtractor implements X509PrincipalExtractor { public Object extractPrincipal(X509Certificate clientCert) { // String subjectDN = clientCert.getSubjectX500Principal().getName(); String subjectDN = clientCert.getSubjectDN().getName(); - - this.logger.debug("Subject DN is '" + subjectDN + "'"); - + this.logger.debug(LogMessage.format("Subject DN is '%s'", subjectDN)); Matcher matcher = this.subjectDnPattern.matcher(subjectDN); - if (!matcher.find()) { throw new BadCredentialsException(this.messages.getMessage("SubjectDnX509PrincipalExtractor.noMatching", new Object[] { subjectDN }, "No matching pattern was found in subject DN: {0}")); } - - if (matcher.groupCount() != 1) { - throw new IllegalArgumentException("Regular expression must contain a single group "); - } - + Assert.isTrue(matcher.groupCount() == 1, "Regular expression must contain a single group "); String username = matcher.group(1); - - this.logger.debug("Extracted Principal name is '" + username + "'"); - + this.logger.debug(LogMessage.format("Extracted Principal name is '%s'", username)); return username; } diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/x509/X509AuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/x509/X509AuthenticationFilter.java index cdb4190ded..1692df44d0 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/preauth/x509/X509AuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/x509/X509AuthenticationFilter.java @@ -20,6 +20,7 @@ import java.security.cert.X509Certificate; import javax.servlet.http.HttpServletRequest; +import org.springframework.core.log.LogMessage; import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter; /** @@ -32,12 +33,7 @@ public class X509AuthenticationFilter extends AbstractPreAuthenticatedProcessing @Override protected Object getPreAuthenticatedPrincipal(HttpServletRequest request) { X509Certificate cert = extractClientCertificate(request); - - if (cert == null) { - return null; - } - - return this.principalExtractor.extractPrincipal(cert); + return (cert != null) ? this.principalExtractor.extractPrincipal(cert) : null; } @Override @@ -47,19 +43,11 @@ public class X509AuthenticationFilter extends AbstractPreAuthenticatedProcessing private X509Certificate extractClientCertificate(HttpServletRequest request) { X509Certificate[] certs = (X509Certificate[]) request.getAttribute("javax.servlet.request.X509Certificate"); - if (certs != null && certs.length > 0) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("X.509 client authentication certificate:" + certs[0]); - } - + this.logger.debug(LogMessage.format("X.509 client authentication certificate:%s", certs[0])); return certs[0]; } - - if (this.logger.isDebugEnabled()) { - this.logger.debug("No client certificate found in request."); - } - + this.logger.debug("No client certificate found in request."); return null; } diff --git a/web/src/main/java/org/springframework/security/web/authentication/rememberme/AbstractRememberMeServices.java b/web/src/main/java/org/springframework/security/web/authentication/rememberme/AbstractRememberMeServices.java index a250449dc8..1d25067d37 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/rememberme/AbstractRememberMeServices.java +++ b/web/src/main/java/org/springframework/security/web/authentication/rememberme/AbstractRememberMeServices.java @@ -31,6 +31,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.InitializingBean; import org.springframework.context.support.MessageSourceAccessor; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AccountStatusException; import org.springframework.security.authentication.AccountStatusUserDetailsChecker; import org.springframework.security.authentication.AuthenticationDetailsSource; @@ -118,47 +119,38 @@ public abstract class AbstractRememberMeServices implements RememberMeServices, @Override public final Authentication autoLogin(HttpServletRequest request, HttpServletResponse response) { String rememberMeCookie = extractRememberMeCookie(request); - if (rememberMeCookie == null) { return null; } - this.logger.debug("Remember-me cookie detected"); - if (rememberMeCookie.length() == 0) { this.logger.debug("Cookie was empty"); cancelCookie(request, response); return null; } - - UserDetails user = null; - try { String[] cookieTokens = decodeCookie(rememberMeCookie); - user = processAutoLoginCookie(cookieTokens, request, response); + UserDetails user = processAutoLoginCookie(cookieTokens, request, response); this.userDetailsChecker.check(user); - this.logger.debug("Remember-me cookie accepted"); - return createSuccessfulAuthentication(request, user); } - catch (CookieTheftException cte) { + catch (CookieTheftException ex) { cancelCookie(request, response); - throw cte; + throw ex; } - catch (UsernameNotFoundException noUser) { - this.logger.debug("Remember-me login was valid but corresponding user not found.", noUser); + catch (UsernameNotFoundException ex) { + this.logger.debug("Remember-me login was valid but corresponding user not found.", ex); } - catch (InvalidCookieException invalidCookie) { - this.logger.debug("Invalid remember-me cookie: " + invalidCookie.getMessage()); + catch (InvalidCookieException ex) { + this.logger.debug("Invalid remember-me cookie: " + ex.getMessage()); } - catch (AccountStatusException statusInvalid) { - this.logger.debug("Invalid UserDetails: " + statusInvalid.getMessage()); + catch (AccountStatusException ex) { + this.logger.debug("Invalid UserDetails: " + ex.getMessage()); } catch (RememberMeAuthenticationException ex) { this.logger.debug(ex.getMessage()); } - cancelCookie(request, response); return null; } @@ -172,17 +164,14 @@ public abstract class AbstractRememberMeServices implements RememberMeServices, */ protected String extractRememberMeCookie(HttpServletRequest request) { Cookie[] cookies = request.getCookies(); - if ((cookies == null) || (cookies.length == 0)) { return null; } - for (Cookie cookie : cookies) { if (this.cookieName.equals(cookie.getName())) { return cookie.getValue(); } } - return null; } @@ -216,18 +205,14 @@ public abstract class AbstractRememberMeServices implements RememberMeServices, for (int j = 0; j < cookieValue.length() % 4; j++) { cookieValue = cookieValue + "="; } - try { Base64.getDecoder().decode(cookieValue.getBytes()); } catch (IllegalArgumentException ex) { throw new InvalidCookieException("Cookie token was not Base64 encoded; value was '" + cookieValue + "'"); } - String cookieAsPlainText = new String(Base64.getDecoder().decode(cookieValue.getBytes())); - String[] tokens = StringUtils.delimitedListToStringArray(cookieAsPlainText, DELIMITER); - for (int i = 0; i < tokens.length; i++) { try { tokens[i] = URLDecoder.decode(tokens[i], StandardCharsets.UTF_8.toString()); @@ -236,7 +221,6 @@ public abstract class AbstractRememberMeServices implements RememberMeServices, this.logger.error(ex.getMessage(), ex); } } - return tokens; } @@ -254,20 +238,15 @@ public abstract class AbstractRememberMeServices implements RememberMeServices, catch (UnsupportedEncodingException ex) { this.logger.error(ex.getMessage(), ex); } - if (i < cookieTokens.length - 1) { sb.append(DELIMITER); } } - String value = sb.toString(); - sb = new StringBuilder(new String(Base64.getEncoder().encode(value.getBytes()))); - while (sb.charAt(sb.length() - 1) == '=') { sb.deleteCharAt(sb.length() - 1); } - return sb.toString(); } @@ -293,12 +272,10 @@ public abstract class AbstractRememberMeServices implements RememberMeServices, @Override public final void loginSuccess(HttpServletRequest request, HttpServletResponse response, Authentication successfulAuthentication) { - if (!rememberMeRequested(request, this.parameter)) { this.logger.debug("Remember-me login not requested."); return; } - onLoginSuccess(request, response, successfulAuthentication); } @@ -324,20 +301,15 @@ public abstract class AbstractRememberMeServices implements RememberMeServices, if (this.alwaysRemember) { return true; } - String paramValue = request.getParameter(parameter); - if (paramValue != null) { if (paramValue.equalsIgnoreCase("true") || paramValue.equalsIgnoreCase("on") || paramValue.equalsIgnoreCase("yes") || paramValue.equals("1")) { return true; } } - - if (this.logger.isDebugEnabled()) { - this.logger.debug("Did not send remember-me cookie (principal did not set parameter '" + parameter + "')"); - } - + this.logger.debug( + LogMessage.format("Did not send remember-me cookie (principal did not set parameter '%s')", parameter)); return false; } @@ -370,12 +342,7 @@ public abstract class AbstractRememberMeServices implements RememberMeServices, if (this.cookieDomain != null) { cookie.setDomain(this.cookieDomain); } - if (this.useSecureCookie == null) { - cookie.setSecure(request.isSecure()); - } - else { - cookie.setSecure(this.useSecureCookie); - } + cookie.setSecure((this.useSecureCookie != null) ? this.useSecureCookie : request.isSecure()); response.addCookie(cookie); } @@ -402,16 +369,8 @@ public abstract class AbstractRememberMeServices implements RememberMeServices, if (maxAge < 1) { cookie.setVersion(1); } - - if (this.useSecureCookie == null) { - cookie.setSecure(request.isSecure()); - } - else { - cookie.setSecure(this.useSecureCookie); - } - + cookie.setSecure((this.useSecureCookie != null) ? this.useSecureCookie : request.isSecure()); cookie.setHttpOnly(true); - response.addCookie(cookie); } @@ -426,9 +385,8 @@ public abstract class AbstractRememberMeServices implements RememberMeServices, */ @Override public void logout(HttpServletRequest request, HttpServletResponse response, Authentication authentication) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("Logout of user " + ((authentication != null) ? authentication.getName() : "Unknown")); - } + this.logger.debug(LogMessage + .of(() -> "Logout of user " + ((authentication != null) ? authentication.getName() : "Unknown"))); cancelCookie(request, response); } diff --git a/web/src/main/java/org/springframework/security/web/authentication/rememberme/InMemoryTokenRepositoryImpl.java b/web/src/main/java/org/springframework/security/web/authentication/rememberme/InMemoryTokenRepositoryImpl.java index aadf83fc7c..9eecb9778a 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/rememberme/InMemoryTokenRepositoryImpl.java +++ b/web/src/main/java/org/springframework/security/web/authentication/rememberme/InMemoryTokenRepositoryImpl.java @@ -36,21 +36,17 @@ public class InMemoryTokenRepositoryImpl implements PersistentTokenRepository { @Override public synchronized void createNewToken(PersistentRememberMeToken token) { PersistentRememberMeToken current = this.seriesTokens.get(token.getSeries()); - if (current != null) { throw new DataIntegrityViolationException("Series Id '" + token.getSeries() + "' already exists!"); } - this.seriesTokens.put(token.getSeries(), token); } @Override public synchronized void updateToken(String series, String tokenValue, Date lastUsed) { PersistentRememberMeToken token = getTokenForSeries(series); - PersistentRememberMeToken newToken = new PersistentRememberMeToken(token.getUsername(), series, tokenValue, new Date()); - // Store it, overwriting the existing one. this.seriesTokens.put(series, newToken); } @@ -63,12 +59,9 @@ public class InMemoryTokenRepositoryImpl implements PersistentTokenRepository { @Override public synchronized void removeUserTokens(String username) { Iterator series = this.seriesTokens.keySet().iterator(); - while (series.hasNext()) { String seriesId = series.next(); - PersistentRememberMeToken token = this.seriesTokens.get(seriesId); - if (username.equals(token.getUsername())) { series.remove(); } diff --git a/web/src/main/java/org/springframework/security/web/authentication/rememberme/JdbcTokenRepositoryImpl.java b/web/src/main/java/org/springframework/security/web/authentication/rememberme/JdbcTokenRepositoryImpl.java index bc14f0dd54..e33c5e6a3c 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/rememberme/JdbcTokenRepositoryImpl.java +++ b/web/src/main/java/org/springframework/security/web/authentication/rememberme/JdbcTokenRepositoryImpl.java @@ -16,8 +16,11 @@ package org.springframework.security.web.authentication.rememberme; +import java.sql.ResultSet; +import java.sql.SQLException; import java.util.Date; +import org.springframework.core.log.LogMessage; import org.springframework.dao.DataAccessException; import org.springframework.dao.EmptyResultDataAccessException; import org.springframework.dao.IncorrectResultSizeDataAccessException; @@ -87,27 +90,26 @@ public class JdbcTokenRepositoryImpl extends JdbcDaoSupport implements Persisten @Override public PersistentRememberMeToken getTokenForSeries(String seriesId) { try { - return getJdbcTemplate().queryForObject(this.tokensBySeriesSql, - (rs, rowNum) -> new PersistentRememberMeToken(rs.getString(1), rs.getString(2), rs.getString(3), - rs.getTimestamp(4)), - seriesId); + return getJdbcTemplate().queryForObject(this.tokensBySeriesSql, this::createRememberMeToken, seriesId); } - catch (EmptyResultDataAccessException zeroResults) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("Querying token for series '" + seriesId + "' returned no results.", zeroResults); - } + catch (EmptyResultDataAccessException ex) { + this.logger.debug(LogMessage.format("Querying token for series '%s' returned no results.", seriesId), ex); } - catch (IncorrectResultSizeDataAccessException moreThanOne) { - this.logger.error("Querying token for series '" + seriesId + "' returned more than one value. Series" - + " should be unique"); + catch (IncorrectResultSizeDataAccessException ex) { + this.logger.error(LogMessage.format( + "Querying token for series '%s' returned more than one value. Series" + " should be unique", + seriesId)); } catch (DataAccessException ex) { this.logger.error("Failed to load token for series " + seriesId, ex); } - return null; } + private PersistentRememberMeToken createRememberMeToken(ResultSet rs, int rowNum) throws SQLException { + return new PersistentRememberMeToken(rs.getString(1), rs.getString(2), rs.getString(3), rs.getTimestamp(4)); + } + @Override public void removeUserTokens(String username) { getJdbcTemplate().update(this.removeUserTokensSql, username); diff --git a/web/src/main/java/org/springframework/security/web/authentication/rememberme/PersistentTokenBasedRememberMeServices.java b/web/src/main/java/org/springframework/security/web/authentication/rememberme/PersistentTokenBasedRememberMeServices.java index 8288d3e4a5..9914275ca2 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/rememberme/PersistentTokenBasedRememberMeServices.java +++ b/web/src/main/java/org/springframework/security/web/authentication/rememberme/PersistentTokenBasedRememberMeServices.java @@ -24,6 +24,7 @@ import java.util.Date; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.springframework.core.log.LogMessage; import org.springframework.security.core.Authentication; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; @@ -93,47 +94,35 @@ public class PersistentTokenBasedRememberMeServices extends AbstractRememberMeSe @Override protected UserDetails processAutoLoginCookie(String[] cookieTokens, HttpServletRequest request, HttpServletResponse response) { - if (cookieTokens.length != 2) { throw new InvalidCookieException("Cookie token did not contain " + 2 + " tokens, but contained '" + Arrays.asList(cookieTokens) + "'"); } - - final String presentedSeries = cookieTokens[0]; - final String presentedToken = cookieTokens[1]; - + String presentedSeries = cookieTokens[0]; + String presentedToken = cookieTokens[1]; PersistentRememberMeToken token = this.tokenRepository.getTokenForSeries(presentedSeries); - if (token == null) { // No series match, so we can't authenticate using this cookie throw new RememberMeAuthenticationException("No persistent token found for series id: " + presentedSeries); } - // We have a match for this user/series combination if (!presentedToken.equals(token.getTokenValue())) { // Token doesn't match series value. Delete all logins for this user and throw // an exception to warn them. this.tokenRepository.removeUserTokens(token.getUsername()); - throw new CookieTheftException(this.messages.getMessage( "PersistentTokenBasedRememberMeServices.cookieStolen", "Invalid remember-me token (Series/token) mismatch. Implies previous cookie theft attack.")); } - if (token.getDate().getTime() + getTokenValiditySeconds() * 1000L < System.currentTimeMillis()) { throw new RememberMeAuthenticationException("Remember-me login has expired"); } - // Token also matches, so login is valid. Update the token value, keeping the // *same* series number. - if (this.logger.isDebugEnabled()) { - this.logger.debug("Refreshing persistent login token for user '" + token.getUsername() + "', series '" - + token.getSeries() + "'"); - } - + this.logger.debug(LogMessage.format("Refreshing persistent login token for user '%s', series '%s'", + token.getUsername(), token.getSeries())); PersistentRememberMeToken newToken = new PersistentRememberMeToken(token.getUsername(), token.getSeries(), generateTokenData(), new Date()); - try { this.tokenRepository.updateToken(newToken.getSeries(), newToken.getTokenValue(), newToken.getDate()); addCookie(newToken, request, response); @@ -142,7 +131,6 @@ public class PersistentTokenBasedRememberMeServices extends AbstractRememberMeSe this.logger.error("Failed to update token: ", ex); throw new RememberMeAuthenticationException("Autologin failed due to data access problem"); } - return getUserDetailsService().loadUserByUsername(token.getUsername()); } @@ -155,9 +143,7 @@ public class PersistentTokenBasedRememberMeServices extends AbstractRememberMeSe protected void onLoginSuccess(HttpServletRequest request, HttpServletResponse response, Authentication successfulAuthentication) { String username = successfulAuthentication.getName(); - - this.logger.debug("Creating new persistent login for user " + username); - + this.logger.debug(LogMessage.format("Creating new persistent login for user %s", username)); PersistentRememberMeToken persistentToken = new PersistentRememberMeToken(username, generateSeriesData(), generateTokenData(), new Date()); try { @@ -172,7 +158,6 @@ public class PersistentTokenBasedRememberMeServices extends AbstractRememberMeSe @Override public void logout(HttpServletRequest request, HttpServletResponse response, Authentication authentication) { super.logout(request, response, authentication); - if (authentication != null) { this.tokenRepository.removeUserTokens(authentication.getName()); } diff --git a/web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java index 8fb4a860d3..f41fd586a4 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java @@ -27,6 +27,7 @@ import javax.servlet.http.HttpServletResponse; import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisherAware; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.event.InteractiveAuthenticationSuccessEvent; import org.springframework.security.core.Authentication; @@ -86,66 +87,50 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements } @Override - public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - HttpServletRequest request = (HttpServletRequest) req; - HttpServletResponse response = (HttpServletResponse) res; - - if (SecurityContextHolder.getContext().getAuthentication() == null) { - Authentication rememberMeAuth = this.rememberMeServices.autoLogin(request, response); - - if (rememberMeAuth != null) { - // Attempt authenticaton via AuthenticationManager - try { - rememberMeAuth = this.authenticationManager.authenticate(rememberMeAuth); - - // Store to SecurityContextHolder - SecurityContextHolder.getContext().setAuthentication(rememberMeAuth); - - onSuccessfulAuthentication(request, response, rememberMeAuth); - - if (this.logger.isDebugEnabled()) { - this.logger.debug("SecurityContextHolder populated with remember-me token: '" - + SecurityContextHolder.getContext().getAuthentication() + "'"); - } - - // Fire event - if (this.eventPublisher != null) { - this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent( - SecurityContextHolder.getContext().getAuthentication(), this.getClass())); - } - - if (this.successHandler != null) { - this.successHandler.onAuthenticationSuccess(request, response, rememberMeAuth); - - return; - } + doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain); + } + private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) + throws IOException, ServletException { + if (SecurityContextHolder.getContext().getAuthentication() != null) { + this.logger.debug(LogMessage + .of(() -> "SecurityContextHolder not populated with remember-me token, as it already contained: '" + + SecurityContextHolder.getContext().getAuthentication() + "'")); + chain.doFilter(request, response); + return; + } + Authentication rememberMeAuth = this.rememberMeServices.autoLogin(request, response); + if (rememberMeAuth != null) { + // Attempt authenticaton via AuthenticationManager + try { + rememberMeAuth = this.authenticationManager.authenticate(rememberMeAuth); + // Store to SecurityContextHolder + SecurityContextHolder.getContext().setAuthentication(rememberMeAuth); + onSuccessfulAuthentication(request, response, rememberMeAuth); + this.logger.debug(LogMessage.of(() -> "SecurityContextHolder populated with remember-me token: '" + + SecurityContextHolder.getContext().getAuthentication() + "'")); + if (this.eventPublisher != null) { + this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent( + SecurityContextHolder.getContext().getAuthentication(), this.getClass())); } - catch (AuthenticationException authenticationException) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("SecurityContextHolder not populated with remember-me token, as " - + "AuthenticationManager rejected Authentication returned by RememberMeServices: '" - + rememberMeAuth + "'; invalidating remember-me token", authenticationException); - } - - this.rememberMeServices.loginFail(request, response); - - onUnsuccessfulAuthentication(request, response, authenticationException); + if (this.successHandler != null) { + this.successHandler.onAuthenticationSuccess(request, response, rememberMeAuth); + return; } } - - chain.doFilter(request, response); - } - else { - if (this.logger.isDebugEnabled()) { - this.logger - .debug("SecurityContextHolder not populated with remember-me token, as it already contained: '" - + SecurityContextHolder.getContext().getAuthentication() + "'"); + catch (AuthenticationException ex) { + this.logger.debug(LogMessage + .format("SecurityContextHolder not populated with remember-me token, as AuthenticationManager " + + "rejected Authentication returned by RememberMeServices: '%s'; " + + "invalidating remember-me token", rememberMeAuth), + ex); + this.rememberMeServices.loginFail(request, response); + onUnsuccessfulAuthentication(request, response, ex); } - - chain.doFilter(request, response); } + chain.doFilter(request, response); } /** diff --git a/web/src/main/java/org/springframework/security/web/authentication/rememberme/TokenBasedRememberMeServices.java b/web/src/main/java/org/springframework/security/web/authentication/rememberme/TokenBasedRememberMeServices.java index d75d228ed2..2facda2bc1 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/rememberme/TokenBasedRememberMeServices.java +++ b/web/src/main/java/org/springframework/security/web/authentication/rememberme/TokenBasedRememberMeServices.java @@ -90,52 +90,43 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices { @Override protected UserDetails processAutoLoginCookie(String[] cookieTokens, HttpServletRequest request, HttpServletResponse response) { - if (cookieTokens.length != 3) { throw new InvalidCookieException( "Cookie token did not contain 3" + " tokens, but contained '" + Arrays.asList(cookieTokens) + "'"); } + long tokenExpiryTime = getTokenExpiryTime(cookieTokens); + if (isTokenExpired(tokenExpiryTime)) { + throw new InvalidCookieException("Cookie token[1] has expired (expired on '" + new Date(tokenExpiryTime) + + "'; current time is '" + new Date() + "')"); + } + // Check the user exists. Defer lookup until after expiry time checked, to + // possibly avoid expensive database call. + UserDetails userDetails = getUserDetailsService().loadUserByUsername(cookieTokens[0]); + Assert.notNull(userDetails, () -> "UserDetailsService " + getUserDetailsService() + + " returned null for username " + cookieTokens[0] + ". " + "This is an interface contract violation"); + // Check signature of token matches remaining details. Must do this after user + // lookup, as we need the DAO-derived password. If efficiency was a major issue, + // just add in a UserCache implementation, but recall that this method is usually + // only called once per HttpSession - if the token is valid, it will cause + // SecurityContextHolder population, whilst if invalid, will cause the cookie to + // be cancelled. + String expectedTokenSignature = makeTokenSignature(tokenExpiryTime, userDetails.getUsername(), + userDetails.getPassword()); + if (!equals(expectedTokenSignature, cookieTokens[2])) { + throw new InvalidCookieException("Cookie token[2] contained signature '" + cookieTokens[2] + + "' but expected '" + expectedTokenSignature + "'"); + } + return userDetails; + } - long tokenExpiryTime; - + private long getTokenExpiryTime(String[] cookieTokens) { try { - tokenExpiryTime = new Long(cookieTokens[1]); + return new Long(cookieTokens[1]); } catch (NumberFormatException nfe) { throw new InvalidCookieException( "Cookie token[1] did not contain a valid number (contained '" + cookieTokens[1] + "')"); } - - if (isTokenExpired(tokenExpiryTime)) { - throw new InvalidCookieException("Cookie token[1] has expired (expired on '" + new Date(tokenExpiryTime) - + "'; current time is '" + new Date() + "')"); - } - - // Check the user exists. - // Defer lookup until after expiry time checked, to possibly avoid expensive - // database call. - - UserDetails userDetails = getUserDetailsService().loadUserByUsername(cookieTokens[0]); - - Assert.notNull(userDetails, () -> "UserDetailsService " + getUserDetailsService() - + " returned null for username " + cookieTokens[0] + ". " + "This is an interface contract violation"); - - // Check signature of token matches remaining details. - // Must do this after user lookup, as we need the DAO-derived password. - // If efficiency was a major issue, just add in a UserCache implementation, - // but recall that this method is usually only called once per HttpSession - if - // the token is valid, - // it will cause SecurityContextHolder population, whilst if invalid, will cause - // the cookie to be cancelled. - String expectedTokenSignature = makeTokenSignature(tokenExpiryTime, userDetails.getUsername(), - userDetails.getPassword()); - - if (!equals(expectedTokenSignature, cookieTokens[2])) { - throw new InvalidCookieException("Cookie token[2] contained signature '" + cookieTokens[2] - + "' but expected '" + expectedTokenSignature + "'"); - } - - return userDetails; } /** @@ -144,15 +135,13 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices { */ protected String makeTokenSignature(long tokenExpiryTime, String username, String password) { String data = username + ":" + tokenExpiryTime + ":" + password + ":" + getKey(); - MessageDigest digest; try { - digest = MessageDigest.getInstance("MD5"); + MessageDigest digest = MessageDigest.getInstance("MD5"); + return new String(Hex.encode(digest.digest(data.getBytes()))); } catch (NoSuchAlgorithmException ex) { throw new IllegalStateException("No MD5 algorithm available!"); } - - return new String(Hex.encode(digest.digest(data.getBytes()))); } protected boolean isTokenExpired(long tokenExpiryTime) { @@ -162,10 +151,8 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices { @Override public void onLoginSuccess(HttpServletRequest request, HttpServletResponse response, Authentication successfulAuthentication) { - String username = retrieveUserName(successfulAuthentication); String password = retrievePassword(successfulAuthentication); - // If unable to find a username and password, just abort as // TokenBasedRememberMeServices is // unable to construct a valid token in this case. @@ -173,27 +160,21 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices { this.logger.debug("Unable to retrieve username"); return; } - if (!StringUtils.hasLength(password)) { UserDetails user = getUserDetailsService().loadUserByUsername(username); password = user.getPassword(); - if (!StringUtils.hasLength(password)) { this.logger.debug("Unable to obtain password for user: " + username); return; } } - int tokenLifetime = calculateLoginLifetime(request, successfulAuthentication); long expiryTime = System.currentTimeMillis(); // SEC-949 expiryTime += 1000L * ((tokenLifetime < 0) ? TWO_WEEKS_S : tokenLifetime); - String signatureValue = makeTokenSignature(expiryTime, username, password); - setCookie(new String[] { username, Long.toString(expiryTime), signatureValue }, tokenLifetime, request, response); - if (this.logger.isDebugEnabled()) { this.logger.debug( "Added remember-me cookie for user '" + username + "', expiry: '" + new Date(expiryTime) + "'"); @@ -223,21 +204,17 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices { if (isInstanceOfUserDetails(authentication)) { return ((UserDetails) authentication.getPrincipal()).getUsername(); } - else { - return authentication.getPrincipal().toString(); - } + return authentication.getPrincipal().toString(); } protected String retrievePassword(Authentication authentication) { if (isInstanceOfUserDetails(authentication)) { return ((UserDetails) authentication.getPrincipal()).getPassword(); } - else { - if (authentication.getCredentials() == null) { - return null; - } + if (authentication.getCredentials() != null) { return authentication.getCredentials().toString(); } + return null; } private boolean isInstanceOfUserDetails(Authentication authentication) { @@ -250,15 +227,11 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices { private static boolean equals(String expected, String actual) { byte[] expectedBytes = bytesUtf8(expected); byte[] actualBytes = bytesUtf8(actual); - return MessageDigest.isEqual(expectedBytes, actualBytes); } private static byte[] bytesUtf8(String s) { - if (s == null) { - return null; - } - return Utf8.encode(s); + return (s != null) ? Utf8.encode(s) : null; } } diff --git a/web/src/main/java/org/springframework/security/web/authentication/session/AbstractSessionFixationProtectionStrategy.java b/web/src/main/java/org/springframework/security/web/authentication/session/AbstractSessionFixationProtectionStrategy.java index 5ff51a6131..c4f497bf1b 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/session/AbstractSessionFixationProtectionStrategy.java +++ b/web/src/main/java/org/springframework/security/web/authentication/session/AbstractSessionFixationProtectionStrategy.java @@ -73,35 +73,26 @@ public abstract class AbstractSessionFixationProtectionStrategy public void onAuthentication(Authentication authentication, HttpServletRequest request, HttpServletResponse response) { boolean hadSessionAlready = request.getSession(false) != null; - if (!hadSessionAlready && !this.alwaysCreateSession) { // Session fixation isn't a problem if there's no session - return; } - // Create new session if necessary HttpSession session = request.getSession(); - if (hadSessionAlready && request.isRequestedSessionIdValid()) { - String originalSessionId; String newSessionId; Object mutex = WebUtils.getSessionMutex(session); synchronized (mutex) { // We need to migrate to a new session originalSessionId = session.getId(); - session = applySessionFixation(request); newSessionId = session.getId(); } - if (originalSessionId.equals(newSessionId)) { - this.logger.warn( - "Your servlet container did not change the session ID when a new session was created. You will" - + " not be adequately protected against session-fixation attacks"); + this.logger.warn("Your servlet container did not change the session ID when a new session " + + "was created. You will not be adequately protected against session-fixation attacks"); } - onSessionChange(originalSessionId, session, authentication); } } diff --git a/web/src/main/java/org/springframework/security/web/authentication/session/CompositeSessionAuthenticationStrategy.java b/web/src/main/java/org/springframework/security/web/authentication/session/CompositeSessionAuthenticationStrategy.java index d58303fa46..b1e53ec772 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/session/CompositeSessionAuthenticationStrategy.java +++ b/web/src/main/java/org/springframework/security/web/authentication/session/CompositeSessionAuthenticationStrategy.java @@ -25,6 +25,7 @@ import javax.servlet.http.HttpSession; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.security.core.Authentication; import org.springframework.util.Assert; @@ -63,10 +64,7 @@ public class CompositeSessionAuthenticationStrategy implements SessionAuthentica public CompositeSessionAuthenticationStrategy(List delegateStrategies) { Assert.notEmpty(delegateStrategies, "delegateStrategies cannot be null or empty"); for (SessionAuthenticationStrategy strategy : delegateStrategies) { - if (strategy == null) { - throw new IllegalArgumentException( - "delegateStrategies cannot contain null entires. Got " + delegateStrategies); - } + Assert.notNull(strategy, () -> "delegateStrategies cannot contain null entires. Got " + delegateStrategies); } this.delegateStrategies = delegateStrategies; } @@ -75,9 +73,7 @@ public class CompositeSessionAuthenticationStrategy implements SessionAuthentica public void onAuthentication(Authentication authentication, HttpServletRequest request, HttpServletResponse response) throws SessionAuthenticationException { for (SessionAuthenticationStrategy delegate : this.delegateStrategies) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("Delegating to " + delegate); - } + this.logger.debug(LogMessage.format("Delegating to %s", delegate)); delegate.onAuthentication(authentication, request, response); } } diff --git a/web/src/main/java/org/springframework/security/web/authentication/session/ConcurrentSessionControlAuthenticationStrategy.java b/web/src/main/java/org/springframework/security/web/authentication/session/ConcurrentSessionControlAuthenticationStrategy.java index e91ad70f07..7e96cf3d75 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/session/ConcurrentSessionControlAuthenticationStrategy.java +++ b/web/src/main/java/org/springframework/security/web/authentication/session/ConcurrentSessionControlAuthenticationStrategy.java @@ -94,26 +94,19 @@ public class ConcurrentSessionControlAuthenticationStrategy @Override public void onAuthentication(Authentication authentication, HttpServletRequest request, HttpServletResponse response) { - - final List sessions = this.sessionRegistry.getAllSessions(authentication.getPrincipal(), - false); - + List sessions = this.sessionRegistry.getAllSessions(authentication.getPrincipal(), false); int sessionCount = sessions.size(); int allowedSessions = getMaximumSessionsForThisUser(authentication); - if (sessionCount < allowedSessions) { // They haven't got too many login sessions running at present return; } - if (allowedSessions == -1) { // We permit unlimited logins return; } - if (sessionCount == allowedSessions) { HttpSession session = request.getSession(false); - if (session != null) { // Only permit it though if this request is associated with one of the // already registered sessions @@ -126,7 +119,6 @@ public class ConcurrentSessionControlAuthenticationStrategy // If the session is null, a new one will be created by the parent class, // exceeding the allowed number } - allowableSessionsExceeded(sessions, allowedSessions, this.sessionRegistry); } @@ -157,7 +149,6 @@ public class ConcurrentSessionControlAuthenticationStrategy this.messages.getMessage("ConcurrentSessionControlAuthenticationStrategy.exceededAllowed", new Object[] { allowableSessions }, "Maximum sessions of {0} for this principal exceeded")); } - // Determine least recently used sessions, and mark them for invalidation sessions.sort(Comparator.comparing(SessionInformation::getLastRequest)); int maximumSessionsExceededBy = sessions.size() - allowableSessions + 1; diff --git a/web/src/main/java/org/springframework/security/web/authentication/session/SessionFixationProtectionStrategy.java b/web/src/main/java/org/springframework/security/web/authentication/session/SessionFixationProtectionStrategy.java index cd761c28fe..842ba85b77 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/session/SessionFixationProtectionStrategy.java +++ b/web/src/main/java/org/springframework/security/web/authentication/session/SessionFixationProtectionStrategy.java @@ -23,6 +23,8 @@ import java.util.Map; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpSession; +import org.springframework.core.log.LogMessage; + /** * Uses {@code HttpServletRequest.invalidate()} to protect against session fixation * attacks. @@ -82,21 +84,13 @@ public class SessionFixationProtectionStrategy extends AbstractSessionFixationPr final HttpSession applySessionFixation(HttpServletRequest request) { HttpSession session = request.getSession(); String originalSessionId = session.getId(); - if (this.logger.isDebugEnabled()) { - this.logger.debug("Invalidating session with Id '" + originalSessionId + "' " - + (this.migrateSessionAttributes ? "and" : "without") + " migrating attributes."); - } - + this.logger.debug(LogMessage.of(() -> "Invalidating session with Id '" + originalSessionId + "' " + + (this.migrateSessionAttributes ? "and" : "without") + " migrating attributes.")); Map attributesToMigrate = extractAttributes(session); int maxInactiveIntervalToMigrate = session.getMaxInactiveInterval(); - session.invalidate(); session = request.getSession(true); // we now have a new session - - if (this.logger.isDebugEnabled()) { - this.logger.debug("Started new session: " + session.getId()); - } - + this.logger.debug(LogMessage.format("Started new session: %s", session.getId())); transferAttributes(attributesToMigrate, session); if (this.migrateSessionAttributes) { session.setMaxInactiveInterval(maxInactiveIntervalToMigrate); @@ -111,27 +105,22 @@ public class SessionFixationProtectionStrategy extends AbstractSessionFixationPr */ void transferAttributes(Map attributes, HttpSession newSession) { if (attributes != null) { - for (Map.Entry entry : attributes.entrySet()) { - newSession.setAttribute(entry.getKey(), entry.getValue()); - } + attributes.forEach(newSession::setAttribute); } } @SuppressWarnings("unchecked") private HashMap createMigratedAttributeMap(HttpSession session) { HashMap attributesToMigrate = new HashMap<>(); - - Enumeration enumer = session.getAttributeNames(); - - while (enumer.hasMoreElements()) { - String key = (String) enumer.nextElement(); + Enumeration enumeration = session.getAttributeNames(); + while (enumeration.hasMoreElements()) { + String key = enumeration.nextElement(); if (!this.migrateSessionAttributes && !key.startsWith("SPRING_SECURITY_")) { // Only retain Spring Security attributes continue; } attributesToMigrate.put(key, session.getAttribute(key)); } - return attributesToMigrate; } diff --git a/web/src/main/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilter.java b/web/src/main/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilter.java index 58e1e4fb44..6fff566ed7 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilter.java @@ -34,6 +34,7 @@ import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.context.MessageSource; import org.springframework.context.MessageSourceAware; import org.springframework.context.support.MessageSourceAccessor; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AccountExpiredException; import org.springframework.security.authentication.AccountStatusUserDetailsChecker; import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; @@ -149,7 +150,6 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv Assert.isNull(this.successHandler, "You cannot set both successHandler and targetUrl"); this.successHandler = new SimpleUrlAuthenticationSuccessHandler(this.targetUrl); } - if (this.failureHandler == null) { this.failureHandler = (this.switchFailureUrl != null) ? new SimpleUrlAuthenticationFailureHandler(this.switchFailureUrl) @@ -161,20 +161,20 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv } @Override - public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - HttpServletRequest request = (HttpServletRequest) req; - HttpServletResponse response = (HttpServletResponse) res; + doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain); + } + private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) + throws IOException, ServletException { // check for switch or exit request if (requiresSwitchUser(request)) { // if set, attempt switch and store original try { Authentication targetUser = attemptSwitchUser(request); - // update the current context to the new target user SecurityContextHolder.getContext().setAuthentication(targetUser); - // redirect to target url this.successHandler.onAuthenticationSuccess(request, response, targetUser); } @@ -182,22 +182,17 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv this.logger.debug("Switch User failed", ex); this.failureHandler.onAuthenticationFailure(request, response, ex); } - return; } - else if (requiresExitUser(request)) { + if (requiresExitUser(request)) { // get the original authentication object (if exists) Authentication originalUser = attemptExitUser(request); - // update the current context back to the original user SecurityContextHolder.getContext().setAuthentication(originalUser); - // redirect to target url this.successHandler.onAuthenticationSuccess(request, response, originalUser); - return; } - chain.doFilter(request, response); } @@ -214,33 +209,19 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv */ protected Authentication attemptSwitchUser(HttpServletRequest request) throws AuthenticationException { UsernamePasswordAuthenticationToken targetUserRequest; - String username = request.getParameter(this.usernameParameter); - - if (username == null) { - username = ""; - } - - if (this.logger.isDebugEnabled()) { - this.logger.debug("Attempt to switch to user [" + username + "]"); - } - + username = (username != null) ? username : ""; + this.logger.debug(LogMessage.format("Attempt to switch to user [%s]", username)); UserDetails targetUser = this.userDetailsService.loadUserByUsername(username); this.userDetailsChecker.check(targetUser); - // OK, create the switch user token targetUserRequest = createSwitchUserToken(request, targetUser); - - if (this.logger.isDebugEnabled()) { - this.logger.debug("Switch User Token [" + targetUserRequest + "]"); - } - + this.logger.debug(LogMessage.format("Switch User Token [%s]", targetUserRequest)); // publish event if (this.eventPublisher != null) { this.eventPublisher.publishEvent(new AuthenticationSwitchUserEvent( SecurityContextHolder.getContext().getAuthentication(), targetUser)); } - return targetUserRequest; } @@ -256,35 +237,28 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv throws AuthenticationCredentialsNotFoundException { // need to check to see if the current user has a SwitchUserGrantedAuthority Authentication current = SecurityContextHolder.getContext().getAuthentication(); - - if (null == current) { + if (current == null) { throw new AuthenticationCredentialsNotFoundException(this.messages .getMessage("SwitchUserFilter.noCurrentUser", "No current user associated with this request")); } - // check to see if the current user did actual switch to another user // if so, get the original source user so we can switch back Authentication original = getSourceAuthentication(current); - if (original == null) { this.logger.debug("Could not find original user Authentication object!"); throw new AuthenticationCredentialsNotFoundException(this.messages.getMessage( "SwitchUserFilter.noOriginalAuthentication", "Could not find original Authentication object")); } - // get the source user details UserDetails originalUser = null; Object obj = original.getPrincipal(); - if ((obj != null) && obj instanceof UserDetails) { originalUser = (UserDetails) obj; } - // publish event if (this.eventPublisher != null) { this.eventPublisher.publishEvent(new AuthenticationSwitchUserEvent(current, originalUser)); } - return original; } @@ -299,45 +273,38 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv */ private UsernamePasswordAuthenticationToken createSwitchUserToken(HttpServletRequest request, UserDetails targetUser) { - UsernamePasswordAuthenticationToken targetUserRequest; - // grant an additional authority that contains the original Authentication object // which will be used to 'exit' from the current switched user. - - Authentication currentAuth; - - try { - // SEC-1763. Check first if we are already switched. - currentAuth = attemptExitUser(request); - } - catch (AuthenticationCredentialsNotFoundException ex) { - currentAuth = SecurityContextHolder.getContext().getAuthentication(); - } - - GrantedAuthority switchAuthority = new SwitchUserGrantedAuthority(this.switchAuthorityRole, currentAuth); - + Authentication currentAuthentication = getCurrentAuthentication(request); + GrantedAuthority switchAuthority = new SwitchUserGrantedAuthority(this.switchAuthorityRole, + currentAuthentication); // get the original authorities Collection orig = targetUser.getAuthorities(); - // Allow subclasses to change the authorities to be granted if (this.switchUserAuthorityChanger != null) { - orig = this.switchUserAuthorityChanger.modifyGrantedAuthorities(targetUser, currentAuth, orig); + orig = this.switchUserAuthorityChanger.modifyGrantedAuthorities(targetUser, currentAuthentication, orig); } - // add the new switch user authority List newAuths = new ArrayList<>(orig); newAuths.add(switchAuthority); - // create the new authentication token targetUserRequest = new UsernamePasswordAuthenticationToken(targetUser, targetUser.getPassword(), newAuths); - // set details targetUserRequest.setDetails(this.authenticationDetailsSource.buildDetails(request)); - return targetUserRequest; } + private Authentication getCurrentAuthentication(HttpServletRequest request) { + try { + // SEC-1763. Check first if we are already switched. + return attemptExitUser(request); + } + catch (AuthenticationCredentialsNotFoundException ex) { + return SecurityContextHolder.getContext().getAuthentication(); + } + } + /** * Find the original Authentication object from the current user's * granted authorities. A successfully switched user should have a @@ -349,10 +316,8 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv */ private Authentication getSourceAuthentication(Authentication current) { Authentication original = null; - // iterate over granted authorities and find the 'switch user' authority Collection authorities = current.getAuthorities(); - for (GrantedAuthority auth : authorities) { // check for switch user type of authority if (auth instanceof SwitchUserGrantedAuthority) { @@ -360,7 +325,6 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv this.logger.debug("Found original switch user granted authority [" + original + "]"); } } - return original; } diff --git a/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLoginPageGeneratingFilter.java b/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLoginPageGeneratingFilter.java index f3c96da440..06f0955c65 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLoginPageGeneratingFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLoginPageGeneratingFilter.java @@ -112,24 +112,28 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean { this.logoutSuccessUrl = DEFAULT_LOGIN_PAGE_URL + "?logout"; this.failureUrl = DEFAULT_LOGIN_PAGE_URL + "?" + ERROR_PARAMETER_NAME; if (authFilter != null) { - this.formLoginEnabled = true; - this.usernameParameter = authFilter.getUsernameParameter(); - this.passwordParameter = authFilter.getPasswordParameter(); - - if (authFilter.getRememberMeServices() instanceof AbstractRememberMeServices) { - this.rememberMeParameter = ((AbstractRememberMeServices) authFilter.getRememberMeServices()) - .getParameter(); - } + initAuthFilter(authFilter); } - if (openIDFilter != null) { - this.openIdEnabled = true; - this.openIDusernameParameter = "openid_identifier"; + initOpenIdFilter(openIDFilter); + } + } - if (openIDFilter.getRememberMeServices() instanceof AbstractRememberMeServices) { - this.openIDrememberMeParameter = ((AbstractRememberMeServices) openIDFilter.getRememberMeServices()) - .getParameter(); - } + private void initAuthFilter(UsernamePasswordAuthenticationFilter authFilter) { + this.formLoginEnabled = true; + this.usernameParameter = authFilter.getUsernameParameter(); + this.passwordParameter = authFilter.getPasswordParameter(); + if (authFilter.getRememberMeServices() instanceof AbstractRememberMeServices) { + this.rememberMeParameter = ((AbstractRememberMeServices) authFilter.getRememberMeServices()).getParameter(); + } + } + + private void initOpenIdFilter(AbstractAuthenticationProcessingFilter openIDFilter) { + this.openIdEnabled = true; + this.openIDusernameParameter = "openid_identifier"; + if (openIDFilter.getRememberMeServices() instanceof AbstractRememberMeServices) { + this.openIDrememberMeParameter = ((AbstractRememberMeServices) openIDFilter.getRememberMeServices()) + .getParameter(); } } @@ -214,11 +218,13 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean { } @Override - public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - HttpServletRequest request = (HttpServletRequest) req; - HttpServletResponse response = (HttpServletResponse) res; + doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain); + } + private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) + throws IOException, ServletException { boolean loginError = isErrorPage(request); boolean logoutSuccess = isLogoutSuccess(request); if (isLoginUrlRequest(request) || loginError || logoutSuccess) { @@ -226,66 +232,69 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean { response.setContentType("text/html;charset=UTF-8"); response.setContentLength(loginPageHtml.getBytes(StandardCharsets.UTF_8).length); response.getWriter().write(loginPageHtml); - return; } - chain.doFilter(request, response); } private String generateLoginPageHtml(HttpServletRequest request, boolean loginError, boolean logoutSuccess) { String errorMsg = "Invalid credentials"; - if (loginError) { HttpSession session = request.getSession(false); - if (session != null) { AuthenticationException ex = (AuthenticationException) session .getAttribute(WebAttributes.AUTHENTICATION_EXCEPTION); errorMsg = (ex != null) ? ex.getMessage() : "Invalid credentials"; } } - - StringBuilder sb = new StringBuilder(); - - sb.append("\n" + "\n" + " \n" + " \n" - + " \n" - + " \n" + " \n" - + " Please sign in\n" - + " \n" - + " \n" - + " \n" + " \n" + "

\n"); - String contextPath = request.getContextPath(); + StringBuilder sb = new StringBuilder(); + sb.append("\n"); + sb.append("\n"); + sb.append(" \n"); + sb.append(" \n"); + sb.append(" \n"); + sb.append(" \n"); + sb.append(" \n"); + sb.append(" Please sign in\n"); + sb.append(" \n"); + sb.append(" \n"); + sb.append(" \n"); + sb.append(" \n"); + sb.append("
\n"); if (this.formLoginEnabled) { sb.append("
\n" - + " \n" - + createError(loginError, errorMsg) + createLogoutSuccess(logoutSuccess) + "

\n" - + " \n" - + " \n" + "

\n" - + "

\n" + " \n" - + " \n" + "

\n" - + createRememberMe(this.rememberMeParameter) + renderHiddenInputs(request) - + " \n" - + "
\n"); + + this.authenticationUrl + "\">\n"); + sb.append(" \n"); + sb.append(createError(loginError, errorMsg) + createLogoutSuccess(logoutSuccess) + "

\n"); + sb.append(" \n"); + sb.append(" \n"); + sb.append("

\n"); + sb.append("

\n"); + sb.append(" \n"); + sb.append(" \n"); + sb.append("

\n"); + sb.append(createRememberMe(this.rememberMeParameter) + renderHiddenInputs(request)); + sb.append(" \n"); + sb.append(" \n"); } - if (this.openIdEnabled) { sb.append("
\n" - + " \n" - + createError(loginError, errorMsg) + createLogoutSuccess(logoutSuccess) + "

\n" - + " \n" - + " \n" + "

\n" - + createRememberMe(this.openIDrememberMeParameter) + renderHiddenInputs(request) - + " \n" - + "
\n"); + + this.openIDauthenticationUrl + "\">\n"); + sb.append(" \n"); + sb.append(createError(loginError, errorMsg) + createLogoutSuccess(logoutSuccess) + "

\n"); + sb.append(" \n"); + sb.append(" \n"); + sb.append("

\n"); + sb.append(createRememberMe(this.openIDrememberMeParameter) + renderHiddenInputs(request)); + sb.append(" \n"); + sb.append(" \n"); } - if (this.oauth2LoginEnabled) { sb.append(""); sb.append(createError(loginError, errorMsg)); @@ -303,7 +312,6 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean { } sb.append("\n"); } - if (this.saml2LoginEnabled) { sb.append(""); sb.append(createError(loginError, errorMsg)); @@ -323,15 +331,17 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean { } sb.append("
\n"); sb.append(""); - return sb.toString(); } private String renderHiddenInputs(HttpServletRequest request) { StringBuilder sb = new StringBuilder(); for (Map.Entry input : this.resolveHiddenInputs.apply(request).entrySet()) { - sb.append("\n"); + sb.append("\n"); } return sb.toString(); } @@ -356,13 +366,17 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean { } private static String createError(boolean isError, String message) { - return isError ? "
" + HtmlUtils.htmlEscape(message) + "
" - : ""; + if (!isError) { + return ""; + } + return "
" + HtmlUtils.htmlEscape(message) + "
"; } private static String createLogoutSuccess(boolean isLogoutSuccess) { - return isLogoutSuccess ? "
You have been signed out
" - : ""; + if (!isLogoutSuccess) { + return ""; + } + return "
You have been signed out
"; } private boolean matches(HttpServletRequest request, String url) { @@ -371,20 +385,16 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean { } String uri = request.getRequestURI(); int pathParamIndex = uri.indexOf(';'); - if (pathParamIndex > 0) { // strip everything after the first semi-colon uri = uri.substring(0, pathParamIndex); } - if (request.getQueryString() != null) { uri += "?" + request.getQueryString(); } - if ("".equals(request.getContextPath())) { return uri.equals(url); } - return uri.equals(request.getContextPath() + url); } diff --git a/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLogoutPageGeneratingFilter.java b/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLogoutPageGeneratingFilter.java index 7dae48a336..84f6cde296 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLogoutPageGeneratingFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLogoutPageGeneratingFilter.java @@ -55,21 +55,34 @@ public class DefaultLogoutPageGeneratingFilter extends OncePerRequestFilter { } private void renderLogout(HttpServletRequest request, HttpServletResponse response) throws IOException { - String page = "\n" + "\n" + " \n" + " \n" - + " \n" - + " \n" + " \n" - + " Confirm Log Out?\n" - + " \n" - + " \n" - + " \n" + " \n" + "
\n" - + "
\n" + " \n" - + renderHiddenInputs(request) - + " \n" - + "
\n" + "
\n" + " \n" + ""; - + StringBuilder sb = new StringBuilder(); + sb.append("\n"); + sb.append("\n"); + sb.append(" \n"); + sb.append(" \n"); + sb.append(" \n"); + sb.append(" \n"); + sb.append(" \n"); + sb.append(" Confirm Log Out?\n"); + sb.append(" \n"); + sb.append(" \n"); + sb.append(" \n"); + sb.append(" \n"); + sb.append("
\n"); + sb.append("
\n"); + sb.append(" \n"); + sb.append(renderHiddenInputs(request) + + " \n"); + sb.append("
\n"); + sb.append("
\n"); + sb.append(" \n"); + sb.append(""); response.setContentType("text/html;charset=UTF-8"); - response.getWriter().write(page); + response.getWriter().write(sb.toString()); } /** @@ -86,8 +99,11 @@ public class DefaultLogoutPageGeneratingFilter extends OncePerRequestFilter { private String renderHiddenInputs(HttpServletRequest request) { StringBuilder sb = new StringBuilder(); for (Map.Entry input : this.resolveHiddenInputs.apply(request).entrySet()) { - sb.append("\n"); + sb.append("\n"); } return sb.toString(); } diff --git a/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationConverter.java b/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationConverter.java index fdbbb743bf..2e39a67624 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationConverter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationConverter.java @@ -80,29 +80,17 @@ public class BasicAuthenticationConverter implements AuthenticationConverter { if (header == null) { return null; } - header = header.trim(); if (!StringUtils.startsWithIgnoreCase(header, AUTHENTICATION_SCHEME_BASIC)) { return null; } - if (header.equalsIgnoreCase(AUTHENTICATION_SCHEME_BASIC)) { throw new BadCredentialsException("Empty basic authentication token"); } - byte[] base64Token = header.substring(6).getBytes(StandardCharsets.UTF_8); - byte[] decoded; - try { - decoded = Base64.getDecoder().decode(base64Token); - } - catch (IllegalArgumentException ex) { - throw new BadCredentialsException("Failed to decode basic authentication token"); - } - + byte[] decoded = decode(base64Token); String token = new String(decoded, getCredentialsCharset(request)); - int delim = token.indexOf(":"); - if (delim == -1) { throw new BadCredentialsException("Invalid basic authentication token"); } @@ -112,6 +100,15 @@ public class BasicAuthenticationConverter implements AuthenticationConverter { return result; } + private byte[] decode(byte[] base64Token) { + try { + return Base64.getDecoder().decode(base64Token); + } + catch (IllegalArgumentException ex) { + throw new BadCredentialsException("Failed to decode basic authentication token"); + } + } + protected Charset getCredentialsCharset(HttpServletRequest request) { return getCredentialsCharset(); } diff --git a/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java index 620179bebe..199eb1d429 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java @@ -24,6 +24,7 @@ import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; @@ -132,7 +133,6 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter { @Override public void afterPropertiesSet() { Assert.notNull(this.authenticationManager, "An AuthenticationManager is required"); - if (!isIgnoreFailure()) { Assert.notNull(this.authenticationEntryPoint, "An AuthenticationEntryPoint is required"); } @@ -141,53 +141,34 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws IOException, ServletException { - final boolean debug = this.logger.isDebugEnabled(); try { UsernamePasswordAuthenticationToken authRequest = this.authenticationConverter.convert(request); if (authRequest == null) { chain.doFilter(request, response); return; } - String username = authRequest.getName(); - - if (debug) { - this.logger.debug("Basic Authentication Authorization header found for user '" + username + "'"); - } - + this.logger.debug( + LogMessage.format("Basic Authentication Authorization header found for user '%s'", username)); if (authenticationIsRequired(username)) { Authentication authResult = this.authenticationManager.authenticate(authRequest); - - if (debug) { - this.logger.debug("Authentication success: " + authResult); - } - + this.logger.debug(LogMessage.format("Authentication success: %s", authResult)); SecurityContextHolder.getContext().setAuthentication(authResult); - this.rememberMeServices.loginSuccess(request, response, authResult); - onSuccessfulAuthentication(request, response, authResult); } - } - catch (AuthenticationException failed) { + catch (AuthenticationException ex) { SecurityContextHolder.clearContext(); - - if (debug) { - this.logger.debug("Authentication request for failed!", failed); - } - + this.logger.debug("Authentication request for failed!", ex); this.rememberMeServices.loginFail(request, response); - - onUnsuccessfulAuthentication(request, response, failed); - + onUnsuccessfulAuthentication(request, response, ex); if (this.ignoreFailure) { chain.doFilter(request, response); } else { - this.authenticationEntryPoint.commence(request, response, failed); + this.authenticationEntryPoint.commence(request, response, ex); } - return; } @@ -196,40 +177,26 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter { private boolean authenticationIsRequired(String username) { // Only reauthenticate if username doesn't match SecurityContextHolder and user - // isn't authenticated - // (see SEC-53) + // isn't authenticated (see SEC-53) Authentication existingAuth = SecurityContextHolder.getContext().getAuthentication(); - if (existingAuth == null || !existingAuth.isAuthenticated()) { return true; } - // Limit username comparison to providers which use usernames (ie - // UsernamePasswordAuthenticationToken) - // (see SEC-348) - + // UsernamePasswordAuthenticationToken) (see SEC-348) if (existingAuth instanceof UsernamePasswordAuthenticationToken && !existingAuth.getName().equals(username)) { return true; } - // Handle unusual condition where an AnonymousAuthenticationToken is already - // present - // This shouldn't happen very often, as BasicProcessingFitler is meant to be - // earlier in the filter - // chain than AnonymousAuthenticationFilter. Nevertheless, presence of both an - // AnonymousAuthenticationToken - // together with a BASIC authentication request header should indicate - // reauthentication using the + // present. This shouldn't happen very often, as BasicProcessingFitler is meant to + // be earlier in the filter chain than AnonymousAuthenticationFilter. + // Nevertheless, presence of both an AnonymousAuthenticationToken together with a + // BASIC authentication request header should indicate reauthentication using the // BASIC protocol is desirable. This behaviour is also consistent with that - // provided by form and digest, - // both of which force re-authentication if the respective header is detected (and - // in doing so replace - // any existing AnonymousAuthenticationToken). See SEC-610. - if (existingAuth instanceof AnonymousAuthenticationToken) { - return true; - } - - return false; + // provided by form and digest, both of which force re-authentication if the + // respective header is detected (and in doing so replace/ any existing + // AnonymousAuthenticationToken). See SEC-610. + return (existingAuth instanceof AnonymousAuthenticationToken); } protected void onSuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, diff --git a/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthUtils.java b/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthUtils.java index 506dd648dd..a3621e4225 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthUtils.java +++ b/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthUtils.java @@ -44,18 +44,14 @@ final class DigestAuthUtils { if (str == null) { return null; } - int len = str.length(); - if (len == 0) { return EMPTY_STRING_ARRAY; } - List list = new ArrayList<>(); int i = 0; int start = 0; boolean match = false; - while (i < len) { if (str.charAt(i) == '"') { i++; @@ -83,7 +79,6 @@ final class DigestAuthUtils { if (match) { list.add(str.substring(start, i)); } - return list.toArray(new String[0]); } @@ -108,32 +103,19 @@ final class DigestAuthUtils { static String generateDigest(boolean passwordAlreadyEncoded, String username, String realm, String password, String httpMethod, String uri, String qop, String nonce, String nc, String cnonce) throws IllegalArgumentException { - String a1Md5; String a2 = httpMethod + ":" + uri; + String a1Md5 = (!passwordAlreadyEncoded) ? DigestAuthUtils.encodePasswordInA1Format(username, realm, password) + : password; String a2Md5 = md5Hex(a2); - - if (passwordAlreadyEncoded) { - a1Md5 = password; - } - else { - a1Md5 = DigestAuthUtils.encodePasswordInA1Format(username, realm, password); - } - - String digest; - if (qop == null) { // as per RFC 2069 compliant clients (also reaffirmed by RFC 2617) - digest = a1Md5 + ":" + nonce + ":" + a2Md5; + return md5Hex(a1Md5 + ":" + nonce + ":" + a2Md5); } - else if ("auth".equals(qop)) { + if ("auth".equals(qop)) { // As per RFC 2617 compliant clients - digest = a1Md5 + ":" + nonce + ":" + nc + ":" + cnonce + ":" + qop + ":" + a2Md5; + return md5Hex(a1Md5 + ":" + nonce + ":" + nc + ":" + cnonce + ":" + qop + ":" + a2Md5); } - else { - throw new IllegalArgumentException("This method does not support a qop: '" + qop + "'"); - } - - return md5Hex(digest); + throw new IllegalArgumentException("This method does not support a qop: '" + qop + "'"); } /** @@ -157,28 +139,15 @@ final class DigestAuthUtils { if ((array == null) || (array.length == 0)) { return null; } - Map map = new HashMap<>(); - for (String s : array) { - String postRemove; - - if (removeCharacters == null) { - postRemove = s; - } - else { - postRemove = StringUtils.replace(s, removeCharacters, ""); - } - + String postRemove = (removeCharacters != null) ? StringUtils.replace(s, removeCharacters, "") : s; String[] splitThisArrayElement = split(postRemove, delimiter); - if (splitThisArrayElement == null) { continue; } - map.put(splitThisArrayElement[0].trim(), splitThisArrayElement[1].trim()); } - return map; } @@ -196,33 +165,24 @@ final class DigestAuthUtils { static String[] split(String toSplit, String delimiter) { Assert.hasLength(toSplit, "Cannot split a null or empty string"); Assert.hasLength(delimiter, "Cannot use a null or empty delimiter to split a string"); - - if (delimiter.length() != 1) { - throw new IllegalArgumentException("Delimiter can only be one character in length"); - } - + Assert.isTrue(delimiter.length() == 1, "Delimiter can only be one character in length"); int offset = toSplit.indexOf(delimiter); - if (offset < 0) { return null; } - String beforeDelimiter = toSplit.substring(0, offset); String afterDelimiter = toSplit.substring(offset + 1); - return new String[] { beforeDelimiter, afterDelimiter }; } static String md5Hex(String data) { - MessageDigest digest; try { - digest = MessageDigest.getInstance("MD5"); + MessageDigest digest = MessageDigest.getInstance("MD5"); + return new String(Hex.encode(digest.digest(data.getBytes()))); } catch (NoSuchAlgorithmException ex) { throw new IllegalStateException("No MD5 algorithm available!"); } - - return new String(Hex.encode(digest.digest(data.getBytes()))); } } diff --git a/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationEntryPoint.java b/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationEntryPoint.java index 6fe4b6b4e5..ac547cabb5 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationEntryPoint.java +++ b/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationEntryPoint.java @@ -27,9 +27,11 @@ import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.InitializingBean; import org.springframework.core.Ordered; +import org.springframework.core.log.LogMessage; import org.springframework.http.HttpStatus; import org.springframework.security.core.AuthenticationException; import org.springframework.security.web.AuthenticationEntryPoint; +import org.springframework.util.Assert; /** * Used by the SecurityEnforcementFilter to commence authentication via the @@ -68,44 +70,30 @@ public class DigestAuthenticationEntryPoint implements AuthenticationEntryPoint, @Override public void afterPropertiesSet() { - if ((this.realmName == null) || "".equals(this.realmName)) { - throw new IllegalArgumentException("realmName must be specified"); - } - - if ((this.key == null) || "".equals(this.key)) { - throw new IllegalArgumentException("key must be specified"); - } + Assert.hasLength(this.realmName, "realmName must be specified"); + Assert.hasLength(this.key, "key must be specified"); } @Override public void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException authException) throws IOException { - HttpServletResponse httpResponse = response; - - // compute a nonce (do not use remote IP address due to proxy farms) - // format of nonce is: - // base64(expirationTime + ":" + md5Hex(expirationTime + ":" + key)) + // compute a nonce (do not use remote IP address due to proxy farms) format of + // nonce is: base64(expirationTime + ":" + md5Hex(expirationTime + ":" + key)) long expiryTime = System.currentTimeMillis() + (this.nonceValiditySeconds * 1000); String signatureValue = DigestAuthUtils.md5Hex(expiryTime + ":" + this.key); String nonceValue = expiryTime + ":" + signatureValue; String nonceValueBase64 = new String(Base64.getEncoder().encode(nonceValue.getBytes())); - - // qop is quality of protection, as defined by RFC 2617. - // we do not use opaque due to IE violation of RFC 2617 in not - // representing opaque on subsequent requests in same session. + // qop is quality of protection, as defined by RFC 2617. We do not use opaque due + // to IE violation of RFC 2617 in not representing opaque on subsequent requests + // in same session. String authenticateHeader = "Digest realm=\"" + this.realmName + "\", " + "qop=\"auth\", nonce=\"" + nonceValueBase64 + "\""; - if (authException instanceof NonceExpiredException) { authenticateHeader = authenticateHeader + ", stale=\"true\""; } - - if (logger.isDebugEnabled()) { - logger.debug("WWW-Authenticate header sent to user agent: " + authenticateHeader); - } - - httpResponse.addHeader("WWW-Authenticate", authenticateHeader); - httpResponse.sendError(HttpStatus.UNAUTHORIZED.value(), HttpStatus.UNAUTHORIZED.getReasonPhrase()); + logger.debug(LogMessage.format("WWW-Authenticate header sent to user agent: %s", authenticateHeader)); + response.addHeader("WWW-Authenticate", authenticateHeader); + response.sendError(HttpStatus.UNAUTHORIZED.value(), HttpStatus.UNAUTHORIZED.getReasonPhrase()); } public String getKey() { diff --git a/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilter.java index abda070507..12cc3af722 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilter.java @@ -33,6 +33,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.context.MessageSource; import org.springframework.context.MessageSourceAware; import org.springframework.context.support.MessageSourceAccessor; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.authentication.BadCredentialsException; @@ -112,136 +113,105 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes } @Override - public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - HttpServletRequest request = (HttpServletRequest) req; - HttpServletResponse response = (HttpServletResponse) res; + doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain); + } + private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) + throws IOException, ServletException { String header = request.getHeader("Authorization"); - if (header == null || !header.startsWith("Digest ")) { chain.doFilter(request, response); - return; } - - if (logger.isDebugEnabled()) { - logger.debug("Digest Authorization header received from user agent: " + header); - } - + logger.debug(LogMessage.format("Digest Authorization header received from user agent: %s", header)); DigestData digestAuth = new DigestData(header); - try { digestAuth.validateAndDecode(this.authenticationEntryPoint.getKey(), this.authenticationEntryPoint.getRealmName()); } catch (BadCredentialsException ex) { fail(request, response, ex); - return; } - - // Lookup password for presented username - // NB: DAO-provided password MUST be clear text - not encoded/salted - // (unless this instance's passwordAlreadyEncoded property is 'false') + // Lookup password for presented username. N.B. DAO-provided password MUST be + // clear text - not encoded/salted (unless this instance's passwordAlreadyEncoded + // property is 'false') boolean cacheWasUsed = true; UserDetails user = this.userCache.getUserFromCache(digestAuth.getUsername()); String serverDigestMd5; - try { if (user == null) { cacheWasUsed = false; user = this.userDetailsService.loadUserByUsername(digestAuth.getUsername()); - if (user == null) { throw new AuthenticationServiceException( "AuthenticationDao returned null, which is an interface contract violation"); } - this.userCache.putUserInCache(user); } - serverDigestMd5 = digestAuth.calculateServerDigest(user.getPassword(), request.getMethod()); - // If digest is incorrect, try refreshing from backend and recomputing if (!serverDigestMd5.equals(digestAuth.getResponse()) && cacheWasUsed) { - if (logger.isDebugEnabled()) { - logger.debug( - "Digest comparison failure; trying to refresh user from DAO in case password had changed"); - } - + logger.debug("Digest comparison failure; trying to refresh user from DAO in case password had changed"); user = this.userDetailsService.loadUserByUsername(digestAuth.getUsername()); this.userCache.putUserInCache(user); serverDigestMd5 = digestAuth.calculateServerDigest(user.getPassword(), request.getMethod()); } - } - catch (UsernameNotFoundException notFound) { - fail(request, response, - new BadCredentialsException(this.messages.getMessage("DigestAuthenticationFilter.usernameNotFound", - new Object[] { digestAuth.getUsername() }, "Username {0} not found"))); - + catch (UsernameNotFoundException ex) { + String message = this.messages.getMessage("DigestAuthenticationFilter.usernameNotFound", + new Object[] { digestAuth.getUsername() }, "Username {0} not found"); + fail(request, response, new BadCredentialsException(message)); return; } - // If digest is still incorrect, definitely reject authentication attempt if (!serverDigestMd5.equals(digestAuth.getResponse())) { - if (logger.isDebugEnabled()) { - logger.debug("Expected response: '" + serverDigestMd5 + "' but received: '" + digestAuth.getResponse() - + "'; is AuthenticationDao returning clear text passwords?"); - } - - fail(request, response, new BadCredentialsException( - this.messages.getMessage("DigestAuthenticationFilter.incorrectResponse", "Incorrect response"))); + logger.debug(LogMessage.format( + "Expected response: '%s' but received: '%s'; is AuthenticationDao returning clear text passwords?", + serverDigestMd5, digestAuth.getResponse())); + String message = this.messages.getMessage("DigestAuthenticationFilter.incorrectResponse", + "Incorrect response"); + fail(request, response, new BadCredentialsException(message)); return; } - // To get this far, the digest must have been valid // Check the nonce has not expired // We do this last so we can direct the user agent its nonce is stale // but the request was otherwise appearing to be valid if (digestAuth.isNonceExpired()) { - fail(request, response, new NonceExpiredException(this.messages - .getMessage("DigestAuthenticationFilter.nonceExpired", "Nonce has expired/timed out"))); - + String message = this.messages.getMessage("DigestAuthenticationFilter.nonceExpired", + "Nonce has expired/timed out"); + fail(request, response, new NonceExpiredException(message)); return; } - - if (logger.isDebugEnabled()) { - logger.debug("Authentication success for user: '" + digestAuth.getUsername() + "' with response: '" - + digestAuth.getResponse() + "'"); - } - + logger.debug(LogMessage.format("Authentication success for user: '%s' with response: '%s'", + digestAuth.getUsername(), digestAuth.getResponse())); Authentication authentication = createSuccessfulAuthentication(request, user); SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(authentication); SecurityContextHolder.setContext(context); - chain.doFilter(request, response); } private Authentication createSuccessfulAuthentication(HttpServletRequest request, UserDetails user) { - UsernamePasswordAuthenticationToken authRequest; - if (this.createAuthenticatedToken) { - authRequest = new UsernamePasswordAuthenticationToken(user, user.getPassword(), user.getAuthorities()); - } - else { - authRequest = new UsernamePasswordAuthenticationToken(user, user.getPassword()); - } - + UsernamePasswordAuthenticationToken authRequest = getAuthRequest(user); authRequest.setDetails(this.authenticationDetailsSource.buildDetails(request)); - return authRequest; } + private UsernamePasswordAuthenticationToken getAuthRequest(UserDetails user) { + if (this.createAuthenticatedToken) { + return new UsernamePasswordAuthenticationToken(user, user.getPassword(), user.getAuthorities()); + } + return new UsernamePasswordAuthenticationToken(user, user.getPassword()); + } + private void fail(HttpServletRequest request, HttpServletResponse response, AuthenticationException failed) throws IOException, ServletException { SecurityContextHolder.getContext().setAuthentication(null); - - if (logger.isDebugEnabled()) { - logger.debug(failed); - } - + logger.debug(failed); this.authenticationEntryPoint.commence(request, response, failed); } @@ -326,7 +296,6 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes this.section212response = header.substring(7); String[] headerEntries = DigestAuthUtils.splitIgnoringQuotes(this.section212response, ','); Map headerMap = DigestAuthUtils.splitEachArrayElementAndCreateMap(headerEntries, "=", "\""); - this.username = headerMap.get("username"); this.realm = headerMap.get("realm"); this.nonce = headerMap.get("nonce"); @@ -335,11 +304,9 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes this.qop = headerMap.get("qop"); // RFC 2617 extension this.nc = headerMap.get("nc"); // RFC 2617 extension this.cnonce = headerMap.get("cnonce"); // RFC 2617 extension - - if (logger.isDebugEnabled()) { - logger.debug("Extracted username: '" + this.username + "'; realm: '" + this.realm + "'; nonce: '" - + this.nonce + "'; uri: '" + this.uri + "'; response: '" + this.response + "'"); - } + logger.debug( + LogMessage.format("Extracted username: '%s'; realm: '%s'; nonce: '%s'; uri: '%s'; response: '%s'", + this.username, this.realm, this.nonce, this.uri, this.response)); } void validateAndDecode(String entryPointKey, String expectedRealm) throws BadCredentialsException { @@ -353,23 +320,18 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes // Check all required parameters for an "auth" qop were supplied (ie RFC 2617) if ("auth".equals(this.qop)) { if ((this.nc == null) || (this.cnonce == null)) { - if (logger.isDebugEnabled()) { - logger.debug("extracted nc: '" + this.nc + "'; cnonce: '" + this.cnonce + "'"); - } - + logger.debug(LogMessage.format("extracted nc: '%s'; cnonce: '%s'", this.nc, this.cnonce)); throw new BadCredentialsException(DigestAuthenticationFilter.this.messages.getMessage( "DigestAuthenticationFilter.missingAuth", new Object[] { this.section212response }, "Missing mandatory digest value; received header {0}")); } } - // Check realm name equals what we expected if (!expectedRealm.equals(this.realm)) { throw new BadCredentialsException(DigestAuthenticationFilter.this.messages.getMessage( "DigestAuthenticationFilter.incorrectRealm", new Object[] { this.realm, expectedRealm }, "Response realm name '{0}' does not match system realm name of '{1}'")); } - // Check nonce was Base64 encoded (as sent by DigestAuthenticationEntryPoint) try { Base64.getDecoder().decode(this.nonce.getBytes()); @@ -379,21 +341,16 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes DigestAuthenticationFilter.this.messages.getMessage("DigestAuthenticationFilter.nonceEncoding", new Object[] { this.nonce }, "Nonce is not encoded in Base64; received nonce {0}")); } - - // Decode nonce from Base64 - // format of nonce is: - // base64(expirationTime + ":" + md5Hex(expirationTime + ":" + key)) + // Decode nonce from Base64 format of nonce is: base64(expirationTime + ":" + + // md5Hex(expirationTime + ":" + key)) String nonceAsPlainText = new String(Base64.getDecoder().decode(this.nonce.getBytes())); String[] nonceTokens = StringUtils.delimitedListToStringArray(nonceAsPlainText, ":"); - if (nonceTokens.length != 2) { throw new BadCredentialsException(DigestAuthenticationFilter.this.messages.getMessage( "DigestAuthenticationFilter.nonceNotTwoTokens", new Object[] { nonceAsPlainText }, "Nonce should have yielded two tokens but was {0}")); } - // Extract expiry time from nonce - try { this.nonceExpiryTime = new Long(nonceTokens[0]); } @@ -402,10 +359,8 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes "DigestAuthenticationFilter.nonceNotNumeric", new Object[] { nonceAsPlainText }, "Nonce token should have yielded a numeric first token, but was {0}")); } - // Check signature of nonce matches this expiry time String expectedNonceSignature = DigestAuthUtils.md5Hex(this.nonceExpiryTime + ":" + entryPointKey); - if (!expectedNonceSignature.equals(nonceTokens[1])) { throw new BadCredentialsException(DigestAuthenticationFilter.this.messages.getMessage( "DigestAuthenticationFilter.nonceCompromised", new Object[] { nonceAsPlainText }, @@ -414,9 +369,8 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes } String calculateServerDigest(String password, String httpMethod) { - // Compute the expected response-digest (will be in hex form) - - // Don't catch IllegalArgumentException (already checked validity) + // Compute the expected response-digest (will be in hex form). Don't catch + // IllegalArgumentException (already checked validity) return DigestAuthUtils.generateDigest(DigestAuthenticationFilter.this.passwordAlreadyEncoded, this.username, this.realm, password, httpMethod, this.uri, this.qop, this.nonce, this.nc, this.cnonce); } diff --git a/web/src/main/java/org/springframework/security/web/bind/support/AuthenticationPrincipalArgumentResolver.java b/web/src/main/java/org/springframework/security/web/bind/support/AuthenticationPrincipalArgumentResolver.java index fcda61caf1..2ff185989c 100644 --- a/web/src/main/java/org/springframework/security/web/bind/support/AuthenticationPrincipalArgumentResolver.java +++ b/web/src/main/java/org/springframework/security/web/bind/support/AuthenticationPrincipalArgumentResolver.java @@ -105,9 +105,7 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet if (authPrincipal.errorOnInvalidType()) { throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType()); } - else { - return null; - } + return null; } return principal; } diff --git a/web/src/main/java/org/springframework/security/web/context/AbstractSecurityWebApplicationInitializer.java b/web/src/main/java/org/springframework/security/web/context/AbstractSecurityWebApplicationInitializer.java index f722374ab1..2df5c7a2eb 100644 --- a/web/src/main/java/org/springframework/security/web/context/AbstractSecurityWebApplicationInitializer.java +++ b/web/src/main/java/org/springframework/security/web/context/AbstractSecurityWebApplicationInitializer.java @@ -173,11 +173,8 @@ public abstract class AbstractSecurityWebApplicationInitializer implements WebAp */ private void registerFilters(ServletContext servletContext, boolean insertBeforeOtherFilters, Filter... filters) { Assert.notEmpty(filters, "filters cannot be null or empty"); - for (Filter filter : filters) { - if (filter == null) { - throw new IllegalArgumentException("filters cannot contain null values. Got " + Arrays.asList(filters)); - } + Assert.notNull(filter, () -> "filters cannot contain null values. Got " + Arrays.asList(filters)); String filterName = Conventions.getVariableName(filter); registerFilter(servletContext, insertBeforeOtherFilters, filterName, filter); } @@ -195,10 +192,8 @@ public abstract class AbstractSecurityWebApplicationInitializer implements WebAp private void registerFilter(ServletContext servletContext, boolean insertBeforeOtherFilters, String filterName, Filter filter) { Dynamic registration = servletContext.addFilter(filterName, filter); - if (registration == null) { - throw new IllegalStateException("Duplicate Filter registration for '" + filterName - + "'. Check to ensure the Filter is only configured once."); - } + Assert.state(registration != null, () -> "Duplicate Filter registration for '" + filterName + + "'. Check to ensure the Filter is only configured once."); registration.setAsyncSupported(isAsyncSecuritySupported()); EnumSet dispatcherTypes = getSecurityDispatcherTypes(); registration.addMappingForUrlPatterns(dispatcherTypes, !insertBeforeOtherFilters, "/*"); diff --git a/web/src/main/java/org/springframework/security/web/context/HttpSessionSecurityContextRepository.java b/web/src/main/java/org/springframework/security/web/context/HttpSessionSecurityContextRepository.java index c0e9b958a6..e58ad9ef6b 100644 --- a/web/src/main/java/org/springframework/security/web/context/HttpSessionSecurityContextRepository.java +++ b/web/src/main/java/org/springframework/security/web/context/HttpSessionSecurityContextRepository.java @@ -28,6 +28,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AuthenticationTrustResolver; import org.springframework.security.authentication.AuthenticationTrustResolverImpl; import org.springframework.security.core.Authentication; @@ -115,24 +116,18 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo HttpServletRequest request = requestResponseHolder.getRequest(); HttpServletResponse response = requestResponseHolder.getResponse(); HttpSession httpSession = request.getSession(false); - SecurityContext context = readSecurityContextFromSession(httpSession); - if (context == null) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("No SecurityContext was available from the HttpSession: " + httpSession + ". " - + "A new one will be created."); - } + this.logger.debug(LogMessage.format( + "No SecurityContext was available from the HttpSession: %s. A new one will be created.", + httpSession)); context = generateNewContext(); } - SaveToSessionResponseWrapper wrappedResponse = new SaveToSessionResponseWrapper(response, request, httpSession != null, context); requestResponseHolder.setResponse(wrappedResponse); - requestResponseHolder.setRequest(new SaveToSessionRequestWrapper(request, wrappedResponse)); - return context; } @@ -140,13 +135,10 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo public void saveContext(SecurityContext context, HttpServletRequest request, HttpServletResponse response) { SaveContextOnUpdateOrErrorResponseWrapper responseWrapper = WebUtils.getNativeResponse(response, SaveContextOnUpdateOrErrorResponseWrapper.class); - if (responseWrapper == null) { - throw new IllegalStateException("Cannot invoke saveContext on response " + response - + ". You must use the HttpRequestResponseHolder.response after invoking loadContext"); - } - // saveContext() might already be called by the response wrapper - // if something in the chain called sendError() or sendRedirect(). This ensures we - // only call it + Assert.state(responseWrapper != null, () -> "Cannot invoke saveContext on response " + response + + ". You must use the HttpRequestResponseHolder.response after invoking loadContext"); + // saveContext() might already be called by the response wrapper if something in + // the chain called sendError() or sendRedirect(). This ensures we only call it // once per request. if (!responseWrapper.isContextSaved()) { responseWrapper.saveContext(context); @@ -156,11 +148,9 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo @Override public boolean containsContext(HttpServletRequest request) { HttpSession session = request.getSession(false); - if (session == null) { return false; } - return session.getAttribute(this.springSecurityContextKey) != null; } @@ -168,47 +158,30 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo * @param httpSession the session obtained from the request. */ private SecurityContext readSecurityContextFromSession(HttpSession httpSession) { - final boolean debug = this.logger.isDebugEnabled(); - if (httpSession == null) { - if (debug) { - this.logger.debug("No HttpSession currently exists"); - } - + this.logger.debug("No HttpSession currently exists"); return null; } - // Session exists, so try to obtain a context from it. - Object contextFromSession = httpSession.getAttribute(this.springSecurityContextKey); - if (contextFromSession == null) { - if (debug) { - this.logger.debug("HttpSession returned null object for SPRING_SECURITY_CONTEXT"); - } - + this.logger.debug("HttpSession returned null object for SPRING_SECURITY_CONTEXT"); return null; } // We now have the security context object from the session. if (!(contextFromSession instanceof SecurityContext)) { - if (this.logger.isWarnEnabled()) { - this.logger.warn(this.springSecurityContextKey + " did not contain a SecurityContext but contained: '" - + contextFromSession + "'; are you improperly modifying the HttpSession directly " - + "(you should always use SecurityContextHolder) or using the HttpSession attribute " - + "reserved for this class?"); - } - + this.logger.warn(LogMessage.format( + "%s did not contain a SecurityContext but contained: '%s'; are you improperly " + + "modifying the HttpSession directly (you should always use SecurityContextHolder) " + + "or using the HttpSession attribute reserved for this class?", + this.springSecurityContextKey, contextFromSession)); return null; } - if (debug) { - this.logger.debug("Obtained a valid SecurityContext from " + this.springSecurityContextKey + ": '" - + contextFromSession + "'"); - } - + this.logger.debug(LogMessage.format("Obtained a valid SecurityContext from %s: '%s'", + this.springSecurityContextKey, contextFromSession)); // Everything OK. The only non-null return from this method. - return (SecurityContext) contextFromSession; } @@ -306,6 +279,8 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo */ final class SaveToSessionResponseWrapper extends SaveContextOnUpdateOrErrorResponseWrapper { + private final Log logger = HttpSessionSecurityContextRepository.this.logger; + private final HttpServletRequest request; private final boolean httpSessionExistedAtStartOfRequest; @@ -349,41 +324,29 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo protected void saveContext(SecurityContext context) { final Authentication authentication = context.getAuthentication(); HttpSession httpSession = this.request.getSession(false); - + String springSecurityContextKey = HttpSessionSecurityContextRepository.this.springSecurityContextKey; // See SEC-776 if (authentication == null || HttpSessionSecurityContextRepository.this.trustResolver.isAnonymous(authentication)) { - if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) { - HttpSessionSecurityContextRepository.this.logger.debug( - "SecurityContext is empty or contents are anonymous - context will not be stored in HttpSession."); - } - + this.logger.debug("SecurityContext is empty or contents are anonymous - " + + "context will not be stored in HttpSession."); if (httpSession != null && this.authBeforeExecution != null) { // SEC-1587 A non-anonymous context may still be in the session // SEC-1735 remove if the contextBeforeExecution was not anonymous - httpSession.removeAttribute(HttpSessionSecurityContextRepository.this.springSecurityContextKey); + httpSession.removeAttribute(springSecurityContextKey); } return; } - - if (httpSession == null) { - httpSession = createNewSessionIfAllowed(context); - } - + httpSession = (httpSession != null) ? httpSession : createNewSessionIfAllowed(context); // If HttpSession exists, store current SecurityContext but only if it has // actually changed in this thread (see SEC-37, SEC-1307, SEC-1528) if (httpSession != null) { // We may have a new session, so check also whether the context attribute // is set SEC-1561 - if (contextChanged(context) || httpSession - .getAttribute(HttpSessionSecurityContextRepository.this.springSecurityContextKey) == null) { - httpSession.setAttribute(HttpSessionSecurityContextRepository.this.springSecurityContextKey, - context); - - if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) { - HttpSessionSecurityContextRepository.this.logger - .debug("SecurityContext '" + context + "' stored to HttpSession: '" + httpSession); - } + if (contextChanged(context) || httpSession.getAttribute(springSecurityContextKey) == null) { + httpSession.setAttribute(springSecurityContextKey, context); + this.logger.debug(LogMessage.format("SecurityContext '%s' stored to HttpSession: '%s'", context, + httpSession)); } } } @@ -396,56 +359,37 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo if (isTransientAuthentication(context.getAuthentication())) { return null; } - if (this.httpSessionExistedAtStartOfRequest) { - if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) { - HttpSessionSecurityContextRepository.this.logger - .debug("HttpSession is now null, but was not null at start of request; " - + "session was invalidated, so do not create a new session"); - } - + this.logger.debug("HttpSession is now null, but was not null at start of request; " + + "session was invalidated, so do not create a new session"); return null; } - if (!HttpSessionSecurityContextRepository.this.allowSessionCreation) { - if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) { - HttpSessionSecurityContextRepository.this.logger.debug("The HttpSession is currently null, and the " - + HttpSessionSecurityContextRepository.class.getSimpleName() - + " is prohibited from creating an HttpSession " - + "(because the allowSessionCreation property is false) - SecurityContext thus not " - + "stored for next request"); - } - + this.logger.debug("The HttpSession is currently null, and the " + + HttpSessionSecurityContextRepository.class.getSimpleName() + + " is prohibited from creating an HttpSession " + + "(because the allowSessionCreation property is false) - SecurityContext thus not " + + "stored for next request"); return null; } // Generate a HttpSession only if we need to - if (HttpSessionSecurityContextRepository.this.contextObject.equals(context)) { - if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) { - HttpSessionSecurityContextRepository.this.logger.debug( - "HttpSession is null, but SecurityContext has not changed from default empty context: ' " - + context + "'; not creating HttpSession or storing SecurityContext"); - } - + this.logger.debug(LogMessage.format( + "HttpSession is null, but SecurityContext has not changed from " + + "default empty context: '%s'; not creating HttpSession or storing SecurityContext", + context)); return null; } - - if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) { - HttpSessionSecurityContextRepository.this.logger - .debug("HttpSession being created as SecurityContext is non-default"); - } - + this.logger.debug("HttpSession being created as SecurityContext is non-default"); try { return this.request.getSession(true); } catch (IllegalStateException ex) { // Response must already be committed, therefore can't create a new // session - HttpSessionSecurityContextRepository.this.logger - .warn("Failed to create a session, as response has been committed. Unable to store" - + " SecurityContext."); + this.logger.warn("Failed to create a session, as response has been committed. " + + "Unable to store SecurityContext."); } - return null; } diff --git a/web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java b/web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java index ab021a8525..4fb534a565 100644 --- a/web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java +++ b/web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java @@ -44,7 +44,7 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends OnCommit private boolean contextSaved = false; - /* See SEC-1052 */ + // See SEC-1052 private final boolean disableUrlRewriting; /** diff --git a/web/src/main/java/org/springframework/security/web/context/SecurityContextPersistenceFilter.java b/web/src/main/java/org/springframework/security/web/context/SecurityContextPersistenceFilter.java index 67a623cf3d..0077c71d46 100644 --- a/web/src/main/java/org/springframework/security/web/context/SecurityContextPersistenceFilter.java +++ b/web/src/main/java/org/springframework/security/web/context/SecurityContextPersistenceFilter.java @@ -26,6 +26,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; +import org.springframework.core.log.LogMessage; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.web.filter.GenericFilterBean; @@ -74,49 +75,36 @@ public class SecurityContextPersistenceFilter extends GenericFilterBean { } @Override - public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - HttpServletRequest request = (HttpServletRequest) req; - HttpServletResponse response = (HttpServletResponse) res; + doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain); + } + private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) + throws IOException, ServletException { + // ensure that filter is only applied once per request if (request.getAttribute(FILTER_APPLIED) != null) { - // ensure that filter is only applied once per request chain.doFilter(request, response); return; } - - final boolean debug = this.logger.isDebugEnabled(); - request.setAttribute(FILTER_APPLIED, Boolean.TRUE); - if (this.forceEagerSessionCreation) { HttpSession session = request.getSession(); - - if (debug && session.isNew()) { - this.logger.debug("Eagerly created session: " + session.getId()); - } + this.logger.debug(LogMessage.format("Eagerly created session: %s", session.getId())); } - HttpRequestResponseHolder holder = new HttpRequestResponseHolder(request, response); SecurityContext contextBeforeChainExecution = this.repo.loadContext(holder); - try { SecurityContextHolder.setContext(contextBeforeChainExecution); - chain.doFilter(holder.getRequest(), holder.getResponse()); - } finally { SecurityContext contextAfterChainExecution = SecurityContextHolder.getContext(); - // Crucial removal of SecurityContextHolder contents - do this before anything - // else. + // Crucial removal of SecurityContextHolder contents before anything else. SecurityContextHolder.clearContext(); this.repo.saveContext(contextAfterChainExecution, holder.getRequest(), holder.getResponse()); request.removeAttribute(FILTER_APPLIED); - - if (debug) { - this.logger.debug("SecurityContextHolder now cleared, as request processing completed"); - } + this.logger.debug("SecurityContextHolder now cleared, as request processing completed"); } } diff --git a/web/src/main/java/org/springframework/security/web/context/request/async/WebAsyncManagerIntegrationFilter.java b/web/src/main/java/org/springframework/security/web/context/request/async/WebAsyncManagerIntegrationFilter.java index 8f2f091bfc..dbe0c65f4c 100644 --- a/web/src/main/java/org/springframework/security/web/context/request/async/WebAsyncManagerIntegrationFilter.java +++ b/web/src/main/java/org/springframework/security/web/context/request/async/WebAsyncManagerIntegrationFilter.java @@ -46,14 +46,12 @@ public final class WebAsyncManagerIntegrationFilter extends OncePerRequestFilter protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request); - SecurityContextCallableProcessingInterceptor securityProcessingInterceptor = (SecurityContextCallableProcessingInterceptor) asyncManager .getCallableInterceptor(CALLABLE_INTERCEPTOR_KEY); if (securityProcessingInterceptor == null) { asyncManager.registerCallableInterceptor(CALLABLE_INTERCEPTOR_KEY, new SecurityContextCallableProcessingInterceptor()); } - filterChain.doFilter(request, response); } diff --git a/web/src/main/java/org/springframework/security/web/context/support/SecurityWebApplicationContextUtils.java b/web/src/main/java/org/springframework/security/web/context/support/SecurityWebApplicationContextUtils.java index ba4721eee8..98e267e04e 100644 --- a/web/src/main/java/org/springframework/security/web/context/support/SecurityWebApplicationContextUtils.java +++ b/web/src/main/java/org/springframework/security/web/context/support/SecurityWebApplicationContextUtils.java @@ -20,6 +20,7 @@ import java.util.Enumeration; import javax.servlet.ServletContext; +import org.springframework.util.Assert; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.support.WebApplicationContextUtils; @@ -47,11 +48,10 @@ public abstract class SecurityWebApplicationContextUtils extends WebApplicationC * @see ServletContext#getAttributeNames() */ public static WebApplicationContext findRequiredWebApplicationContext(ServletContext servletContext) { - WebApplicationContext wac = _findWebApplicationContext(servletContext); - if (wac == null) { - throw new IllegalStateException("No WebApplicationContext found: no ContextLoaderListener registered?"); - } - return wac; + WebApplicationContext webApplicationContext = compatiblyFindWebApplicationContext(servletContext); + Assert.state(webApplicationContext != null, + "No WebApplicationContext found: no ContextLoaderListener registered?"); + return webApplicationContext; } /** @@ -59,23 +59,21 @@ public abstract class SecurityWebApplicationContextUtils extends WebApplicationC * spring framework 4.1.x. * @see #findWebApplicationContext(ServletContext) */ - private static WebApplicationContext _findWebApplicationContext(ServletContext sc) { - WebApplicationContext wac = getWebApplicationContext(sc); - if (wac == null) { + private static WebApplicationContext compatiblyFindWebApplicationContext(ServletContext sc) { + WebApplicationContext webApplicationContext = getWebApplicationContext(sc); + if (webApplicationContext == null) { Enumeration attrNames = sc.getAttributeNames(); while (attrNames.hasMoreElements()) { String attrName = attrNames.nextElement(); Object attrValue = sc.getAttribute(attrName); if (attrValue instanceof WebApplicationContext) { - if (wac != null) { - throw new IllegalStateException("No unique WebApplicationContext found: more than one " - + "DispatcherServlet registered with publishContext=true?"); - } - wac = (WebApplicationContext) attrValue; + Assert.state(webApplicationContext == null, "No unique WebApplicationContext found: more than one " + + "DispatcherServlet registered with publishContext=true?"); + webApplicationContext = (WebApplicationContext) attrValue; } } } - return wac; + return webApplicationContext; } } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CookieCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/CookieCsrfTokenRepository.java index 761785fb8e..473dbaf56c 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CookieCsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CookieCsrfTokenRepository.java @@ -69,30 +69,13 @@ public final class CookieCsrfTokenRepository implements CsrfTokenRepository { public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) { String tokenValue = (token != null) ? token.getToken() : ""; Cookie cookie = new Cookie(this.cookieName, tokenValue); - if (this.secure == null) { - cookie.setSecure(request.isSecure()); - } - else { - cookie.setSecure(this.secure); - } - - if (this.cookiePath != null && !this.cookiePath.isEmpty()) { - cookie.setPath(this.cookiePath); - } - else { - cookie.setPath(this.getRequestContext(request)); - } - if (token == null) { - cookie.setMaxAge(0); - } - else { - cookie.setMaxAge(-1); - } + cookie.setSecure((this.secure != null) ? this.secure : request.isSecure()); + cookie.setPath(StringUtils.hasLength(this.cookiePath) ? this.cookiePath : this.getRequestContext(request)); + cookie.setMaxAge((token != null) ? -1 : 0); cookie.setHttpOnly(this.cookieHttpOnly); - if (this.cookieDomain != null && !this.cookieDomain.isEmpty()) { + if (StringUtils.hasLength(this.cookieDomain)) { cookie.setDomain(this.cookieDomain); } - response.addCookie(cookie); } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java index 60af71fbea..54de46b586 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java @@ -51,10 +51,8 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt boolean containsToken = this.csrfTokenRepository.loadToken(request) != null; if (containsToken) { this.csrfTokenRepository.saveToken(null, request, response); - CsrfToken newToken = this.csrfTokenRepository.generateToken(request); this.csrfTokenRepository.saveToken(newToken, request, response); - request.setAttribute(CsrfToken.class.getName(), newToken); request.setAttribute(newToken.getParameterName(), newToken); } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index e2b34d8f46..8f91907895 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -29,6 +29,8 @@ import javax.servlet.http.HttpSession; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; +import org.springframework.security.access.AccessDeniedException; import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.access.AccessDeniedHandlerImpl; import org.springframework.security.web.util.UrlUtils; @@ -97,39 +99,30 @@ public final class CsrfFilter extends OncePerRequestFilter { protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { request.setAttribute(HttpServletResponse.class.getName(), response); - CsrfToken csrfToken = this.tokenRepository.loadToken(request); - final boolean missingToken = csrfToken == null; + boolean missingToken = (csrfToken == null); if (missingToken) { csrfToken = this.tokenRepository.generateToken(request); this.tokenRepository.saveToken(csrfToken, request, response); } request.setAttribute(CsrfToken.class.getName(), csrfToken); request.setAttribute(csrfToken.getParameterName(), csrfToken); - if (!this.requireCsrfProtectionMatcher.matches(request)) { filterChain.doFilter(request, response); return; } - String actualToken = request.getHeader(csrfToken.getHeaderName()); if (actualToken == null) { actualToken = request.getParameter(csrfToken.getParameterName()); } if (!csrfToken.getToken().equals(actualToken)) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request)); - } - if (missingToken) { - this.accessDeniedHandler.handle(request, response, new MissingCsrfTokenException(actualToken)); - } - else { - this.accessDeniedHandler.handle(request, response, - new InvalidCsrfTokenException(csrfToken, actualToken)); - } + this.logger.debug( + LogMessage.of(() -> "Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request))); + AccessDeniedException exception = (!missingToken) ? new InvalidCsrfTokenException(csrfToken, actualToken) + : new MissingCsrfTokenException(actualToken); + this.accessDeniedHandler.handle(request, response, exception); return; } - filterChain.doFilter(request, response); } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfToken.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfToken.java index 9a2cb42a75..bc59a2e496 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfToken.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfToken.java @@ -24,7 +24,6 @@ import java.io.Serializable; * @author Rob Winch * @since 3.2 * @see DefaultCsrfToken - * */ public interface CsrfToken extends Serializable { diff --git a/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java index fcb044a327..b0f4263f2a 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java @@ -87,11 +87,8 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository { private HttpServletResponse getResponse(HttpServletRequest request) { HttpServletResponse response = (HttpServletResponse) request.getAttribute(HTTP_RESPONSE_ATTR); - if (response == null) { - throw new IllegalArgumentException( - "The HttpServletRequest attribute must contain an HttpServletResponse for the attribute " - + HTTP_RESPONSE_ATTR); - } + Assert.notNull(response, () -> "The HttpServletRequest attribute must contain an HttpServletResponse " + + "for the attribute " + HTTP_RESPONSE_ATTR); return response; } @@ -166,7 +163,6 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository { if (this.tokenRepository == null) { return; } - synchronized (this) { if (this.tokenRepository != null) { this.tokenRepository.saveToken(this.delegate, this.request, this.response); diff --git a/web/src/main/java/org/springframework/security/web/debug/DebugFilter.java b/web/src/main/java/org/springframework/security/web/debug/DebugFilter.java index c1fa2cb84a..a25b7927ac 100644 --- a/web/src/main/java/org/springframework/security/web/debug/DebugFilter.java +++ b/web/src/main/java/org/springframework/security/web/debug/DebugFilter.java @@ -50,35 +50,35 @@ public final class DebugFilter implements Filter { static final String ALREADY_FILTERED_ATTR_NAME = DebugFilter.class.getName().concat(".FILTERED"); - private final FilterChainProxy fcp; + private final FilterChainProxy filterChainProxy; private final Logger logger = new Logger(); - public DebugFilter(FilterChainProxy fcp) { - this.fcp = fcp; + public DebugFilter(FilterChainProxy filterChainProxy) { + this.filterChainProxy = filterChainProxy; } @Override - public void doFilter(ServletRequest srvltRequest, ServletResponse srvltResponse, FilterChain filterChain) + public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws ServletException, IOException { - - if (!(srvltRequest instanceof HttpServletRequest) || !(srvltResponse instanceof HttpServletResponse)) { + if (!(request instanceof HttpServletRequest) || !(response instanceof HttpServletResponse)) { throw new ServletException("DebugFilter just supports HTTP requests"); } - HttpServletRequest request = (HttpServletRequest) srvltRequest; - HttpServletResponse response = (HttpServletResponse) srvltResponse; + doFilter((HttpServletRequest) request, (HttpServletResponse) response, filterChain); + } + private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws IOException, ServletException { List filters = getFilters(request); this.logger.info("Request received for " + request.getMethod() + " '" + UrlUtils.buildRequestUrl(request) + "':\n\n" + request + "\n\n" + "servletPath:" + request.getServletPath() + "\n" + "pathInfo:" + request.getPathInfo() + "\n" + "headers: \n" + formatHeaders(request) + "\n\n" + formatFilters(filters)); - if (request.getAttribute(ALREADY_FILTERED_ATTR_NAME) == null) { invokeWithWrappedRequest(request, response, filterChain); } else { - this.fcp.doFilter(request, response, filterChain); + this.filterChainProxy.doFilter(request, response, filterChain); } } @@ -87,7 +87,7 @@ public final class DebugFilter implements Filter { request.setAttribute(ALREADY_FILTERED_ATTR_NAME, Boolean.TRUE); request = new DebugRequestWrapper(request); try { - this.fcp.doFilter(request, response, filterChain); + this.filterChainProxy.doFilter(request, response, filterChain); } finally { request.removeAttribute(ALREADY_FILTERED_ATTR_NAME); @@ -134,7 +134,7 @@ public final class DebugFilter implements Filter { } private List getFilters(HttpServletRequest request) { - for (SecurityFilterChain chain : this.fcp.getFilterChains()) { + for (SecurityFilterChain chain : this.filterChainProxy.getFilterChains()) { if (chain.matches(request)) { return chain.getFilters(); } @@ -163,11 +163,9 @@ public final class DebugFilter implements Filter { public HttpSession getSession() { boolean sessionExists = super.getSession(false) != null; HttpSession session = super.getSession(); - if (!sessionExists) { DebugRequestWrapper.logger.info("New HTTP session created: " + session.getId(), true); } - return session; } diff --git a/web/src/main/java/org/springframework/security/web/firewall/DefaultHttpFirewall.java b/web/src/main/java/org/springframework/security/web/firewall/DefaultHttpFirewall.java index ed6b51922a..aec01583a5 100644 --- a/web/src/main/java/org/springframework/security/web/firewall/DefaultHttpFirewall.java +++ b/web/src/main/java/org/springframework/security/web/firewall/DefaultHttpFirewall.java @@ -50,19 +50,17 @@ public class DefaultHttpFirewall implements HttpFirewall { @Override public FirewalledRequest getFirewalledRequest(HttpServletRequest request) throws RequestRejectedException { - FirewalledRequest fwr = new RequestWrapper(request); - - if (!isNormalized(fwr.getServletPath()) || !isNormalized(fwr.getPathInfo())) { - throw new RequestRejectedException("Un-normalized paths are not supported: " + fwr.getServletPath() - + ((fwr.getPathInfo() != null) ? fwr.getPathInfo() : "")); + FirewalledRequest firewalledRequest = new RequestWrapper(request); + if (!isNormalized(firewalledRequest.getServletPath()) || !isNormalized(firewalledRequest.getPathInfo())) { + throw new RequestRejectedException( + "Un-normalized paths are not supported: " + firewalledRequest.getServletPath() + + ((firewalledRequest.getPathInfo() != null) ? firewalledRequest.getPathInfo() : "")); } - - String requestURI = fwr.getRequestURI(); + String requestURI = firewalledRequest.getRequestURI(); if (containsInvalidUrlEncodedSlash(requestURI)) { throw new RequestRejectedException("The requestURI cannot contain encoded slash. Got " + requestURI); } - - return fwr; + return firewalledRequest; } @Override @@ -89,11 +87,9 @@ public class DefaultHttpFirewall implements HttpFirewall { if (this.allowUrlEncodedSlash || uri == null) { return false; } - if (uri.contains("%2f") || uri.contains("%2F")) { return true; } - return false; } @@ -107,22 +103,18 @@ public class DefaultHttpFirewall implements HttpFirewall { if (path == null) { return true; } - - for (int j = path.length(); j > 0;) { - int i = path.lastIndexOf('/', j - 1); - int gap = j - i; - - if (gap == 2 && path.charAt(i + 1) == '.') { + for (int i = path.length(); i > 0;) { + int slashIndex = path.lastIndexOf('/', i - 1); + int gap = i - slashIndex; + if (gap == 2 && path.charAt(slashIndex + 1) == '.') { // ".", "/./" or "/." return false; } - else if (gap == 3 && path.charAt(i + 1) == '.' && path.charAt(i + 2) == '.') { + if (gap == 3 && path.charAt(slashIndex + 1) == '.' && path.charAt(slashIndex + 2) == '.') { return false; } - - j = i; + i = slashIndex; } - return true; } diff --git a/web/src/main/java/org/springframework/security/web/firewall/FirewalledResponse.java b/web/src/main/java/org/springframework/security/web/firewall/FirewalledResponse.java index 83db1f9932..0f0bca9b1c 100644 --- a/web/src/main/java/org/springframework/security/web/firewall/FirewalledResponse.java +++ b/web/src/main/java/org/springframework/security/web/firewall/FirewalledResponse.java @@ -22,6 +22,8 @@ import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponseWrapper; +import org.springframework.util.Assert; + /** * @author Luke Taylor * @author Eddú Meléndez @@ -71,9 +73,7 @@ class FirewalledResponse extends HttpServletResponseWrapper { } void validateCrlf(String name, String value) { - if (hasCrlf(name) || hasCrlf(value)) { - throw new IllegalArgumentException("Invalid characters (CR/LF) in header " + name); - } + Assert.isTrue(!hasCrlf(name) && !hasCrlf(value), () -> "Invalid characters (CR/LF) in header " + name); } private boolean hasCrlf(String value) { diff --git a/web/src/main/java/org/springframework/security/web/firewall/HttpStatusRequestRejectedHandler.java b/web/src/main/java/org/springframework/security/web/firewall/HttpStatusRequestRejectedHandler.java index 0946038da9..36a9df93e2 100644 --- a/web/src/main/java/org/springframework/security/web/firewall/HttpStatusRequestRejectedHandler.java +++ b/web/src/main/java/org/springframework/security/web/firewall/HttpStatusRequestRejectedHandler.java @@ -24,6 +24,8 @@ import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; + /** * A simple implementation of {@link RequestRejectedHandler} that sends an error with * configurable status code. @@ -55,10 +57,8 @@ public class HttpStatusRequestRejectedHandler implements RequestRejectedHandler @Override public void handle(HttpServletRequest request, HttpServletResponse response, RequestRejectedException requestRejectedException) throws IOException { - if (logger.isDebugEnabled()) { - logger.debug("Rejecting request due to: " + requestRejectedException.getMessage(), - requestRejectedException); - } + logger.debug(LogMessage.format("Rejecting request due to: %s", requestRejectedException.getMessage()), + requestRejectedException); response.sendError(this.httpError); } diff --git a/web/src/main/java/org/springframework/security/web/firewall/RequestWrapper.java b/web/src/main/java/org/springframework/security/web/firewall/RequestWrapper.java index d0c559af22..80a41a0a82 100644 --- a/web/src/main/java/org/springframework/security/web/firewall/RequestWrapper.java +++ b/web/src/main/java/org/springframework/security/web/firewall/RequestWrapper.java @@ -74,10 +74,8 @@ final class RequestWrapper extends FirewalledRequest { if (path == null) { return null; } - - int scIndex = path.indexOf(';'); - - if (scIndex < 0) { + int semicolonIndex = path.indexOf(';'); + if (semicolonIndex < 0) { int doubleSlashIndex = path.indexOf("//"); if (doubleSlashIndex < 0) { // Most likely case, no parameters in any segment and no '//', so no @@ -85,29 +83,23 @@ final class RequestWrapper extends FirewalledRequest { return path; } } - - StringTokenizer st = new StringTokenizer(path, "/"); + StringTokenizer tokenizer = new StringTokenizer(path, "/"); StringBuilder stripped = new StringBuilder(path.length()); - if (path.charAt(0) == '/') { stripped.append('/'); } - - while (st.hasMoreTokens()) { - String segment = st.nextToken(); - scIndex = segment.indexOf(';'); - - if (scIndex >= 0) { - segment = segment.substring(0, scIndex); + while (tokenizer.hasMoreTokens()) { + String segment = tokenizer.nextToken(); + semicolonIndex = segment.indexOf(';'); + if (semicolonIndex >= 0) { + segment = segment.substring(0, semicolonIndex); } stripped.append(segment).append('/'); } - // Remove the trailing slash if the original path didn't have one if (path.charAt(path.length() - 1) != '/') { stripped.deleteCharAt(stripped.length() - 1); } - return stripped.toString(); } diff --git a/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java b/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java index f8893c0c55..9aa6868aca 100644 --- a/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java +++ b/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java @@ -31,6 +31,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.springframework.http.HttpMethod; +import org.springframework.util.Assert; /** *

@@ -83,7 +84,7 @@ public class StrictHttpFirewall implements HttpFirewall { * Used to specify to {@link #setAllowedHttpMethods(Collection)} that any HTTP method * should be allowed. */ - private static final Set ALLOW_ANY_HTTP_METHOD = Collections.unmodifiableSet(Collections.emptySet()); + private static final Set ALLOW_ANY_HTTP_METHOD = Collections.emptySet(); private static final String ENCODED_PERCENT = "%25"; @@ -165,15 +166,9 @@ public class StrictHttpFirewall implements HttpFirewall { * @see #setUnsafeAllowAnyHttpMethod(boolean) */ public void setAllowedHttpMethods(Collection allowedHttpMethods) { - if (allowedHttpMethods == null) { - throw new IllegalArgumentException("allowedHttpMethods cannot be null"); - } - if (allowedHttpMethods == ALLOW_ANY_HTTP_METHOD) { - this.allowedHttpMethods = ALLOW_ANY_HTTP_METHOD; - } - else { - this.allowedHttpMethods = new HashSet<>(allowedHttpMethods); - } + Assert.notNull(allowedHttpMethods, "allowedHttpMethods cannot be null"); + this.allowedHttpMethods = (allowedHttpMethods != ALLOW_ANY_HTTP_METHOD) ? new HashSet<>(allowedHttpMethods) + : ALLOW_ANY_HTTP_METHOD; } /** @@ -361,9 +356,7 @@ public class StrictHttpFirewall implements HttpFirewall { * @see Character#isDefined(int) */ public void setAllowedHeaderNames(Predicate allowedHeaderNames) { - if (allowedHeaderNames == null) { - throw new IllegalArgumentException("allowedHeaderNames cannot be null"); - } + Assert.notNull(allowedHeaderNames, "allowedHeaderNames cannot be null"); this.allowedHeaderNames = allowedHeaderNames; } @@ -378,28 +371,20 @@ public class StrictHttpFirewall implements HttpFirewall { * @see Character#isDefined(int) */ public void setAllowedHeaderValues(Predicate allowedHeaderValues) { - if (allowedHeaderValues == null) { - throw new IllegalArgumentException("allowedHeaderValues cannot be null"); - } + Assert.notNull(allowedHeaderValues, "allowedHeaderValues cannot be null"); this.allowedHeaderValues = allowedHeaderValues; } - /* + /** * Determines which parameter names should be allowed. The default is to reject header - * names that contain ISO control characters and characters that are not defined.

- * + * names that contain ISO control characters and characters that are not defined. * @param allowedParameterNames the predicate for testing parameter names - * - * @see Character#isISOControl(int) - * - * @see Character#isDefined(int) - * * @since 5.4 + * @see Character#isISOControl(int) + * @see Character#isDefined(int) */ public void setAllowedParameterNames(Predicate allowedParameterNames) { - if (allowedParameterNames == null) { - throw new IllegalArgumentException("allowedParameterNames cannot be null"); - } + Assert.notNull(allowedParameterNames, "allowedParameterNames cannot be null"); this.allowedParameterNames = allowedParameterNames; } @@ -412,9 +397,7 @@ public class StrictHttpFirewall implements HttpFirewall { * @since 5.4 */ public void setAllowedParameterValues(Predicate allowedParameterValues) { - if (allowedParameterValues == null) { - throw new IllegalArgumentException("allowedParameterValues cannot be null"); - } + Assert.notNull(allowedParameterValues, "allowedParameterValues cannot be null"); this.allowedParameterValues = allowedParameterValues; } @@ -426,9 +409,7 @@ public class StrictHttpFirewall implements HttpFirewall { * @since 5.2 */ public void setAllowedHostnames(Predicate allowedHostnames) { - if (allowedHostnames == null) { - throw new IllegalArgumentException("allowedHostnames cannot be null"); - } + Assert.notNull(allowedHostnames, "allowedHostnames cannot be null"); this.allowedHostnames = allowedHostnames; } @@ -447,173 +428,15 @@ public class StrictHttpFirewall implements HttpFirewall { rejectForbiddenHttpMethod(request); rejectedBlocklistedUrls(request); rejectedUntrustedHosts(request); - if (!isNormalized(request)) { throw new RequestRejectedException("The request was rejected because the URL was not normalized."); } - String requestUri = request.getRequestURI(); if (!containsOnlyPrintableAsciiCharacters(requestUri)) { throw new RequestRejectedException( "The requestURI was rejected because it can only contain printable ASCII characters."); } - return new FirewalledRequest(request) { - @Override - public long getDateHeader(String name) { - if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) { - throw new RequestRejectedException( - "The request was rejected because the header name \"" + name + "\" is not allowed."); - } - return super.getDateHeader(name); - } - - @Override - public int getIntHeader(String name) { - if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) { - throw new RequestRejectedException( - "The request was rejected because the header name \"" + name + "\" is not allowed."); - } - return super.getIntHeader(name); - } - - @Override - public String getHeader(String name) { - if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) { - throw new RequestRejectedException( - "The request was rejected because the header name \"" + name + "\" is not allowed."); - } - String value = super.getHeader(name); - if (value != null && !StrictHttpFirewall.this.allowedHeaderValues.test(value)) { - throw new RequestRejectedException( - "The request was rejected because the header value \"" + value + "\" is not allowed."); - } - return value; - } - - @Override - public Enumeration getHeaders(String name) { - if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) { - throw new RequestRejectedException( - "The request was rejected because the header name \"" + name + "\" is not allowed."); - } - - Enumeration valuesEnumeration = super.getHeaders(name); - return new Enumeration() { - @Override - public boolean hasMoreElements() { - return valuesEnumeration.hasMoreElements(); - } - - @Override - public String nextElement() { - String value = valuesEnumeration.nextElement(); - if (!StrictHttpFirewall.this.allowedHeaderValues.test(value)) { - throw new RequestRejectedException("The request was rejected because the header value \"" - + value + "\" is not allowed."); - } - return value; - } - }; - } - - @Override - public Enumeration getHeaderNames() { - Enumeration namesEnumeration = super.getHeaderNames(); - return new Enumeration() { - @Override - public boolean hasMoreElements() { - return namesEnumeration.hasMoreElements(); - } - - @Override - public String nextElement() { - String name = namesEnumeration.nextElement(); - if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) { - throw new RequestRejectedException("The request was rejected because the header name \"" - + name + "\" is not allowed."); - } - return name; - } - }; - } - - @Override - public String getParameter(String name) { - if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) { - throw new RequestRejectedException( - "The request was rejected because the parameter name \"" + name + "\" is not allowed."); - } - String value = super.getParameter(name); - if (value != null && !StrictHttpFirewall.this.allowedParameterValues.test(value)) { - throw new RequestRejectedException( - "The request was rejected because the parameter value \"" + value + "\" is not allowed."); - } - return value; - } - - @Override - public Map getParameterMap() { - Map parameterMap = super.getParameterMap(); - for (Map.Entry entry : parameterMap.entrySet()) { - String name = entry.getKey(); - String[] values = entry.getValue(); - if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) { - throw new RequestRejectedException( - "The request was rejected because the parameter name \"" + name + "\" is not allowed."); - } - for (String value : values) { - if (!StrictHttpFirewall.this.allowedParameterValues.test(value)) { - throw new RequestRejectedException("The request was rejected because the parameter value \"" - + value + "\" is not allowed."); - } - } - } - return parameterMap; - } - - @Override - public Enumeration getParameterNames() { - Enumeration namesEnumeration = super.getParameterNames(); - return new Enumeration() { - @Override - public boolean hasMoreElements() { - return namesEnumeration.hasMoreElements(); - } - - @Override - public String nextElement() { - String name = namesEnumeration.nextElement(); - if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) { - throw new RequestRejectedException("The request was rejected because the parameter name \"" - + name + "\" is not allowed."); - } - return name; - } - }; - } - - @Override - public String[] getParameterValues(String name) { - if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) { - throw new RequestRejectedException( - "The request was rejected because the parameter name \"" + name + "\" is not allowed."); - } - String[] values = super.getParameterValues(name); - if (values != null) { - for (String value : values) { - if (!StrictHttpFirewall.this.allowedParameterValues.test(value)) { - throw new RequestRejectedException("The request was rejected because the parameter value \"" - + value + "\" is not allowed."); - } - } - } - return values; - } - - @Override - public void reset() { - } - }; + return new StrictFirewalledRequest(request); } private void rejectForbiddenHttpMethod(HttpServletRequest request) { @@ -705,12 +528,11 @@ public class StrictHttpFirewall implements HttpFirewall { private static boolean containsOnlyPrintableAsciiCharacters(String uri) { int length = uri.length(); for (int i = 0; i < length; i++) { - char c = uri.charAt(i); - if (c < '\u0020' || c > '\u007e') { + char ch = uri.charAt(i); + if (ch < '\u0020' || ch > '\u007e') { return false; } } - return true; } @@ -728,22 +550,17 @@ public class StrictHttpFirewall implements HttpFirewall { if (path == null) { return true; } - - for (int j = path.length(); j > 0;) { - int i = path.lastIndexOf('/', j - 1); - int gap = j - i; - - if (gap == 2 && path.charAt(i + 1) == '.') { - // ".", "/./" or "/." + for (int i = path.length(); i > 0;) { + int slashIndex = path.lastIndexOf('/', i - 1); + int gap = i - slashIndex; + if (gap == 2 && path.charAt(slashIndex + 1) == '.') { + return false; // ".", "/./" or "/." + } + if (gap == 3 && path.charAt(slashIndex + 1) == '.' && path.charAt(slashIndex + 2) == '.') { return false; } - else if (gap == 3 && path.charAt(i + 1) == '.' && path.charAt(i + 2) == '.') { - return false; - } - - j = i; + i = slashIndex; } - return true; } @@ -782,4 +599,166 @@ public class StrictHttpFirewall implements HttpFirewall { return getDecodedUrlBlocklist(); } + /** + * Strict {@link FirewalledRequest}. + */ + private class StrictFirewalledRequest extends FirewalledRequest { + + StrictFirewalledRequest(HttpServletRequest request) { + super(request); + } + + @Override + public long getDateHeader(String name) { + validateAllowedHeaderName(name); + return super.getDateHeader(name); + } + + @Override + public int getIntHeader(String name) { + validateAllowedHeaderName(name); + return super.getIntHeader(name); + } + + @Override + public String getHeader(String name) { + validateAllowedHeaderName(name); + String value = super.getHeader(name); + if (value != null) { + validateAllowedHeaderValue(value); + } + return value; + } + + @Override + public Enumeration getHeaders(String name) { + validateAllowedHeaderName(name); + Enumeration headers = super.getHeaders(name); + return new Enumeration() { + + @Override + public boolean hasMoreElements() { + return headers.hasMoreElements(); + } + + @Override + public String nextElement() { + String value = headers.nextElement(); + validateAllowedHeaderValue(value); + return value; + } + + }; + } + + @Override + public Enumeration getHeaderNames() { + Enumeration names = super.getHeaderNames(); + return new Enumeration() { + + @Override + public boolean hasMoreElements() { + return names.hasMoreElements(); + } + + @Override + public String nextElement() { + String headerNames = names.nextElement(); + validateAllowedHeaderName(headerNames); + return headerNames; + } + + }; + } + + @Override + public String getParameter(String name) { + validateAllowedParameterName(name); + String value = super.getParameter(name); + if (value != null) { + validateAllowedParameterValue(value); + } + return value; + } + + @Override + public Map getParameterMap() { + Map parameterMap = super.getParameterMap(); + for (Map.Entry entry : parameterMap.entrySet()) { + String name = entry.getKey(); + String[] values = entry.getValue(); + validateAllowedParameterName(name); + for (String value : values) { + validateAllowedParameterValue(value); + } + } + return parameterMap; + } + + @Override + public Enumeration getParameterNames() { + Enumeration paramaterNames = super.getParameterNames(); + return new Enumeration() { + + @Override + public boolean hasMoreElements() { + return paramaterNames.hasMoreElements(); + } + + @Override + public String nextElement() { + String name = paramaterNames.nextElement(); + validateAllowedParameterName(name); + return name; + } + + }; + } + + @Override + public String[] getParameterValues(String name) { + validateAllowedParameterName(name); + String[] values = super.getParameterValues(name); + if (values != null) { + for (String value : values) { + validateAllowedParameterValue(value); + } + } + return values; + } + + private void validateAllowedHeaderName(String headerNames) { + if (!StrictHttpFirewall.this.allowedHeaderNames.test(headerNames)) { + throw new RequestRejectedException( + "The request was rejected because the header name \"" + headerNames + "\" is not allowed."); + } + } + + private void validateAllowedHeaderValue(String value) { + if (!StrictHttpFirewall.this.allowedHeaderValues.test(value)) { + throw new RequestRejectedException( + "The request was rejected because the header value \"" + value + "\" is not allowed."); + } + } + + private void validateAllowedParameterName(String name) { + if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) { + throw new RequestRejectedException( + "The request was rejected because the parameter name \"" + name + "\" is not allowed."); + } + } + + private void validateAllowedParameterValue(String value) { + if (!StrictHttpFirewall.this.allowedParameterValues.test(value)) { + throw new RequestRejectedException( + "The request was rejected because the parameter value \"" + value + "\" is not allowed."); + } + } + + @Override + public void reset() { + } + + }; + } diff --git a/web/src/main/java/org/springframework/security/web/header/Header.java b/web/src/main/java/org/springframework/security/web/header/Header.java index b1c6aa1576..2356914712 100644 --- a/web/src/main/java/org/springframework/security/web/header/Header.java +++ b/web/src/main/java/org/springframework/security/web/header/Header.java @@ -62,20 +62,18 @@ public final class Header { } @Override - public boolean equals(Object o) { - if (this == o) { + public boolean equals(Object obj) { + if (this == obj) { return true; } - if (o == null || getClass() != o.getClass()) { + if (obj == null || getClass() != obj.getClass()) { return false; } - - Header header = (Header) o; - - if (!this.headerName.equals(header.headerName)) { + Header other = (Header) obj; + if (!this.headerName.equals(other.headerName)) { return false; } - return this.headerValues.equals(header.headerValues); + return this.headerValues.equals(other.headerValues); } @Override diff --git a/web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java b/web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java index f59983f47c..38529ecc86 100644 --- a/web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java +++ b/web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java @@ -68,7 +68,6 @@ public class HeaderWriterFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - if (this.shouldWriteHeadersEagerly) { doHeadersBefore(request, response, filterChain); } diff --git a/web/src/main/java/org/springframework/security/web/header/writers/ClearSiteDataHeaderWriter.java b/web/src/main/java/org/springframework/security/web/header/writers/ClearSiteDataHeaderWriter.java index 3ba25039a8..d4df85fbee 100644 --- a/web/src/main/java/org/springframework/security/web/header/writers/ClearSiteDataHeaderWriter.java +++ b/web/src/main/java/org/springframework/security/web/header/writers/ClearSiteDataHeaderWriter.java @@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.security.web.header.HeaderWriter; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; @@ -76,10 +77,9 @@ public final class ClearSiteDataHeaderWriter implements HeaderWriter { response.setHeader(CLEAR_SITE_DATA_HEADER, this.headerValue); } } - else if (this.logger.isDebugEnabled()) { - this.logger.debug("Not injecting Clear-Site-Data header since it did not match the " + "requestMatcher " - + this.requestMatcher); - } + this.logger.debug( + LogMessage.format("Not injecting Clear-Site-Data header since it did not match the requestMatcher %s", + this.requestMatcher)); } private String transformToHeaderValue(Directive... directives) { @@ -97,14 +97,19 @@ public final class ClearSiteDataHeaderWriter implements HeaderWriter { } /** - *

- * Represents the directive values expected by the {@link ClearSiteDataHeaderWriter} - *

- * . + * Represents the directive values expected by the {@link ClearSiteDataHeaderWriter}. */ public enum Directive { - CACHE("cache"), COOKIES("cookies"), STORAGE("storage"), EXECUTION_CONTEXTS("executionContexts"), ALL("*"); + CACHE("cache"), + + COOKIES("cookies"), + + STORAGE("storage"), + + EXECUTION_CONTEXTS("executionContexts"), + + ALL("*"); private final String headerValue; diff --git a/web/src/main/java/org/springframework/security/web/header/writers/ContentSecurityPolicyHeaderWriter.java b/web/src/main/java/org/springframework/security/web/header/writers/ContentSecurityPolicyHeaderWriter.java index a782869052..5de10d04b8 100644 --- a/web/src/main/java/org/springframework/security/web/header/writers/ContentSecurityPolicyHeaderWriter.java +++ b/web/src/main/java/org/springframework/security/web/header/writers/ContentSecurityPolicyHeaderWriter.java @@ -117,7 +117,7 @@ public final class ContentSecurityPolicyHeaderWriter implements HeaderWriter { */ @Override public void writeHeaders(HttpServletRequest request, HttpServletResponse response) { - String headerName = !this.reportOnly ? CONTENT_SECURITY_POLICY_HEADER + String headerName = (!this.reportOnly) ? CONTENT_SECURITY_POLICY_HEADER : CONTENT_SECURITY_POLICY_REPORT_ONLY_HEADER; if (!response.containsHeader(headerName)) { response.setHeader(headerName, this.policyDirectives); diff --git a/web/src/main/java/org/springframework/security/web/header/writers/HpkpHeaderWriter.java b/web/src/main/java/org/springframework/security/web/header/writers/HpkpHeaderWriter.java index 89de450887..31ee296fde 100644 --- a/web/src/main/java/org/springframework/security/web/header/writers/HpkpHeaderWriter.java +++ b/web/src/main/java/org/springframework/security/web/header/writers/HpkpHeaderWriter.java @@ -174,19 +174,17 @@ public final class HpkpHeaderWriter implements HeaderWriter { @Override public void writeHeaders(HttpServletRequest request, HttpServletResponse response) { - if (this.requestMatcher.matches(request)) { - if (!this.pins.isEmpty()) { - String headerName = this.reportOnly ? HPKP_RO_HEADER_NAME : HPKP_HEADER_NAME; - if (!response.containsHeader(headerName)) { - response.setHeader(headerName, this.hpkpHeaderValue); - } - } - if (this.logger.isDebugEnabled()) { - this.logger.debug("Not injecting HPKP header since there aren't any pins"); - } - } - else if (this.logger.isDebugEnabled()) { + if (!this.requestMatcher.matches(request)) { this.logger.debug("Not injecting HPKP header since it wasn't a secure connection"); + return; + } + if (this.pins.isEmpty()) { + this.logger.debug("Not injecting HPKP header since there aren't any pins"); + return; + } + String headerName = (this.reportOnly) ? HPKP_RO_HEADER_NAME : HPKP_HEADER_NAME; + if (!response.containsHeader(headerName)) { + response.setHeader(headerName, this.hpkpHeaderValue); } } @@ -294,9 +292,7 @@ public final class HpkpHeaderWriter implements HeaderWriter { * @throws IllegalArgumentException if maxAgeInSeconds is negative */ public void setMaxAgeInSeconds(long maxAgeInSeconds) { - if (maxAgeInSeconds < 0) { - throw new IllegalArgumentException("maxAgeInSeconds must be non-negative. Got " + maxAgeInSeconds); - } + Assert.isTrue(maxAgeInSeconds > 0, () -> "maxAgeInSeconds must be non-negative. Got " + maxAgeInSeconds); this.maxAgeInSeconds = maxAgeInSeconds; updateHpkpHeaderValue(); } @@ -414,11 +410,11 @@ public final class HpkpHeaderWriter implements HeaderWriter { public void setReportUri(String reportUri) { try { this.reportUri = new URI(reportUri); + updateHpkpHeaderValue(); } catch (URISyntaxException ex) { throw new IllegalArgumentException(ex); } - updateHpkpHeaderValue(); } private void updateHpkpHeaderValue() { diff --git a/web/src/main/java/org/springframework/security/web/header/writers/HstsHeaderWriter.java b/web/src/main/java/org/springframework/security/web/header/writers/HstsHeaderWriter.java index 7d1ee6b5e5..c88bc5af87 100644 --- a/web/src/main/java/org/springframework/security/web/header/writers/HstsHeaderWriter.java +++ b/web/src/main/java/org/springframework/security/web/header/writers/HstsHeaderWriter.java @@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.security.web.header.HeaderWriter; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; @@ -148,14 +149,13 @@ public final class HstsHeaderWriter implements HeaderWriter { @Override public void writeHeaders(HttpServletRequest request, HttpServletResponse response) { - if (this.requestMatcher.matches(request)) { - if (!response.containsHeader(HSTS_HEADER_NAME)) { - response.setHeader(HSTS_HEADER_NAME, this.hstsHeaderValue); - } + if (!this.requestMatcher.matches(request)) { + this.logger.debug(LogMessage.format( + "Not injecting HSTS header since it did not match the requestMatcher %s", this.requestMatcher)); + return; } - else if (this.logger.isDebugEnabled()) { - this.logger.debug( - "Not injecting HSTS header since it did not match the requestMatcher " + this.requestMatcher); + if (!response.containsHeader(HSTS_HEADER_NAME)) { + response.setHeader(HSTS_HEADER_NAME, this.hstsHeaderValue); } } @@ -188,9 +188,7 @@ public final class HstsHeaderWriter implements HeaderWriter { * @throws IllegalArgumentException if maxAgeInSeconds is negative */ public void setMaxAgeInSeconds(long maxAgeInSeconds) { - if (maxAgeInSeconds < 0) { - throw new IllegalArgumentException("maxAgeInSeconds must be non-negative. Got " + maxAgeInSeconds); - } + Assert.isTrue(maxAgeInSeconds >= 0, () -> "maxAgeInSeconds must be non-negative. Got " + maxAgeInSeconds); this.maxAgeInSeconds = maxAgeInSeconds; updateHstsHeaderValue(); } diff --git a/web/src/main/java/org/springframework/security/web/header/writers/ReferrerPolicyHeaderWriter.java b/web/src/main/java/org/springframework/security/web/header/writers/ReferrerPolicyHeaderWriter.java index 0397c999cf..eab56fc389 100644 --- a/web/src/main/java/org/springframework/security/web/header/writers/ReferrerPolicyHeaderWriter.java +++ b/web/src/main/java/org/springframework/security/web/header/writers/ReferrerPolicyHeaderWriter.java @@ -100,10 +100,21 @@ public class ReferrerPolicyHeaderWriter implements HeaderWriter { public enum ReferrerPolicy { - NO_REFERRER("no-referrer"), NO_REFERRER_WHEN_DOWNGRADE("no-referrer-when-downgrade"), SAME_ORIGIN( - "same-origin"), ORIGIN("origin"), STRICT_ORIGIN("strict-origin"), ORIGIN_WHEN_CROSS_ORIGIN( - "origin-when-cross-origin"), STRICT_ORIGIN_WHEN_CROSS_ORIGIN( - "strict-origin-when-cross-origin"), UNSAFE_URL("unsafe-url"); + NO_REFERRER("no-referrer"), + + NO_REFERRER_WHEN_DOWNGRADE("no-referrer-when-downgrade"), + + SAME_ORIGIN("same-origin"), + + ORIGIN("origin"), + + STRICT_ORIGIN("strict-origin"), + + ORIGIN_WHEN_CROSS_ORIGIN("origin-when-cross-origin"), + + STRICT_ORIGIN_WHEN_CROSS_ORIGIN("strict-origin-when-cross-origin"), + + UNSAFE_URL("unsafe-url"); private static final Map REFERRER_POLICIES; @@ -115,7 +126,7 @@ public class ReferrerPolicyHeaderWriter implements HeaderWriter { REFERRER_POLICIES = Collections.unmodifiableMap(referrerPolicies); } - private String policy; + private final String policy; ReferrerPolicy(String policy) { this.policy = policy; diff --git a/web/src/main/java/org/springframework/security/web/header/writers/frameoptions/AbstractRequestParameterAllowFromStrategy.java b/web/src/main/java/org/springframework/security/web/header/writers/frameoptions/AbstractRequestParameterAllowFromStrategy.java index 0426ccdec0..104a6cb234 100644 --- a/web/src/main/java/org/springframework/security/web/header/writers/frameoptions/AbstractRequestParameterAllowFromStrategy.java +++ b/web/src/main/java/org/springframework/security/web/header/writers/frameoptions/AbstractRequestParameterAllowFromStrategy.java @@ -21,6 +21,7 @@ import javax.servlet.http.HttpServletRequest; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -52,15 +53,11 @@ public abstract class AbstractRequestParameterAllowFromStrategy implements Allow @Override public String getAllowFromValue(HttpServletRequest request) { String allowFromOrigin = request.getParameter(this.allowFromParameterName); - if (this.log.isDebugEnabled()) { - this.log.debug("Supplied origin '" + allowFromOrigin + "'"); - } + this.log.debug(LogMessage.format("Supplied origin '%s'", allowFromOrigin)); if (StringUtils.hasText(allowFromOrigin) && allowed(allowFromOrigin)) { return allowFromOrigin; } - else { - return "DENY"; - } + return "DENY"; } /** diff --git a/web/src/main/java/org/springframework/security/web/header/writers/frameoptions/XFrameOptionsHeaderWriter.java b/web/src/main/java/org/springframework/security/web/header/writers/frameoptions/XFrameOptionsHeaderWriter.java index cf094d753b..9cecae5cc4 100644 --- a/web/src/main/java/org/springframework/security/web/header/writers/frameoptions/XFrameOptionsHeaderWriter.java +++ b/web/src/main/java/org/springframework/security/web/header/writers/frameoptions/XFrameOptionsHeaderWriter.java @@ -55,10 +55,9 @@ public final class XFrameOptionsHeaderWriter implements HeaderWriter { */ public XFrameOptionsHeaderWriter(XFrameOptionsMode frameOptionsMode) { Assert.notNull(frameOptionsMode, "frameOptionsMode cannot be null"); - if (XFrameOptionsMode.ALLOW_FROM.equals(frameOptionsMode)) { - throw new IllegalArgumentException( - "ALLOW_FROM requires an AllowFromStrategy. Please use FrameOptionsHeaderWriter(AllowFromStrategy allowFromStrategy) instead"); - } + Assert.isTrue(!XFrameOptionsMode.ALLOW_FROM.equals(frameOptionsMode), + "ALLOW_FROM requires an AllowFromStrategy. Please use " + + "FrameOptionsHeaderWriter(AllowFromStrategy allowFromStrategy) instead"); this.frameOptionsMode = frameOptionsMode; this.allowFromStrategy = null; } @@ -113,7 +112,10 @@ public final class XFrameOptionsHeaderWriter implements HeaderWriter { */ public enum XFrameOptionsMode { - DENY("DENY"), SAMEORIGIN("SAMEORIGIN"), + DENY("DENY"), + + SAMEORIGIN("SAMEORIGIN"), + /** * @deprecated ALLOW-FROM is an obsolete directive that no longer works in modern * browsers. Instead use Content-Security-Policy with the headers.set(HttpHeaders.AUTHORIZATION, "Bearer " + bearerTokenValue); } - private SecurityHeaders() { - } - } diff --git a/web/src/main/java/org/springframework/security/web/jaasapi/JaasApiIntegrationFilter.java b/web/src/main/java/org/springframework/security/web/jaasapi/JaasApiIntegrationFilter.java index 49e9ff9272..bda60ba088 100644 --- a/web/src/main/java/org/springframework/security/web/jaasapi/JaasApiIntegrationFilter.java +++ b/web/src/main/java/org/springframework/security/web/jaasapi/JaasApiIntegrationFilter.java @@ -27,6 +27,7 @@ import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.jaas.JaasAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; @@ -70,34 +71,26 @@ public class JaasApiIntegrationFilter extends GenericFilterBean { *

*/ @Override - public final void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain) + public final void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws ServletException, IOException { Subject subject = obtainSubject(request); if (subject == null && this.createEmptySubject) { - if (this.logger.isDebugEnabled()) { - this.logger.debug( - "Subject returned was null and createEmtpySubject is true; creating new empty subject to run as."); - } + this.logger.debug("Subject returned was null and createEmtpySubject is true; " + + "creating new empty subject to run as."); subject = new Subject(); } if (subject == null) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("Subject is null continue running with no Subject."); - } + this.logger.debug("Subject is null continue running with no Subject."); chain.doFilter(request, response); return; } - final PrivilegedExceptionAction continueChain = () -> { - chain.doFilter(request, response); - return null; - }; - - if (this.logger.isDebugEnabled()) { - this.logger.debug("Running as Subject " + subject); - } + this.logger.debug(LogMessage.format("Running as Subject %s", subject)); try { - Subject.doAs(subject, continueChain); + Subject.doAs(subject, (PrivilegedExceptionAction) () -> { + chain.doFilter(request, response); + return null; + }); } catch (PrivilegedActionException ex) { throw new ServletException(ex.getMessage(), ex); @@ -121,9 +114,7 @@ public class JaasApiIntegrationFilter extends GenericFilterBean { */ protected Subject obtainSubject(ServletRequest request) { Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); - if (this.logger.isDebugEnabled()) { - this.logger.debug("Attempting to obtainSubject using authentication : " + authentication); - } + this.logger.debug(LogMessage.format("Attempting to obtainSubject using authentication : %s", authentication)); if (authentication == null) { return null; } diff --git a/web/src/main/java/org/springframework/security/web/jackson2/PreAuthenticatedAuthenticationTokenDeserializer.java b/web/src/main/java/org/springframework/security/web/jackson2/PreAuthenticatedAuthenticationTokenDeserializer.java index b1742dcc8b..0993091f5c 100644 --- a/web/src/main/java/org/springframework/security/web/jackson2/PreAuthenticatedAuthenticationTokenDeserializer.java +++ b/web/src/main/java/org/springframework/security/web/jackson2/PreAuthenticatedAuthenticationTokenDeserializer.java @@ -46,6 +46,9 @@ import org.springframework.security.web.authentication.preauth.PreAuthenticatedA */ class PreAuthenticatedAuthenticationTokenDeserializer extends JsonDeserializer { + private static final TypeReference> GRANTED_AUTHORITY_LIST = new TypeReference>() { + }; + /** * This method construct {@link PreAuthenticatedAuthenticationToken} object from * serialized json. @@ -58,28 +61,18 @@ class PreAuthenticatedAuthenticationTokenDeserializer extends JsonDeserializer

authorities = mapper.readValue(readJsonNode(jsonNode, "authorities").traverse(mapper), - new TypeReference>() { - }); - if (authenticated) { - token = new PreAuthenticatedAuthenticationToken(principal, credentials, authorities); - } - else { - token = new PreAuthenticatedAuthenticationToken(principal, credentials); - } + GRANTED_AUTHORITY_LIST); + PreAuthenticatedAuthenticationToken token = (!authenticated) + ? new PreAuthenticatedAuthenticationToken(principal, credentials) + : new PreAuthenticatedAuthenticationToken(principal, credentials, authorities); token.setDetails(readJsonNode(jsonNode, "details")); return token; } diff --git a/web/src/main/java/org/springframework/security/web/method/annotation/AuthenticationPrincipalArgumentResolver.java b/web/src/main/java/org/springframework/security/web/method/annotation/AuthenticationPrincipalArgumentResolver.java index eb455c75cf..c54a81d3b8 100644 --- a/web/src/main/java/org/springframework/security/web/method/annotation/AuthenticationPrincipalArgumentResolver.java +++ b/web/src/main/java/org/springframework/security/web/method/annotation/AuthenticationPrincipalArgumentResolver.java @@ -104,28 +104,21 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet return null; } Object principal = authentication.getPrincipal(); - - AuthenticationPrincipal authPrincipal = findMethodAnnotation(AuthenticationPrincipal.class, parameter); - - String expressionToParse = authPrincipal.expression(); + AuthenticationPrincipal annotation = findMethodAnnotation(AuthenticationPrincipal.class, parameter); + String expressionToParse = annotation.expression(); if (StringUtils.hasLength(expressionToParse)) { StandardEvaluationContext context = new StandardEvaluationContext(); context.setRootObject(principal); context.setVariable("this", principal); context.setBeanResolver(this.beanResolver); - Expression expression = this.parser.parseExpression(expressionToParse); principal = expression.getValue(context); } - if (principal != null && !parameter.getParameterType().isAssignableFrom(principal.getClass())) { - - if (authPrincipal.errorOnInvalidType()) { + if (annotation.errorOnInvalidType()) { throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType()); } - else { - return null; - } + return null; } return principal; } diff --git a/web/src/main/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolver.java b/web/src/main/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolver.java index 007ddc7ceb..6cbdb0c2d2 100644 --- a/web/src/main/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolver.java +++ b/web/src/main/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolver.java @@ -79,17 +79,11 @@ public final class CurrentSecurityContextArgumentResolver implements HandlerMeth private BeanResolver beanResolver; - /** - * {@inheritDoc} - */ @Override public boolean supportsParameter(MethodParameter parameter) { return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; } - /** - * {@inheritDoc} - */ @Override public Object resolveArgument(MethodParameter parameter, ModelAndViewContainer mavContainer, NativeWebRequest webRequest, WebDataBinderFactory binderFactory) { @@ -98,29 +92,22 @@ public final class CurrentSecurityContextArgumentResolver implements HandlerMeth return null; } Object securityContextResult = securityContext; - - CurrentSecurityContext securityContextAnnotation = findMethodAnnotation(CurrentSecurityContext.class, - parameter); - - String expressionToParse = securityContextAnnotation.expression(); + CurrentSecurityContext annotation = findMethodAnnotation(CurrentSecurityContext.class, parameter); + String expressionToParse = annotation.expression(); if (StringUtils.hasLength(expressionToParse)) { StandardEvaluationContext context = new StandardEvaluationContext(); context.setRootObject(securityContext); context.setVariable("this", securityContext); - Expression expression = this.parser.parseExpression(expressionToParse); securityContextResult = expression.getValue(context); } - if (securityContextResult != null && !parameter.getParameterType().isAssignableFrom(securityContextResult.getClass())) { - if (securityContextAnnotation.errorOnInvalidType()) { + if (annotation.errorOnInvalidType()) { throw new ClassCastException( securityContextResult + " is not assignable to " + parameter.getParameterType()); } - else { - return null; - } + return null; } return securityContextResult; } diff --git a/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/AuthenticationPrincipalArgumentResolver.java b/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/AuthenticationPrincipalArgumentResolver.java index 6068d82f9d..5be69521ab 100644 --- a/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/AuthenticationPrincipalArgumentResolver.java +++ b/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/AuthenticationPrincipalArgumentResolver.java @@ -72,37 +72,31 @@ public class AuthenticationPrincipalArgumentResolver extends HandlerMethodArgume public Mono resolveArgument(MethodParameter parameter, BindingContext bindingContext, ServerWebExchange exchange) { ReactiveAdapter adapter = getAdapterRegistry().getAdapter(parameter.getParameterType()); - return ReactiveSecurityContextHolder.getContext().map(SecurityContext::getAuthentication).flatMap((a) -> { - Object p = resolvePrincipal(parameter, a.getPrincipal()); - Mono principal = Mono.justOrEmpty(p); - return (adapter != null) ? Mono.just(adapter.fromPublisher(principal)) : principal; - }); + return ReactiveSecurityContextHolder.getContext().map(SecurityContext::getAuthentication) + .flatMap((authentication) -> { + Mono principal = Mono + .justOrEmpty(resolvePrincipal(parameter, authentication.getPrincipal())); + return (adapter != null) ? Mono.just(adapter.fromPublisher(principal)) : principal; + }); } private Object resolvePrincipal(MethodParameter parameter, Object principal) { - AuthenticationPrincipal authPrincipal = findMethodAnnotation(AuthenticationPrincipal.class, parameter); - - String expressionToParse = authPrincipal.expression(); + AuthenticationPrincipal annotation = findMethodAnnotation(AuthenticationPrincipal.class, parameter); + String expressionToParse = annotation.expression(); if (StringUtils.hasLength(expressionToParse)) { StandardEvaluationContext context = new StandardEvaluationContext(); context.setRootObject(principal); context.setVariable("this", principal); context.setBeanResolver(this.beanResolver); - Expression expression = this.parser.parseExpression(expressionToParse); principal = expression.getValue(context); } - if (isInvalidType(parameter, principal)) { - - if (authPrincipal.errorOnInvalidType()) { + if (annotation.errorOnInvalidType()) { throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType()); } - else { - return null; - } + return null; } - return principal; } diff --git a/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.java b/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.java index 101749642f..a02a9c30b9 100644 --- a/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.java +++ b/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.java @@ -65,17 +65,11 @@ public class CurrentSecurityContextArgumentResolver extends HandlerMethodArgumen this.beanResolver = beanResolver; } - /** - * {@inheritDoc} - */ @Override public boolean supportsParameter(MethodParameter parameter) { return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; } - /** - * {@inheritDoc} - */ @Override public Mono resolveArgument(MethodParameter parameter, BindingContext bindingContext, ServerWebExchange exchange) { @@ -84,10 +78,10 @@ public class CurrentSecurityContextArgumentResolver extends HandlerMethodArgumen if (reactiveSecurityContext == null) { return null; } - return reactiveSecurityContext.flatMap((a) -> { - Object p = resolveSecurityContext(parameter, a); - Mono o = Mono.justOrEmpty(p); - return (adapter != null) ? Mono.just(adapter.fromPublisher(o)) : o; + return reactiveSecurityContext.flatMap((securityContext) -> { + Mono resolvedSecurityContext = Mono.justOrEmpty(resolveSecurityContext(parameter, securityContext)); + return (adapter != null) ? Mono.just(adapter.fromPublisher(resolvedSecurityContext)) + : resolvedSecurityContext; }); } @@ -100,32 +94,24 @@ public class CurrentSecurityContextArgumentResolver extends HandlerMethodArgumen * @return the resolved object from expression. */ private Object resolveSecurityContext(MethodParameter parameter, SecurityContext securityContext) { - CurrentSecurityContext securityContextAnnotation = findMethodAnnotation(CurrentSecurityContext.class, - parameter); - + CurrentSecurityContext annotation = findMethodAnnotation(CurrentSecurityContext.class, parameter); Object securityContextResult = securityContext; - - String expressionToParse = securityContextAnnotation.expression(); + String expressionToParse = annotation.expression(); if (StringUtils.hasLength(expressionToParse)) { StandardEvaluationContext context = new StandardEvaluationContext(); context.setRootObject(securityContext); context.setVariable("this", securityContext); context.setBeanResolver(this.beanResolver); - Expression expression = this.parser.parseExpression(expressionToParse); securityContextResult = expression.getValue(context); } - if (isInvalidType(parameter, securityContextResult)) { - if (securityContextAnnotation.errorOnInvalidType()) { + if (annotation.errorOnInvalidType()) { throw new ClassCastException( securityContextResult + " is not assignable to " + parameter.getParameterType()); } - else { - return null; - } + return null; } - return securityContextResult; } diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/CookieRequestCache.java b/web/src/main/java/org/springframework/security/web/savedrequest/CookieRequestCache.java index 3117a37016..1afefeba74 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/CookieRequestCache.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/CookieRequestCache.java @@ -55,52 +55,49 @@ public class CookieRequestCache implements RequestCache { @Override public void saveRequest(HttpServletRequest request, HttpServletResponse response) { - if (this.requestMatcher.matches(request)) { - String redirectUrl = UrlUtils.buildFullRequestUrl(request); - Cookie savedCookie = new Cookie(COOKIE_NAME, encodeCookie(redirectUrl)); - savedCookie.setMaxAge(COOKIE_MAX_AGE); - savedCookie.setSecure(request.isSecure()); - savedCookie.setPath(getCookiePath(request)); - savedCookie.setHttpOnly(true); - - response.addCookie(savedCookie); - } - else { + if (!this.requestMatcher.matches(request)) { this.logger.debug("Request not saved as configured RequestMatcher did not match"); + return; } + String redirectUrl = UrlUtils.buildFullRequestUrl(request); + Cookie savedCookie = new Cookie(COOKIE_NAME, encodeCookie(redirectUrl)); + savedCookie.setMaxAge(COOKIE_MAX_AGE); + savedCookie.setSecure(request.isSecure()); + savedCookie.setPath(getCookiePath(request)); + savedCookie.setHttpOnly(true); + response.addCookie(savedCookie); } @Override public SavedRequest getRequest(HttpServletRequest request, HttpServletResponse response) { Cookie savedRequestCookie = WebUtils.getCookie(request, COOKIE_NAME); - if (savedRequestCookie != null) { - final String originalURI = decodeCookie(savedRequestCookie.getValue()); - UriComponents uriComponents = UriComponentsBuilder.fromUriString(originalURI).build(); - DefaultSavedRequest.Builder builder = new DefaultSavedRequest.Builder(); - - int port = uriComponents.getPort(); - if (port == -1) { - if ("https".equalsIgnoreCase(uriComponents.getScheme())) { - port = 443; - } - else { - port = 80; - } - } - - final MultiValueMap queryParams = uriComponents.getQueryParams(); - - if (!queryParams.isEmpty()) { - final HashMap parameters = new HashMap<>(queryParams.size()); - queryParams.forEach((key, value) -> parameters.put(key, value.toArray(new String[] {}))); - builder.setParameters(parameters); - } - - return builder.setScheme(uriComponents.getScheme()).setServerName(uriComponents.getHost()) - .setRequestURI(uriComponents.getPath()).setQueryString(uriComponents.getQuery()).setServerPort(port) - .setMethod(request.getMethod()).build(); + if (savedRequestCookie == null) { + return null; } - return null; + String originalURI = decodeCookie(savedRequestCookie.getValue()); + UriComponents uriComponents = UriComponentsBuilder.fromUriString(originalURI).build(); + DefaultSavedRequest.Builder builder = new DefaultSavedRequest.Builder(); + int port = getPort(uriComponents); + MultiValueMap queryParams = uriComponents.getQueryParams(); + if (!queryParams.isEmpty()) { + HashMap parameters = new HashMap<>(queryParams.size()); + queryParams.forEach((key, value) -> parameters.put(key, value.toArray(new String[] {}))); + builder.setParameters(parameters); + } + return builder.setScheme(uriComponents.getScheme()).setServerName(uriComponents.getHost()) + .setRequestURI(uriComponents.getPath()).setQueryString(uriComponents.getQuery()).setServerPort(port) + .setMethod(request.getMethod()).build(); + } + + private int getPort(UriComponents uriComponents) { + int port = uriComponents.getPort(); + if (port != -1) { + return port; + } + if ("https".equalsIgnoreCase(uriComponents.getScheme())) { + return 443; + } + return 80; } @Override @@ -110,10 +107,8 @@ public class CookieRequestCache implements RequestCache { this.logger.debug("saved request doesn't match"); return null; } - else { - this.removeRequest(request, response); - return new SavedRequestAwareWrapper(saved, request); - } + this.removeRequest(request, response); + return new SavedRequestAwareWrapper(saved, request); } @Override @@ -135,21 +130,16 @@ public class CookieRequestCache implements RequestCache { } private static String getCookiePath(HttpServletRequest request) { - final String contextPath = request.getContextPath(); - if (StringUtils.isEmpty(contextPath)) { - return "/"; - } - return contextPath; + String contextPath = request.getContextPath(); + return (!StringUtils.isEmpty(contextPath)) ? contextPath : "/"; } private boolean matchesSavedRequest(HttpServletRequest request, SavedRequest savedRequest) { if (savedRequest == null) { return false; } - else { - String currentUrl = UrlUtils.buildFullRequestUrl(request); - return savedRequest.getRedirectUrl().equals(currentUrl); - } + String currentUrl = UrlUtils.buildFullRequestUrl(request); + return savedRequest.getRedirectUrl().equals(currentUrl); } /** diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java b/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java index 5b8b1aacd9..aaa8e995c2 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java @@ -33,6 +33,7 @@ import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.security.web.PortResolver; import org.springframework.security.web.util.UrlUtils; import org.springframework.util.Assert; @@ -99,13 +100,10 @@ public class DefaultSavedRequest implements SavedRequest { public DefaultSavedRequest(HttpServletRequest request, PortResolver portResolver) { Assert.notNull(request, "Request required"); Assert.notNull(portResolver, "PortResolver required"); - // Cookies addCookies(request.getCookies()); - // Headers Enumeration names = request.getHeaderNames(); - while (names.hasMoreElements()) { String name = names.nextElement(); // Skip If-Modified-Since and If-None-Match header. SEC-1412, SEC-1624. @@ -113,18 +111,14 @@ public class DefaultSavedRequest implements SavedRequest { continue; } Enumeration values = request.getHeaders(name); - while (values.hasMoreElements()) { this.addHeader(name, values.nextElement()); } } - // Locales addLocales(request.getLocales()); - // Parameters addParameters(request.getParameterMap()); - // Primitives this.method = request.getMethod(); this.pathInfo = request.getPathInfo(); @@ -170,8 +164,7 @@ public class DefaultSavedRequest implements SavedRequest { } private void addHeader(String name, String value) { - List values = this.headers.computeIfAbsent(name, (k) -> new ArrayList<>()); - + List values = this.headers.computeIfAbsent(name, (key) -> new ArrayList<>()); values.add(value); } @@ -200,9 +193,7 @@ public class DefaultSavedRequest implements SavedRequest { this.addParameter(paramName, (String[]) paramValues); } else { - if (logger.isWarnEnabled()) { - logger.warn("ServletRequest.getParameterMap() returned non-String array"); - } + logger.warn("ServletRequest.getParameterMap() returned non-String array"); } } } @@ -221,44 +212,34 @@ public class DefaultSavedRequest implements SavedRequest { * @return true if the request is deemed to match this one. */ public boolean doesRequestMatch(HttpServletRequest request, PortResolver portResolver) { - if (!propertyEquals("pathInfo", this.pathInfo, request.getPathInfo())) { return false; } - if (!propertyEquals("queryString", this.queryString, request.getQueryString())) { return false; } - if (!propertyEquals("requestURI", this.requestURI, request.getRequestURI())) { return false; } - if (!"GET".equals(request.getMethod()) && "GET".equals(this.method)) { // A save GET should not match an incoming non-GET method return false; } - if (!propertyEquals("serverPort", this.serverPort, portResolver.getServerPort(request))) { return false; } - if (!propertyEquals("requestURL", this.requestURL, request.getRequestURL().toString())) { return false; } - if (!propertyEquals("scheme", this.scheme, request.getScheme())) { return false; } - if (!propertyEquals("serverName", this.serverName, request.getServerName())) { return false; } - if (!propertyEquals("contextPath", this.contextPath, request.getContextPath())) { return false; } - return propertyEquals("servletPath", this.servletPath, request.getServletPath()); } @@ -270,11 +251,9 @@ public class DefaultSavedRequest implements SavedRequest { @Override public List getCookies() { List cookieList = new ArrayList<>(this.cookies.size()); - for (SavedCookie savedCookie : this.cookies) { cookieList.add(savedCookie.getCookie()); } - return cookieList; } @@ -296,12 +275,7 @@ public class DefaultSavedRequest implements SavedRequest { @Override public List getHeaderValues(String name) { List values = this.headers.get(name); - - if (values == null) { - return Collections.emptyList(); - } - - return values; + return (values != null) ? values : Collections.emptyList(); } @Override @@ -362,35 +336,19 @@ public class DefaultSavedRequest implements SavedRequest { private boolean propertyEquals(String log, Object arg1, Object arg2) { if ((arg1 == null) && (arg2 == null)) { - if (logger.isDebugEnabled()) { - logger.debug(log + ": both null (property equals)"); - } - + logger.debug(LogMessage.format("%s: both null (property equals)", log)); return true; } - if (arg1 == null || arg2 == null) { - if (logger.isDebugEnabled()) { - logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2 + " (property not equals)"); - } - + logger.debug(LogMessage.format("%s: arg1=%s; arg2=%s (property not equals)", log, arg1, arg2)); return false; } - if (arg1.equals(arg2)) { - if (logger.isDebugEnabled()) { - logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2 + " (property equals)"); - } - + logger.debug(LogMessage.format("%s: arg1=%s; arg2=%s (property equals)", log, arg1, arg2)); return true; } - else { - if (logger.isDebugEnabled()) { - logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2 + " (property not equals)"); - } - - return false; - } + logger.debug(LogMessage.format("%s: arg1=%s; arg2=%s (property not equals)", log, arg1, arg2)); + return false; } @Override @@ -514,7 +472,6 @@ public class DefaultSavedRequest implements SavedRequest { savedRequest.locales.addAll(this.locales); } savedRequest.addParameters(this.parameters); - this.headers.remove(HEADER_IF_MODIFIED_SINCE); this.headers.remove(HEADER_IF_NONE_MATCH); for (Map.Entry> entry : this.headers.entrySet()) { diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/Enumerator.java b/web/src/main/java/org/springframework/security/web/savedrequest/Enumerator.java index 45e7e69983..fdcb553b55 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/Enumerator.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/Enumerator.java @@ -78,17 +78,14 @@ public class Enumerator implements Enumeration { * @param clone true to clone iterator */ public Enumerator(Iterator iterator, boolean clone) { - if (!clone) { this.iterator = iterator; } else { List list = new ArrayList<>(); - while (iterator.hasNext()) { list.add(iterator.next()); } - this.iterator = list.iterator(); } } diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/FastHttpDateFormat.java b/web/src/main/java/org/springframework/security/web/savedrequest/FastHttpDateFormat.java index 58bdb08cfd..dd11a3c5b0 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/FastHttpDateFormat.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/FastHttpDateFormat.java @@ -34,36 +34,49 @@ import java.util.TimeZone; */ public final class FastHttpDateFormat { - /** HTTP date format. */ + /** + * HTTP date format. + */ protected static final SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US); - /** The set of SimpleDateFormat formats to use in getDateHeader(). */ + /** + * The set of SimpleDateFormat formats to use in getDateHeader(). + */ protected static final SimpleDateFormat[] formats = { new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US), new SimpleDateFormat("EEEEEE, dd-MMM-yy HH:mm:ss zzz", Locale.US), new SimpleDateFormat("EEE MMMM d HH:mm:ss yyyy", Locale.US) }; - /** GMT time zone - all HTTP dates are on GMT */ + /** + * GMT time zone - all HTTP dates are on GMT + */ protected static final TimeZone gmtZone = TimeZone.getTimeZone("GMT"); static { format.setTimeZone(gmtZone); - formats[0].setTimeZone(gmtZone); formats[1].setTimeZone(gmtZone); formats[2].setTimeZone(gmtZone); } - /** Instant on which the currentDate object was generated. */ + /** + * Instant on which the currentDate object was generated. + */ protected static long currentDateGenerated = 0L; - /** Current formatted date. */ + /** + * Current formatted date. + */ protected static String currentDate = null; - /** Formatter cache. */ + /** + * Formatter cache. + */ protected static final HashMap formatCache = new HashMap<>(); - /** Parser cache. */ + /** + * Parser cache. + */ protected static final HashMap parseCache = new HashMap<>(); private FastHttpDateFormat() { @@ -80,23 +93,18 @@ public final class FastHttpDateFormat { public static String formatDate(long value, DateFormat threadLocalformat) { String cachedDate = null; Long longValue = value; - try { cachedDate = formatCache.get(longValue); } - catch (Exception ignored) { + catch (Exception ex) { } - if (cachedDate != null) { return cachedDate; } - String newDate; Date dateValue = new Date(value); - if (threadLocalformat != null) { newDate = threadLocalformat.format(dateValue); - synchronized (formatCache) { updateCache(formatCache, longValue, newDate); } @@ -107,7 +115,6 @@ public final class FastHttpDateFormat { updateCache(formatCache, longValue, newDate); } } - return newDate; } @@ -117,7 +124,6 @@ public final class FastHttpDateFormat { */ public static String getCurrentDate() { long now = System.currentTimeMillis(); - if ((now - currentDateGenerated) > 1000) { synchronized (format) { if ((now - currentDateGenerated) > 1000) { @@ -126,7 +132,6 @@ public final class FastHttpDateFormat { } } } - return currentDate; } @@ -138,19 +143,16 @@ public final class FastHttpDateFormat { */ private static Long internalParseDate(String value, DateFormat[] formats) { Date date = null; - for (int i = 0; (date == null) && (i < formats.length); i++) { try { date = formats[i].parse(value); } - catch (ParseException ignored) { + catch (ParseException ex) { } } - if (date == null) { return null; } - return date.getTime(); } @@ -164,22 +166,17 @@ public final class FastHttpDateFormat { */ public static long parseDate(String value, DateFormat[] threadLocalformats) { Long cachedDate = null; - try { cachedDate = parseCache.get(value); } - catch (Exception ignored) { + catch (Exception ex) { } - if (cachedDate != null) { return cachedDate; } - Long date; - if (threadLocalformats != null) { date = internalParseDate(value, threadLocalformats); - synchronized (parseCache) { updateCache(parseCache, value, date); } @@ -190,13 +187,7 @@ public final class FastHttpDateFormat { updateCache(parseCache, value, date); } } - - if (date == null) { - return (-1L); - } - else { - return date; - } + return (date != null) ? date : -1L; } /** @@ -210,11 +201,9 @@ public final class FastHttpDateFormat { if (value == null) { return; } - if (cache.size() > 1000) { cache.clear(); } - cache.put(key, value); } diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/HttpSessionRequestCache.java b/web/src/main/java/org/springframework/security/web/savedrequest/HttpSessionRequestCache.java index ba261fb0e7..a3ac70e4c1 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/HttpSessionRequestCache.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/HttpSessionRequestCache.java @@ -23,6 +23,7 @@ import javax.servlet.http.HttpSession; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.security.web.PortResolver; import org.springframework.security.web.PortResolverImpl; import org.springframework.security.web.util.UrlUtils; @@ -57,37 +58,29 @@ public class HttpSessionRequestCache implements RequestCache { */ @Override public void saveRequest(HttpServletRequest request, HttpServletResponse response) { - if (this.requestMatcher.matches(request)) { - DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, this.portResolver); - - if (this.createSessionAllowed || request.getSession(false) != null) { - // Store the HTTP request itself. Used by - // AbstractAuthenticationProcessingFilter - // for redirection after successful authentication (SEC-29) - request.getSession().setAttribute(this.sessionAttrName, savedRequest); - this.logger.debug("DefaultSavedRequest added to Session: " + savedRequest); - } - } - else { + if (!this.requestMatcher.matches(request)) { this.logger.debug("Request not saved as configured RequestMatcher did not match"); + return; + } + DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, this.portResolver); + if (this.createSessionAllowed || request.getSession(false) != null) { + // Store the HTTP request itself. Used by + // AbstractAuthenticationProcessingFilter + // for redirection after successful authentication (SEC-29) + request.getSession().setAttribute(this.sessionAttrName, savedRequest); + this.logger.debug(LogMessage.format("DefaultSavedRequest added to Session: %s", savedRequest)); } } @Override public SavedRequest getRequest(HttpServletRequest currentRequest, HttpServletResponse response) { HttpSession session = currentRequest.getSession(false); - - if (session != null) { - return (SavedRequest) session.getAttribute(this.sessionAttrName); - } - - return null; + return (session != null) ? (SavedRequest) session.getAttribute(this.sessionAttrName) : null; } @Override public void removeRequest(HttpServletRequest currentRequest, HttpServletResponse response) { HttpSession session = currentRequest.getSession(false); - if (session != null) { this.logger.debug("Removing DefaultSavedRequest from session if present"); session.removeAttribute(this.sessionAttrName); @@ -97,14 +90,11 @@ public class HttpSessionRequestCache implements RequestCache { @Override public HttpServletRequest getMatchingRequest(HttpServletRequest request, HttpServletResponse response) { SavedRequest saved = getRequest(request, response); - if (!matchesSavedRequest(request, saved)) { this.logger.debug("saved request doesn't match"); return null; } - removeRequest(request, response); - return new SavedRequestAwareWrapper(saved, request); } @@ -112,12 +102,10 @@ public class HttpSessionRequestCache implements RequestCache { if (savedRequest == null) { return false; } - if (savedRequest instanceof DefaultSavedRequest) { DefaultSavedRequest defaultSavedRequest = (DefaultSavedRequest) savedRequest; return defaultSavedRequest.doesRequestMatch(request, this.portResolver); } - String currentUrl = UrlUtils.buildFullRequestUrl(request); return savedRequest.getRedirectUrl().equals(currentUrl); } diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/RequestCacheAwareFilter.java b/web/src/main/java/org/springframework/security/web/savedrequest/RequestCacheAwareFilter.java index 27657614c0..74b769c1ee 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/RequestCacheAwareFilter.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/RequestCacheAwareFilter.java @@ -58,10 +58,8 @@ public class RequestCacheAwareFilter extends GenericFilterBean { @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - HttpServletRequest wrappedSavedRequest = this.requestCache.getMatchingRequest((HttpServletRequest) request, (HttpServletResponse) response); - chain.doFilter((wrappedSavedRequest != null) ? wrappedSavedRequest : request, response); } diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/SavedCookie.java b/web/src/main/java/org/springframework/security/web/savedrequest/SavedCookie.java index 79f6651e74..9357e98fbe 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/SavedCookie.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/SavedCookie.java @@ -93,24 +93,20 @@ public class SavedCookie implements Serializable { } public Cookie getCookie() { - Cookie c = new Cookie(getName(), getValue()); - + Cookie cookie = new Cookie(getName(), getValue()); if (getComment() != null) { - c.setComment(getComment()); + cookie.setComment(getComment()); } - if (getDomain() != null) { - c.setDomain(getDomain()); + cookie.setDomain(getDomain()); } - if (getPath() != null) { - c.setPath(getPath()); + cookie.setPath(getPath()); } - - c.setVersion(getVersion()); - c.setMaxAge(getMaxAge()); - c.setSecure(isSecure()); - return c; + cookie.setVersion(getVersion()); + cookie.setMaxAge(getMaxAge()); + cookie.setSecure(isSecure()); + return cookie; } } diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java b/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java index adee0172e6..dfd4726675 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java @@ -73,11 +73,9 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { SavedRequestAwareWrapper(SavedRequest saved, HttpServletRequest request) { super(request); this.savedRequest = saved; - this.formats[0] = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US); this.formats[1] = new SimpleDateFormat("EEEEEE, dd-MMM-yy HH:mm:ss zzz", Locale.US); this.formats[2] = new SimpleDateFormat("EEE MMMM d HH:mm:ss yyyy", Locale.US); - this.formats[0].setTimeZone(GMT_ZONE); this.formats[1].setTimeZone(GMT_ZONE); this.formats[2].setTimeZone(GMT_ZONE); @@ -86,25 +84,20 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { @Override public long getDateHeader(String name) { String value = getHeader(name); - if (value == null) { return -1L; } - // Attempt to convert the date header in a variety of formats long result = FastHttpDateFormat.parseDate(value, this.formats); - if (result != -1L) { return result; } - throw new IllegalArgumentException(value); } @Override public String getHeader(String name) { List values = this.savedRequest.getHeaderValues(name); - return values.isEmpty() ? null : values.get(0); } @@ -123,19 +116,12 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { @Override public int getIntHeader(String name) { String value = getHeader(name); - - if (value == null) { - return -1; - } - else { - return Integer.parseInt(value); - } + return (value != null) ? Integer.parseInt(value) : -1; } @Override public Locale getLocale() { List locales = this.savedRequest.getLocales(); - return locales.isEmpty() ? Locale.getDefault() : locales.get(0); } @@ -143,13 +129,11 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { @SuppressWarnings("unchecked") public Enumeration getLocales() { List locales = this.savedRequest.getLocales(); - if (locales.isEmpty()) { // Fall back to default locale locales = new ArrayList<>(1); locales.add(Locale.getDefault()); } - return new Enumerator<>(locales); } @@ -171,17 +155,13 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { @Override public String getParameter(String name) { String value = super.getParameter(name); - if (value != null) { return value; } - String[] values = this.savedRequest.getParameterValues(name); - if (values == null || values.length == 0) { return null; } - return values[0]; } @@ -190,11 +170,9 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { public Map getParameterMap() { Set names = getCombinedParameterNames(); Map parameterMap = new HashMap<>(names.size()); - for (String name : names) { parameterMap.put(name, getParameterValues(name)); } - return parameterMap; } @@ -203,7 +181,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { Set names = new HashSet<>(); names.addAll(super.getParameterMap().keySet()); names.addAll(this.savedRequest.getParameterMap().keySet()); - return names; } @@ -217,19 +194,15 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { public String[] getParameterValues(String name) { String[] savedRequestParams = this.savedRequest.getParameterValues(name); String[] wrappedRequestParams = super.getParameterValues(name); - if (savedRequestParams == null) { return wrappedRequestParams; } - if (wrappedRequestParams == null) { return savedRequestParams; } - // We have parameters in both saved and wrapped requests so have to merge them List wrappedParamsList = Arrays.asList(wrappedRequestParams); List combinedParams = new ArrayList<>(wrappedParamsList); - // We want to add all parameters of the saved request *apart from* duplicates of // those already added for (String savedRequestParam : savedRequestParams) { @@ -237,7 +210,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper { combinedParams.add(savedRequestParam); } } - return combinedParams.toArray(new String[0]); } diff --git a/web/src/main/java/org/springframework/security/web/server/DefaultServerRedirectStrategy.java b/web/src/main/java/org/springframework/security/web/server/DefaultServerRedirectStrategy.java index 36bf02a42b..4244880daf 100644 --- a/web/src/main/java/org/springframework/security/web/server/DefaultServerRedirectStrategy.java +++ b/web/src/main/java/org/springframework/security/web/server/DefaultServerRedirectStrategy.java @@ -22,6 +22,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Mono; +import org.springframework.core.log.LogMessage; import org.springframework.http.HttpStatus; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.util.Assert; @@ -50,9 +51,7 @@ public class DefaultServerRedirectStrategy implements ServerRedirectStrategy { ServerHttpResponse response = exchange.getResponse(); response.setStatusCode(this.httpStatus); URI newLocation = createLocation(exchange, location); - if (logger.isDebugEnabled()) { - logger.debug("Redirecting to '" + newLocation + "'"); - } + logger.debug(LogMessage.format("Redirecting to '%s'", newLocation)); response.getHeaders().setLocation(newLocation); }); } diff --git a/web/src/main/java/org/springframework/security/web/server/DelegatingServerAuthenticationEntryPoint.java b/web/src/main/java/org/springframework/security/web/server/DelegatingServerAuthenticationEntryPoint.java index 633a996302..1a3ca56676 100644 --- a/web/src/main/java/org/springframework/security/web/server/DelegatingServerAuthenticationEntryPoint.java +++ b/web/src/main/java/org/springframework/security/web/server/DelegatingServerAuthenticationEntryPoint.java @@ -24,9 +24,11 @@ import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.springframework.core.log.LogMessage; import org.springframework.http.HttpStatus; import org.springframework.security.core.AuthenticationException; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; +import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; @@ -44,7 +46,7 @@ public class DelegatingServerAuthenticationEntryPoint implements ServerAuthentic private final List entryPoints; - private ServerAuthenticationEntryPoint defaultEntryPoint = (exchange, e) -> { + private ServerAuthenticationEntryPoint defaultEntryPoint = (exchange, ex) -> { exchange.getResponse().setStatusCode(HttpStatus.UNAUTHORIZED); return exchange.getResponse().setComplete(); }; @@ -61,23 +63,18 @@ public class DelegatingServerAuthenticationEntryPoint implements ServerAuthentic @Override public Mono commence(ServerWebExchange exchange, AuthenticationException ex) { return Flux.fromIterable(this.entryPoints).filterWhen((entry) -> isMatch(exchange, entry)).next() - .map((entry) -> entry.getEntryPoint()).doOnNext((it) -> { - if (logger.isDebugEnabled()) { - logger.debug("Match found! Executing " + it); - } - }).switchIfEmpty(Mono.just(this.defaultEntryPoint).doOnNext((it) -> { - if (logger.isDebugEnabled()) { - logger.debug("No match found. Using default entry point " + this.defaultEntryPoint); - } - })).flatMap((entryPoint) -> entryPoint.commence(exchange, ex)); + .map((entry) -> entry.getEntryPoint()) + .doOnNext((entryPoint) -> logger.debug(LogMessage.format("Match found! Executing %s", entryPoint))) + .switchIfEmpty(Mono.just(this.defaultEntryPoint) + .doOnNext((entryPoint) -> logger.debug(LogMessage + .format("No match found. Using default entry point %s", this.defaultEntryPoint)))) + .flatMap((entryPoint) -> entryPoint.commence(exchange, ex)); } private Mono isMatch(ServerWebExchange exchange, DelegateEntry entry) { ServerWebExchangeMatcher matcher = entry.getMatcher(); - if (logger.isDebugEnabled()) { - logger.debug("Trying to match using " + matcher); - } - return matcher.matches(exchange).map((result) -> result.isMatch()); + logger.debug(LogMessage.format("Trying to match using %s", matcher)); + return matcher.matches(exchange).map(MatchResult::isMatch); } /** diff --git a/web/src/main/java/org/springframework/security/web/server/ServerHttpBasicAuthenticationConverter.java b/web/src/main/java/org/springframework/security/web/server/ServerHttpBasicAuthenticationConverter.java index e0afa30196..3f58b31ec7 100644 --- a/web/src/main/java/org/springframework/security/web/server/ServerHttpBasicAuthenticationConverter.java +++ b/web/src/main/java/org/springframework/security/web/server/ServerHttpBasicAuthenticationConverter.java @@ -47,26 +47,18 @@ public class ServerHttpBasicAuthenticationConverter implements Function apply(ServerWebExchange exchange) { ServerHttpRequest request = exchange.getRequest(); - String authorization = request.getHeaders().getFirst(HttpHeaders.AUTHORIZATION); if (!StringUtils.startsWithIgnoreCase(authorization, "basic ")) { return Mono.empty(); } - String credentials = (authorization.length() <= BASIC.length()) ? "" : authorization.substring(BASIC.length(), authorization.length()); - byte[] decodedCredentials = base64Decode(credentials); - String decodedAuthz = new String(decodedCredentials); - String[] userParts = decodedAuthz.split(":", 2); - - if (userParts.length != 2) { + String decoded = new String(base64Decode(credentials)); + String[] parts = decoded.split(":", 2); + if (parts.length != 2) { return Mono.empty(); } - - String username = userParts[0]; - String password = userParts[1]; - - return Mono.just(new UsernamePasswordAuthenticationToken(username, password)); + return Mono.just(new UsernamePasswordAuthenticationToken(parts[0], parts[1])); } private byte[] base64Decode(String value) { diff --git a/web/src/main/java/org/springframework/security/web/server/WebFilterChainProxy.java b/web/src/main/java/org/springframework/security/web/server/WebFilterChainProxy.java index 0e03b36497..f4654ab25f 100644 --- a/web/src/main/java/org/springframework/security/web/server/WebFilterChainProxy.java +++ b/web/src/main/java/org/springframework/security/web/server/WebFilterChainProxy.java @@ -52,8 +52,7 @@ public class WebFilterChainProxy implements WebFilter { .filterWhen((securityWebFilterChain) -> securityWebFilterChain.matches(exchange)).next() .switchIfEmpty(chain.filter(exchange).then(Mono.empty())) .flatMap((securityWebFilterChain) -> securityWebFilterChain.getWebFilters().collectList()) - .map((filters) -> new FilteringWebHandler((webHandler) -> chain.filter(webHandler), filters)) - .map((handler) -> new DefaultWebFilterChain(handler)) + .map((filters) -> new FilteringWebHandler(chain::filter, filters)).map(DefaultWebFilterChain::new) .flatMap((securedChain) -> securedChain.filter(exchange)); } diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/AnonymousAuthenticationWebFilter.java b/web/src/main/java/org/springframework/security/web/server/authentication/AnonymousAuthenticationWebFilter.java index 131cf14796..9f212b2c19 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/AnonymousAuthenticationWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/AnonymousAuthenticationWebFilter.java @@ -22,6 +22,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Mono; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; @@ -80,24 +81,19 @@ public class AnonymousAuthenticationWebFilter implements WebFilter { return ReactiveSecurityContextHolder.getContext().switchIfEmpty(Mono.defer(() -> { Authentication authentication = createAuthentication(exchange); SecurityContext securityContext = new SecurityContextImpl(authentication); - if (logger.isDebugEnabled()) { - logger.debug("Populated SecurityContext with anonymous token: '" + authentication + "'"); - } + logger.debug(LogMessage.format("Populated SecurityContext with anonymous token: '%s'", authentication)); return chain.filter(exchange) .subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext))) .then(Mono.empty()); })).flatMap((securityContext) -> { - if (logger.isDebugEnabled()) { - logger.debug("SecurityContext contains anonymous token: '" + securityContext.getAuthentication() + "'"); - } + logger.debug(LogMessage.format("SecurityContext contains anonymous token: '%s'", + securityContext.getAuthentication())); return chain.filter(exchange); }); } protected Authentication createAuthentication(ServerWebExchange exchange) { - AnonymousAuthenticationToken auth = new AnonymousAuthenticationToken(this.key, this.principal, - this.authorities); - return auth; + return new AnonymousAuthenticationToken(this.key, this.principal, this.authorities); } } diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationConverterServerWebExchangeMatcher.java b/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationConverterServerWebExchangeMatcher.java index 8b60a584e7..adea58aef4 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationConverterServerWebExchangeMatcher.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationConverterServerWebExchangeMatcher.java @@ -44,7 +44,7 @@ public final class AuthenticationConverterServerWebExchangeMatcher implements Se @Override public Mono matches(ServerWebExchange exchange) { return this.serverAuthenticationConverter.convert(exchange).flatMap((a) -> MatchResult.match()) - .onErrorResume((e) -> MatchResult.notMatch()).switchIfEmpty(MatchResult.notMatch()); + .onErrorResume((ex) -> MatchResult.notMatch()).switchIfEmpty(MatchResult.notMatch()); } } diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java b/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java index b7d51de9d2..45178bb078 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java @@ -22,6 +22,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Mono; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.authentication.ReactiveAuthenticationManagerResolver; import org.springframework.security.core.Authentication; @@ -111,8 +112,8 @@ public class AuthenticationWebFilter implements WebFilter { .flatMap((matchResult) -> this.authenticationConverter.convert(exchange)) .switchIfEmpty(chain.filter(exchange).then(Mono.empty())) .flatMap((token) -> authenticate(exchange, chain, token)) - .onErrorResume(AuthenticationException.class, (e) -> this.authenticationFailureHandler - .onAuthenticationFailure(new WebFilterExchange(exchange, chain), e)); + .onErrorResume(AuthenticationException.class, (ex) -> this.authenticationFailureHandler + .onAuthenticationFailure(new WebFilterExchange(exchange, chain), ex)); } private Mono authenticate(ServerWebExchange exchange, WebFilterChain chain, Authentication token) { @@ -122,11 +123,8 @@ public class AuthenticationWebFilter implements WebFilter { () -> Mono.error(new IllegalStateException("No provider found for " + token.getClass())))) .flatMap((authentication) -> onAuthenticationSuccess(authentication, new WebFilterExchange(exchange, chain))) - .doOnError(AuthenticationException.class, (e) -> { - if (logger.isDebugEnabled()) { - logger.debug("Authentication failed: " + e.getMessage()); - } - }); + .doOnError(AuthenticationException.class, + (ex) -> logger.debug(LogMessage.format("Authentication failed: %s", ex.getMessage()))); } protected Mono onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) { diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/ReactivePreAuthenticatedAuthenticationManager.java b/web/src/main/java/org/springframework/security/web/server/authentication/ReactivePreAuthenticatedAuthenticationManager.java index d21ff1363f..f60ee37a6e 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/ReactivePreAuthenticatedAuthenticationManager.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/ReactivePreAuthenticatedAuthenticationManager.java @@ -61,11 +61,10 @@ public class ReactivePreAuthenticatedAuthenticationManager implements ReactiveAu return Mono.just(authentication).filter(this::supports).map(Authentication::getName) .flatMap(this.userDetailsService::findByUsername) .switchIfEmpty(Mono.error(() -> new UsernameNotFoundException("User not found"))) - .doOnNext(this.userDetailsChecker::check).map((ud) -> { - PreAuthenticatedAuthenticationToken result = new PreAuthenticatedAuthenticationToken(ud, - authentication.getCredentials(), ud.getAuthorities()); + .doOnNext(this.userDetailsChecker::check).map((userDetails) -> { + PreAuthenticatedAuthenticationToken result = new PreAuthenticatedAuthenticationToken(userDetails, + authentication.getCredentials(), userDetails.getAuthorities()); result.setDetails(authentication.getDetails()); - return result; }); } diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/ServerX509AuthenticationConverter.java b/web/src/main/java/org/springframework/security/web/server/authentication/ServerX509AuthenticationConverter.java index f7b9caae19..f94ebe91bc 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/ServerX509AuthenticationConverter.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/ServerX509AuthenticationConverter.java @@ -50,28 +50,16 @@ public class ServerX509AuthenticationConverter implements ServerAuthenticationCo public Mono convert(ServerWebExchange exchange) { SslInfo sslInfo = exchange.getRequest().getSslInfo(); if (sslInfo == null) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("No SslInfo provided with a request, skipping x509 authentication"); - } - + this.logger.debug("No SslInfo provided with a request, skipping x509 authentication"); return Mono.empty(); } - if (sslInfo.getPeerCertificates() == null || sslInfo.getPeerCertificates().length == 0) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("No peer certificates found in SslInfo, skipping x509 authentication"); - } - + this.logger.debug("No peer certificates found in SslInfo, skipping x509 authentication"); return Mono.empty(); } - X509Certificate clientCertificate = sslInfo.getPeerCertificates()[0]; Object principal = this.principalExtractor.extractPrincipal(clientCertificate); - - PreAuthenticatedAuthenticationToken authRequest = new PreAuthenticatedAuthenticationToken(principal, - clientCertificate); - - return Mono.just(authRequest); + return Mono.just(new PreAuthenticatedAuthenticationToken(principal, clientCertificate)); } } diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/SwitchUserWebFilter.java b/web/src/main/java/org/springframework/security/web/server/authentication/SwitchUserWebFilter.java index 38f59f48f2..410da7cb6c 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/SwitchUserWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/SwitchUserWebFilter.java @@ -25,6 +25,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Mono; +import org.springframework.core.log.LogMessage; import org.springframework.http.HttpMethod; import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; @@ -125,11 +126,9 @@ public class SwitchUserWebFilter implements WebFilter { @Nullable ServerAuthenticationFailureHandler failureHandler) { Assert.notNull(userDetailsService, "userDetailsService must be specified"); Assert.notNull(successHandler, "successHandler must be specified"); - this.userDetailsService = userDetailsService; this.successHandler = successHandler; this.failureHandler = failureHandler; - this.securityContextRepository = new WebSessionServerSecurityContextRepository(); this.userDetailsChecker = new AccountStatusUserDetailsChecker(); } @@ -147,17 +146,10 @@ public class SwitchUserWebFilter implements WebFilter { @Nullable String failureTargetUrl) { Assert.notNull(userDetailsService, "userDetailsService must be specified"); Assert.notNull(successTargetUrl, "successTargetUrl must be specified"); - this.userDetailsService = userDetailsService; this.successHandler = new RedirectServerAuthenticationSuccessHandler(successTargetUrl); - - if (failureTargetUrl != null) { - this.failureHandler = new RedirectServerAuthenticationFailureHandler(failureTargetUrl); - } - else { - this.failureHandler = null; - } - + this.failureHandler = (failureTargetUrl != null) + ? new RedirectServerAuthenticationFailureHandler(failureTargetUrl) : null; this.securityContextRepository = new WebSessionServerSecurityContextRepository(); this.userDetailsChecker = new AccountStatusUserDetailsChecker(); } @@ -165,7 +157,6 @@ public class SwitchUserWebFilter implements WebFilter { @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { final WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain); - return switchUser(webFilterExchange).switchIfEmpty(Mono.defer(() -> exitSwitchUser(webFilterExchange))) .switchIfEmpty(Mono.defer(() -> chain.filter(exchange).then(Mono.empty()))) .flatMap((authentication) -> onAuthenticationSuccess(authentication, webFilterExchange)) @@ -185,10 +176,10 @@ public class SwitchUserWebFilter implements WebFilter { .filter(ServerWebExchangeMatcher.MatchResult::isMatch) .flatMap((matchResult) -> ReactiveSecurityContextHolder.getContext()) .map(SecurityContext::getAuthentication).flatMap((currentAuthentication) -> { - final String username = getUsername(webFilterExchange.getExchange()); + String username = getUsername(webFilterExchange.getExchange()); return attemptSwitchUser(currentAuthentication, username); - }).onErrorResume(AuthenticationException.class, (e) -> onAuthenticationFailure(e, webFilterExchange) - .then(Mono.error(new SwitchUserAuthenticationException(e)))); + }).onErrorResume(AuthenticationException.class, (ex) -> onAuthenticationFailure(ex, webFilterExchange) + .then(Mono.error(new SwitchUserAuthenticationException(ex)))); } /** @@ -220,11 +211,7 @@ public class SwitchUserWebFilter implements WebFilter { @NonNull private Mono attemptSwitchUser(Authentication currentAuthentication, String userName) { Assert.notNull(userName, "The userName can not be null."); - - if (this.logger.isDebugEnabled()) { - this.logger.debug("Attempt to switch to user [" + userName + "]"); - } - + this.logger.debug(LogMessage.format("Attempt to switch to user [%s]", userName)); return this.userDetailsService.findByUsername(userName) .switchIfEmpty(Mono.error(this::noTargetAuthenticationException)) .doOnNext(this.userDetailsChecker::check) @@ -233,19 +220,17 @@ public class SwitchUserWebFilter implements WebFilter { @NonNull private Authentication attemptExitUser(Authentication currentAuthentication) { - final Optional sourceAuthentication = extractSourceAuthentication(currentAuthentication); - + Optional sourceAuthentication = extractSourceAuthentication(currentAuthentication); if (!sourceAuthentication.isPresent()) { this.logger.debug("Could not find original user Authentication object!"); throw noOriginalAuthenticationException(); } - return sourceAuthentication.get(); } private Mono onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) { - final ServerWebExchange exchange = webFilterExchange.getExchange(); - final SecurityContextImpl securityContext = new SecurityContextImpl(authentication); + ServerWebExchange exchange = webFilterExchange.getExchange(); + SecurityContextImpl securityContext = new SecurityContextImpl(authentication); return this.securityContextRepository.save(exchange, securityContext) .then(this.successHandler.onAuthenticationSuccess(webFilterExchange, authentication)) .subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext))); @@ -259,21 +244,18 @@ public class SwitchUserWebFilter implements WebFilter { } private Authentication createSwitchUserToken(UserDetails targetUser, Authentication currentAuthentication) { - final Optional sourceAuthentication = extractSourceAuthentication(currentAuthentication); - + Optional sourceAuthentication = extractSourceAuthentication(currentAuthentication); if (sourceAuthentication.isPresent()) { // SEC-1763. Check first if we are already switched. - this.logger.info("Found original switch user granted authority [" + sourceAuthentication.get() + "]"); + this.logger.info( + LogMessage.format("Found original switch user granted authority [%s]", sourceAuthentication.get())); currentAuthentication = sourceAuthentication.get(); } - - final GrantedAuthority switchAuthority = new SwitchUserGrantedAuthority(ROLE_PREVIOUS_ADMINISTRATOR, + GrantedAuthority switchAuthority = new SwitchUserGrantedAuthority(ROLE_PREVIOUS_ADMINISTRATOR, currentAuthentication); - final Collection targetUserAuthorities = targetUser.getAuthorities(); - - final List extendedTargetUserAuthorities = new ArrayList<>(targetUserAuthorities); + Collection targetUserAuthorities = targetUser.getAuthorities(); + List extendedTargetUserAuthorities = new ArrayList<>(targetUserAuthorities); extendedTargetUserAuthorities.add(switchAuthority); - return new UsernamePasswordAuthenticationToken(targetUser, targetUser.getPassword(), extendedTargetUserAuthorities); } @@ -291,7 +273,7 @@ public class SwitchUserWebFilter implements WebFilter { // iterate over granted authorities and find the 'switch user' authority for (GrantedAuthority authority : currentAuthentication.getAuthorities()) { if (authority instanceof SwitchUserGrantedAuthority) { - final SwitchUserGrantedAuthority switchAuthority = (SwitchUserGrantedAuthority) authority; + SwitchUserGrantedAuthority switchAuthority = (SwitchUserGrantedAuthority) authority; return Optional.of(switchAuthority.getSource()); } } diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/logout/LogoutWebFilter.java b/web/src/main/java/org/springframework/security/web/server/authentication/logout/LogoutWebFilter.java index a27b42d657..5f888db0f9 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/logout/LogoutWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/logout/LogoutWebFilter.java @@ -20,6 +20,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Mono; +import org.springframework.core.log.LogMessage; import org.springframework.http.HttpMethod; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; @@ -70,9 +71,7 @@ public class LogoutWebFilter implements WebFilter { } private Mono logout(WebFilterExchange webFilterExchange, Authentication authentication) { - if (logger.isDebugEnabled()) { - logger.debug("Logging out user '" + authentication + "' and transferring to logout destination"); - } + logger.debug(LogMessage.format("Logging out user '%s' and transferring to logout destination", authentication)); return this.logoutHandler.logout(webFilterExchange, authentication) .then(this.logoutSuccessHandler.onLogoutSuccess(webFilterExchange, authentication)) .subscriberContext(ReactiveSecurityContextHolder.clearContext()); diff --git a/web/src/main/java/org/springframework/security/web/server/authorization/AuthorizationWebFilter.java b/web/src/main/java/org/springframework/security/web/server/authorization/AuthorizationWebFilter.java index 7c167f8265..6b8ee79eb5 100644 --- a/web/src/main/java/org/springframework/security/web/server/authorization/AuthorizationWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/authorization/AuthorizationWebFilter.java @@ -20,6 +20,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Mono; +import org.springframework.core.log.LogMessage; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authorization.ReactiveAuthorizationManager; import org.springframework.security.core.context.ReactiveSecurityContextHolder; @@ -48,15 +49,10 @@ public class AuthorizationWebFilter implements WebFilter { return ReactiveSecurityContextHolder.getContext().filter((c) -> c.getAuthentication() != null) .map(SecurityContext::getAuthentication) .as((authentication) -> this.authorizationManager.verify(authentication, exchange)) - .doOnSuccess((it) -> { - if (logger.isDebugEnabled()) { - logger.debug("Authorization successful"); - } - }).doOnError(AccessDeniedException.class, (e) -> { - if (logger.isDebugEnabled()) { - logger.debug("Authorization failed: " + e.getMessage()); - } - }).switchIfEmpty(chain.filter(exchange)); + .doOnSuccess((it) -> logger.debug("Authorization successful")) + .doOnError(AccessDeniedException.class, + (ex) -> logger.debug(LogMessage.format("Authorization failed: %s", ex.getMessage()))) + .switchIfEmpty(chain.filter(exchange)); } } diff --git a/web/src/main/java/org/springframework/security/web/server/authorization/DelegatingReactiveAuthorizationManager.java b/web/src/main/java/org/springframework/security/web/server/authorization/DelegatingReactiveAuthorizationManager.java index a7074b36f8..bfc0ff5a50 100644 --- a/web/src/main/java/org/springframework/security/web/server/authorization/DelegatingReactiveAuthorizationManager.java +++ b/web/src/main/java/org/springframework/security/web/server/authorization/DelegatingReactiveAuthorizationManager.java @@ -24,6 +24,7 @@ import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.springframework.core.log.LogMessage; import org.springframework.security.authorization.AuthorizationDecision; import org.springframework.security.authorization.ReactiveAuthorizationManager; import org.springframework.security.core.Authentication; @@ -51,11 +52,9 @@ public final class DelegatingReactiveAuthorizationManager implements ReactiveAut public Mono check(Mono authentication, ServerWebExchange exchange) { return Flux.fromIterable(this.mappings).concatMap((mapping) -> mapping.getMatcher().matches(exchange) .filter(MatchResult::isMatch).map(MatchResult::getVariables).flatMap((variables) -> { - if (logger.isDebugEnabled()) { - logger.debug( - "Checking authorization on '" + exchange.getRequest().getPath().pathWithinApplication() - + "' using " + mapping.getEntry()); - } + logger.debug(LogMessage.of(() -> "Checking authorization on '" + + exchange.getRequest().getPath().pathWithinApplication() + "' using " + + mapping.getEntry())); return mapping.getEntry().check(authentication, new AuthorizationContext(exchange, variables)); })).next().defaultIfEmpty(new AuthorizationDecision(false)); } diff --git a/web/src/main/java/org/springframework/security/web/server/authorization/ServerWebExchangeDelegatingServerAccessDeniedHandler.java b/web/src/main/java/org/springframework/security/web/server/authorization/ServerWebExchangeDelegatingServerAccessDeniedHandler.java index 9c7100da90..1fa697839c 100644 --- a/web/src/main/java/org/springframework/security/web/server/authorization/ServerWebExchangeDelegatingServerAccessDeniedHandler.java +++ b/web/src/main/java/org/springframework/security/web/server/authorization/ServerWebExchangeDelegatingServerAccessDeniedHandler.java @@ -39,7 +39,7 @@ public class ServerWebExchangeDelegatingServerAccessDeniedHandler implements Ser private final List handlers; - private ServerAccessDeniedHandler defaultHandler = (exchange, e) -> { + private ServerAccessDeniedHandler defaultHandler = (exchange, ex) -> { exchange.getResponse().setStatusCode(HttpStatus.FORBIDDEN); return exchange.getResponse().setComplete(); }; diff --git a/web/src/main/java/org/springframework/security/web/server/context/ReactorContextWebFilter.java b/web/src/main/java/org/springframework/security/web/server/context/ReactorContextWebFilter.java index 724e884fff..f3b09650aa 100644 --- a/web/src/main/java/org/springframework/security/web/server/context/ReactorContextWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/context/ReactorContextWebFilter.java @@ -44,8 +44,8 @@ public class ReactorContextWebFilter implements WebFilter { @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { - return chain.filter(exchange) - .subscriberContext((c) -> c.hasKey(SecurityContext.class) ? c : withSecurityContext(c, exchange)); + return chain.filter(exchange).subscriberContext( + (context) -> context.hasKey(SecurityContext.class) ? context : withSecurityContext(context, exchange)); } private Context withSecurityContext(Context mainContext, ServerWebExchange exchange) { diff --git a/web/src/main/java/org/springframework/security/web/server/context/SecurityContextServerWebExchange.java b/web/src/main/java/org/springframework/security/web/server/context/SecurityContextServerWebExchange.java index c4bd23556d..1ad1d1fa8a 100644 --- a/web/src/main/java/org/springframework/security/web/server/context/SecurityContextServerWebExchange.java +++ b/web/src/main/java/org/springframework/security/web/server/context/SecurityContextServerWebExchange.java @@ -44,7 +44,7 @@ public class SecurityContextServerWebExchange extends ServerWebExchangeDecorator @Override @SuppressWarnings("unchecked") public Mono getPrincipal() { - return this.context.map((c) -> (T) c.getAuthentication()); + return this.context.map((context) -> (T) context.getAuthentication()); } } diff --git a/web/src/main/java/org/springframework/security/web/server/context/WebSessionServerSecurityContextRepository.java b/web/src/main/java/org/springframework/security/web/server/context/WebSessionServerSecurityContextRepository.java index 3058a82e1e..3bd3ed8b8a 100644 --- a/web/src/main/java/org/springframework/security/web/server/context/WebSessionServerSecurityContextRepository.java +++ b/web/src/main/java/org/springframework/security/web/server/context/WebSessionServerSecurityContextRepository.java @@ -20,9 +20,11 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Mono; +import org.springframework.core.log.LogMessage; import org.springframework.security.core.context.SecurityContext; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; /** * Stores the {@link SecurityContext} in the @@ -59,31 +61,22 @@ public class WebSessionServerSecurityContextRepository implements ServerSecurity return exchange.getSession().doOnNext((session) -> { if (context == null) { session.getAttributes().remove(this.springSecurityContextAttrName); - if (logger.isDebugEnabled()) { - logger.debug("Removed SecurityContext stored in WebSession: '" + session + "'"); - } + logger.debug(LogMessage.format("Removed SecurityContext stored in WebSession: '%s'", session)); } else { session.getAttributes().put(this.springSecurityContextAttrName, context); - if (logger.isDebugEnabled()) { - logger.debug("Saved SecurityContext '" + context + "' in WebSession: '" + session + "'"); - } + logger.debug(LogMessage.format("Saved SecurityContext '%s' in WebSession: '%s'", context, session)); } - }).flatMap((session) -> session.changeSessionId()); + }).flatMap(WebSession::changeSessionId); } @Override public Mono load(ServerWebExchange exchange) { return exchange.getSession().flatMap((session) -> { SecurityContext context = (SecurityContext) session.getAttribute(this.springSecurityContextAttrName); - if (logger.isDebugEnabled()) { - if (context == null) { - logger.debug("No SecurityContext found in WebSession: '" + session + "'"); - } - else { - logger.debug("Found SecurityContext '" + context + "' in WebSession: '" + session + "'"); - } - } + logger.debug((context != null) + ? LogMessage.format("Found SecurityContext '%s' in WebSession: '%s'", context, session) + : LogMessage.format("No SecurityContext found in WebSession: '%s'", session)); return Mono.justOrEmpty(context); }); } diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CookieServerCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/server/csrf/CookieServerCsrfTokenRepository.java index 18fa4d6eeb..8b2c952911 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/CookieServerCsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/CookieServerCsrfTokenRepository.java @@ -77,10 +77,8 @@ public final class CookieServerCsrfTokenRepository implements ServerCsrfTokenRep int maxAge = !tokenValue.isEmpty() ? -1 : 0; String path = (this.cookiePath != null) ? this.cookiePath : getRequestContext(exchange.getRequest()); boolean secure = exchange.getRequest().getSslInfo() != null; - ResponseCookie cookie = ResponseCookie.from(this.cookieName, tokenValue).domain(this.cookieDomain) .httpOnly(this.cookieHttpOnly).maxAge(maxAge).path(path).secure(secure).build(); - exchange.getResponse().addCookie(cookie); }); } diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java index 87655dee60..46ffb2cafb 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java @@ -31,6 +31,7 @@ import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler; import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; +import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; @@ -115,12 +116,11 @@ public class CsrfWebFilter implements WebFilter { if (Boolean.TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) { return chain.filter(exchange).then(Mono.empty()); } - - return this.requireCsrfProtectionMatcher.matches(exchange).filter((matchResult) -> matchResult.isMatch()) + return this.requireCsrfProtectionMatcher.matches(exchange).filter(MatchResult::isMatch) .filter((matchResult) -> !exchange.getAttributes().containsKey(CsrfToken.class.getName())) .flatMap((m) -> validateToken(exchange)).flatMap((m) -> continueFilterChain(exchange, chain)) .switchIfEmpty(continueFilterChain(exchange, chain).then(Mono.empty())) - .onErrorResume(CsrfException.class, (e) -> this.accessDeniedHandler.handle(exchange, e)); + .onErrorResume(CsrfException.class, (ex) -> this.accessDeniedHandler.handle(exchange, ex)); } public static void skipExchange(ServerWebExchange exchange) { @@ -181,7 +181,7 @@ public class CsrfWebFilter implements WebFilter { @Override public Mono matches(ServerWebExchange exchange) { return Mono.just(exchange.getRequest()).flatMap((r) -> Mono.justOrEmpty(r.getMethod())) - .filter((m) -> ALLOWED_METHODS.contains(m)).flatMap((m) -> MatchResult.notMatch()) + .filter(ALLOWED_METHODS::contains).flatMap((m) -> MatchResult.notMatch()) .switchIfEmpty(MatchResult.match()); } diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java b/web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java index 42fc487fe0..eb49369e6f 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java @@ -65,23 +65,21 @@ public final class DefaultCsrfToken implements CsrfToken { } @Override - public boolean equals(Object o) { - if (this == o) { + public boolean equals(Object obj) { + if (this == obj) { return true; } - if (o == null || !(o instanceof CsrfToken)) { + if (obj == null || !(obj instanceof CsrfToken)) { return false; } - - CsrfToken that = (CsrfToken) o; - - if (!getToken().equals(that.getToken())) { + CsrfToken other = (CsrfToken) obj; + if (!getToken().equals(other.getToken())) { return false; } - if (!getParameterName().equals(that.getParameterName())) { + if (!getParameterName().equals(other.getParameterName())) { return false; } - return getHeaderName().equals(that.getHeaderName()); + return getHeaderName().equals(other.getHeaderName()); } @Override diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java index 57002cf9db..9a4714f58a 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java @@ -73,8 +73,8 @@ public class WebSessionServerCsrfTokenRepository implements ServerCsrfTokenRepos @Override public Mono loadToken(ServerWebExchange exchange) { - return exchange.getSession().filter((s) -> s.getAttributes().containsKey(this.sessionAttributeName)) - .map((s) -> s.getAttribute(this.sessionAttributeName)); + return exchange.getSession().filter((session) -> session.getAttributes().containsKey(this.sessionAttributeName)) + .map((session) -> session.getAttribute(this.sessionAttributeName)); } /** diff --git a/web/src/main/java/org/springframework/security/web/server/header/ClearSiteDataServerHttpHeadersWriter.java b/web/src/main/java/org/springframework/security/web/server/header/ClearSiteDataServerHttpHeadersWriter.java index 31b94b4692..1f44a9f5bb 100644 --- a/web/src/main/java/org/springframework/security/web/server/header/ClearSiteDataServerHttpHeadersWriter.java +++ b/web/src/main/java/org/springframework/security/web/server/header/ClearSiteDataServerHttpHeadersWriter.java @@ -58,9 +58,7 @@ public final class ClearSiteDataServerHttpHeadersWriter implements ServerHttpHea if (isSecure(exchange)) { return this.headerWriterDelegate.writeHttpHeaders(exchange); } - else { - return Mono.empty(); - } + return Mono.empty(); } /** @@ -72,7 +70,15 @@ public final class ClearSiteDataServerHttpHeadersWriter implements ServerHttpHea */ public enum Directive { - CACHE("cache"), COOKIES("cookies"), STORAGE("storage"), EXECUTION_CONTEXTS("executionContexts"), ALL("*"); + CACHE("cache"), + + COOKIES("cookies"), + + STORAGE("storage"), + + EXECUTION_CONTEXTS("executionContexts"), + + ALL("*"); private final String headerValue; diff --git a/web/src/main/java/org/springframework/security/web/server/header/ContentSecurityPolicyServerHttpHeadersWriter.java b/web/src/main/java/org/springframework/security/web/server/header/ContentSecurityPolicyServerHttpHeadersWriter.java index b053b418af..f72bc5e14d 100644 --- a/web/src/main/java/org/springframework/security/web/server/header/ContentSecurityPolicyServerHttpHeadersWriter.java +++ b/web/src/main/java/org/springframework/security/web/server/header/ContentSecurityPolicyServerHttpHeadersWriter.java @@ -18,6 +18,7 @@ package org.springframework.security.web.server.header; import reactor.core.publisher.Mono; +import org.springframework.security.web.server.header.StaticServerHttpHeadersWriter.Builder; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; @@ -67,16 +68,12 @@ public final class ContentSecurityPolicyServerHttpHeadersWriter implements Serve } private ServerHttpHeadersWriter createDelegate() { - if (this.policyDirectives != null) { - // @formatter:off - return StaticServerHttpHeadersWriter.builder() - .header(resolveHeader(this.reportOnly), this.policyDirectives) - .build(); - // @formatter:on - } - else { + if (this.policyDirectives == null) { return null; } + Builder builder = StaticServerHttpHeadersWriter.builder(); + builder.header(resolveHeader(this.reportOnly), this.policyDirectives); + return builder.build(); } private static String resolveHeader(boolean reportOnly) { diff --git a/web/src/main/java/org/springframework/security/web/server/header/FeaturePolicyServerHttpHeadersWriter.java b/web/src/main/java/org/springframework/security/web/server/header/FeaturePolicyServerHttpHeadersWriter.java index 6cfe98e55e..ba79c9e467 100644 --- a/web/src/main/java/org/springframework/security/web/server/header/FeaturePolicyServerHttpHeadersWriter.java +++ b/web/src/main/java/org/springframework/security/web/server/header/FeaturePolicyServerHttpHeadersWriter.java @@ -18,6 +18,7 @@ package org.springframework.security.web.server.header; import reactor.core.publisher.Mono; +import org.springframework.security.web.server.header.StaticServerHttpHeadersWriter.Builder; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; @@ -49,11 +50,9 @@ public final class FeaturePolicyServerHttpHeadersWriter implements ServerHttpHea } private static ServerHttpHeadersWriter createDelegate(String policyDirectives) { - // @formatter:off - return StaticServerHttpHeadersWriter.builder() - .header(FEATURE_POLICY, policyDirectives) - .build(); - // @formatter:on + Builder builder = StaticServerHttpHeadersWriter.builder(); + builder.header(FEATURE_POLICY, policyDirectives); + return builder.build(); } } diff --git a/web/src/main/java/org/springframework/security/web/server/header/ReferrerPolicyServerHttpHeadersWriter.java b/web/src/main/java/org/springframework/security/web/server/header/ReferrerPolicyServerHttpHeadersWriter.java index 64457b09e1..ddc6ff59ef 100644 --- a/web/src/main/java/org/springframework/security/web/server/header/ReferrerPolicyServerHttpHeadersWriter.java +++ b/web/src/main/java/org/springframework/security/web/server/header/ReferrerPolicyServerHttpHeadersWriter.java @@ -22,6 +22,7 @@ import java.util.Map; import reactor.core.publisher.Mono; +import org.springframework.security.web.server.header.StaticServerHttpHeadersWriter.Builder; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; @@ -57,25 +58,28 @@ public final class ReferrerPolicyServerHttpHeadersWriter implements ServerHttpHe } private static ServerHttpHeadersWriter createDelegate(ReferrerPolicy policy) { - // @formatter:off - return StaticServerHttpHeadersWriter.builder() - .header(REFERRER_POLICY, policy.getPolicy()) - .build(); - // @formatter:on + Builder builder = StaticServerHttpHeadersWriter.builder(); + builder.header(REFERRER_POLICY, policy.getPolicy()); + return builder.build(); } public enum ReferrerPolicy { - // @formatter:off NO_REFERRER("no-referrer"), + NO_REFERRER_WHEN_DOWNGRADE("no-referrer-when-downgrade"), + SAME_ORIGIN("same-origin"), + ORIGIN("origin"), + STRICT_ORIGIN("strict-origin"), + ORIGIN_WHEN_CROSS_ORIGIN("origin-when-cross-origin"), + STRICT_ORIGIN_WHEN_CROSS_ORIGIN("strict-origin-when-cross-origin"), + UNSAFE_URL("unsafe-url"); - // @formatter:on private static final Map REFERRER_POLICIES; @@ -87,7 +91,7 @@ public final class ReferrerPolicyServerHttpHeadersWriter implements ServerHttpHe REFERRER_POLICIES = Collections.unmodifiableMap(referrerPolicies); } - private String policy; + private final String policy; ReferrerPolicy(String policy) { this.policy = policy; diff --git a/web/src/main/java/org/springframework/security/web/server/header/StrictTransportSecurityServerHttpHeadersWriter.java b/web/src/main/java/org/springframework/security/web/server/header/StrictTransportSecurityServerHttpHeadersWriter.java index f4fcc58fa8..f9452233e7 100644 --- a/web/src/main/java/org/springframework/security/web/server/header/StrictTransportSecurityServerHttpHeadersWriter.java +++ b/web/src/main/java/org/springframework/security/web/server/header/StrictTransportSecurityServerHttpHeadersWriter.java @@ -20,6 +20,7 @@ import java.time.Duration; import reactor.core.publisher.Mono; +import org.springframework.security.web.server.header.StaticServerHttpHeadersWriter.Builder; import org.springframework.web.server.ServerWebExchange; /** @@ -40,9 +41,6 @@ public final class StrictTransportSecurityServerHttpHeadersWriter implements Ser private ServerHttpHeadersWriter delegate; - /** - * - */ public StrictTransportSecurityServerHttpHeadersWriter() { setIncludeSubDomains(true); setMaxAge(Duration.ofDays(365L)); @@ -92,8 +90,9 @@ public final class StrictTransportSecurityServerHttpHeadersWriter implements Ser } private void updateDelegate() { - this.delegate = StaticServerHttpHeadersWriter.builder() - .header(STRICT_TRANSPORT_SECURITY, this.maxAge + this.subdomain + this.preload).build(); + Builder builder = StaticServerHttpHeadersWriter.builder(); + builder.header(STRICT_TRANSPORT_SECURITY, this.maxAge + this.subdomain + this.preload); + this.delegate = builder.build(); } private boolean isSecure(ServerWebExchange exchange) { diff --git a/web/src/main/java/org/springframework/security/web/server/header/XFrameOptionsServerHttpHeadersWriter.java b/web/src/main/java/org/springframework/security/web/server/header/XFrameOptionsServerHttpHeadersWriter.java index 3e5beb7046..d6712fb287 100644 --- a/web/src/main/java/org/springframework/security/web/server/header/XFrameOptionsServerHttpHeadersWriter.java +++ b/web/src/main/java/org/springframework/security/web/server/header/XFrameOptionsServerHttpHeadersWriter.java @@ -18,6 +18,7 @@ package org.springframework.security.web.server.header; import reactor.core.publisher.Mono; +import org.springframework.security.web.server.header.StaticServerHttpHeadersWriter.Builder; import org.springframework.web.server.ServerWebExchange; /** @@ -67,6 +68,7 @@ public class XFrameOptionsServerHttpHeadersWriter implements ServerHttpHeadersWr * content in any frame. */ DENY, + /** * A browser receiving content with this header field MUST NOT display this * content in any frame from a page of different origin than the content itself. @@ -79,9 +81,9 @@ public class XFrameOptionsServerHttpHeadersWriter implements ServerHttpHeadersWr } private static ServerHttpHeadersWriter createDelegate(Mode mode) { - // @formatter:off - return StaticServerHttpHeadersWriter.builder().header(X_FRAME_OPTIONS, mode.name()).build(); - // @formatter:on + Builder builder = StaticServerHttpHeadersWriter.builder(); + builder.header(X_FRAME_OPTIONS, mode.name()); + return builder.build(); } diff --git a/web/src/main/java/org/springframework/security/web/server/header/XXssProtectionServerHttpHeadersWriter.java b/web/src/main/java/org/springframework/security/web/server/header/XXssProtectionServerHttpHeadersWriter.java index 2112d4afbb..2437ed3042 100644 --- a/web/src/main/java/org/springframework/security/web/server/header/XXssProtectionServerHttpHeadersWriter.java +++ b/web/src/main/java/org/springframework/security/web/server/header/XXssProtectionServerHttpHeadersWriter.java @@ -18,6 +18,8 @@ package org.springframework.security.web.server.header; import reactor.core.publisher.Mono; +import org.springframework.security.web.server.header.StaticServerHttpHeadersWriter.Builder; +import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; /** @@ -86,16 +88,15 @@ public class XXssProtectionServerHttpHeadersWriter implements ServerHttpHeadersW * @param block the new value */ public void setBlock(boolean block) { - if (!this.enabled && block) { - throw new IllegalArgumentException("Cannot set block to true with enabled false"); - } + Assert.isTrue(this.enabled || !block, "Cannot set block to true with enabled false"); this.block = block; updateDelegate(); } private void updateDelegate() { - - this.delegate = StaticServerHttpHeadersWriter.builder().header(X_XSS_PROTECTION, createHeaderValue()).build(); + Builder builder = StaticServerHttpHeadersWriter.builder(); + builder.header(X_XSS_PROTECTION, createHeaderValue()); + this.delegate = builder.build(); } private String createHeaderValue() { diff --git a/web/src/main/java/org/springframework/security/web/server/jackson2/WebServerJackson2Module.java b/web/src/main/java/org/springframework/security/web/server/jackson2/WebServerJackson2Module.java index ce75493721..ceea54bdbc 100644 --- a/web/src/main/java/org/springframework/security/web/server/jackson2/WebServerJackson2Module.java +++ b/web/src/main/java/org/springframework/security/web/server/jackson2/WebServerJackson2Module.java @@ -40,8 +40,12 @@ import org.springframework.security.web.server.csrf.DefaultCsrfToken; */ public class WebServerJackson2Module extends SimpleModule { + private static final String NAME = WebServerJackson2Module.class.getName(); + + private static final Version VERSION = new Version(1, 0, 0, null, null, null); + public WebServerJackson2Module() { - super(WebServerJackson2Module.class.getName(), new Version(1, 0, 0, null, null, null)); + super(NAME, VERSION); } @Override diff --git a/web/src/main/java/org/springframework/security/web/server/savedrequest/CookieServerRequestCache.java b/web/src/main/java/org/springframework/security/web/server/savedrequest/CookieServerRequestCache.java index ca8bd11c79..6d92e9e8d0 100644 --- a/web/src/main/java/org/springframework/security/web/server/savedrequest/CookieServerRequestCache.java +++ b/web/src/main/java/org/springframework/security/web/server/savedrequest/CookieServerRequestCache.java @@ -25,6 +25,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Mono; +import org.springframework.core.log.LogMessage; import org.springframework.http.HttpCookie; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; @@ -75,9 +76,7 @@ public class CookieServerRequestCache implements ServerRequestCache { .map(ServerHttpResponse::getCookies).doOnNext((cookies) -> { ResponseCookie redirectUriCookie = createRedirectUriCookie(exchange.getRequest()); cookies.add(REDIRECT_URI_COOKIE_NAME, redirectUriCookie); - if (logger.isDebugEnabled()) { - logger.debug("Request added to Cookie: " + redirectUriCookie); - } + logger.debug(LogMessage.format("Request added to Cookie: %s", redirectUriCookie)); }).then(); } @@ -86,7 +85,7 @@ public class CookieServerRequestCache implements ServerRequestCache { MultiValueMap cookieMap = exchange.getRequest().getCookies(); return Mono.justOrEmpty(cookieMap.getFirst(REDIRECT_URI_COOKIE_NAME)).map(HttpCookie::getValue) .map(CookieServerRequestCache::decodeCookie) - .onErrorResume(IllegalArgumentException.class, (e) -> Mono.empty()).map(URI::create); + .onErrorResume(IllegalArgumentException.class, (ex) -> Mono.empty()).map(URI::create); } @Override @@ -100,7 +99,6 @@ public class CookieServerRequestCache implements ServerRequestCache { String path = request.getPath().pathWithinApplication().value(); String query = request.getURI().getRawQuery(); String redirectUri = path + ((query != null) ? "?" + query : ""); - return createResponseCookie(request, encodeCookie(redirectUri), COOKIE_MAX_AGE); } diff --git a/web/src/main/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCache.java b/web/src/main/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCache.java index 37822da309..df6374cf5b 100644 --- a/web/src/main/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCache.java +++ b/web/src/main/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCache.java @@ -23,6 +23,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Mono; +import org.springframework.core.log.LogMessage; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.server.reactive.ServerHttpRequest; @@ -72,9 +73,7 @@ public class WebSessionServerRequestCache implements ServerRequestCache { .flatMap((m) -> exchange.getSession()).map(WebSession::getAttributes).doOnNext((attrs) -> { String requestPath = pathInApplication(exchange.getRequest()); attrs.put(this.sessionAttrName, requestPath); - if (logger.isDebugEnabled()) { - logger.debug("Request added to WebSession: '" + requestPath + "'"); - } + logger.debug(LogMessage.format("Request added to WebSession: '%s'", requestPath)); }).then(); } @@ -91,9 +90,7 @@ public class WebSessionServerRequestCache implements ServerRequestCache { String requestPath = pathInApplication(exchange.getRequest()); boolean removed = attributes.remove(this.sessionAttrName, requestPath); if (removed) { - if (logger.isDebugEnabled()) { - logger.debug("Request removed from WebSession: '" + requestPath + "'"); - } + logger.debug(LogMessage.format("Request removed from WebSession: '%s'", requestPath)); } return removed; }).map((attributes) -> exchange.getRequest()); diff --git a/web/src/main/java/org/springframework/security/web/server/transport/HttpsRedirectWebFilter.java b/web/src/main/java/org/springframework/security/web/server/transport/HttpsRedirectWebFilter.java index da69b1b028..f1b7eadedc 100644 --- a/web/src/main/java/org/springframework/security/web/server/transport/HttpsRedirectWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/transport/HttpsRedirectWebFilter.java @@ -51,9 +51,6 @@ public final class HttpsRedirectWebFilter implements WebFilter { private final ServerRedirectStrategy redirectStrategy = new DefaultServerRedirectStrategy(); - /** - * {@inheritDoc} - */ @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { return Mono.just(exchange).filter(this::isInsecure).flatMap(this.requiresHttpsRedirectMatcher::matches) @@ -80,7 +77,6 @@ public final class HttpsRedirectWebFilter implements WebFilter { * @param requiresHttpsRedirectMatcher the {@link ServerWebExchangeMatcher} to use */ public void setRequiresHttpsRedirectMatcher(ServerWebExchangeMatcher requiresHttpsRedirectMatcher) { - Assert.notNull(requiresHttpsRedirectMatcher, "requiresHttpsRedirectMatcher cannot be null"); this.requiresHttpsRedirectMatcher = requiresHttpsRedirectMatcher; } @@ -91,17 +87,12 @@ public final class HttpsRedirectWebFilter implements WebFilter { private URI createRedirectUri(ServerWebExchange exchange) { int port = exchange.getRequest().getURI().getPort(); - UriComponentsBuilder builder = UriComponentsBuilder.fromUri(exchange.getRequest().getURI()); - if (port > 0) { Integer httpsPort = this.portMapper.lookupHttpsPort(port); - if (httpsPort == null) { - throw new IllegalStateException("HTTP Port '" + port + "' does not have a corresponding HTTPS Port"); - } + Assert.state(httpsPort != null, () -> "HTTP Port '" + port + "' does not have a corresponding HTTPS Port"); builder.port(httpsPort); } - return builder.scheme("https").build().toUri(); } diff --git a/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java b/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java index 7f56d33c98..4248433f75 100644 --- a/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java @@ -75,7 +75,6 @@ public class LoginPageGeneratingWebFilter implements WebFilter { } private Mono createBuffer(ServerWebExchange exchange) { - Mono token = exchange.getAttributeOrDefault(CsrfToken.class.getName(), Mono.empty()); return token.map(LoginPageGeneratingWebFilter::csrfToken).defaultIfEmpty("").map((csrfTokenHtmlInput) -> { byte[] bytes = createPage(exchange, csrfTokenHtmlInput); @@ -87,18 +86,29 @@ public class LoginPageGeneratingWebFilter implements WebFilter { private byte[] createPage(ServerWebExchange exchange, String csrfTokenHtmlInput) { MultiValueMap queryParams = exchange.getRequest().getQueryParams(); String contextPath = exchange.getRequest().getPath().contextPath().value(); - String page = "\n" + "\n" + " \n" + " \n" - + " \n" - + " \n" + " \n" - + " Please sign in\n" - + " \n" - + " \n" - + " \n" + " \n" + "
\n" - + formLogin(queryParams, contextPath, csrfTokenHtmlInput) - + oauth2LoginLinks(queryParams, contextPath, this.oauth2AuthenticationUrlToClientName) + "
\n" - + " \n" + ""; - - return page.getBytes(Charset.defaultCharset()); + StringBuilder page = new StringBuilder(); + page.append("\n"); + page.append("\n"); + page.append(" \n"); + page.append(" \n"); + page.append(" \n"); + page.append(" \n"); + page.append(" \n"); + page.append(" Please sign in\n"); + page.append(" \n"); + page.append(" \n"); + page.append(" \n"); + page.append(" \n"); + page.append("
\n"); + page.append(formLogin(queryParams, contextPath, csrfTokenHtmlInput)); + page.append(oauth2LoginLinks(queryParams, contextPath, this.oauth2AuthenticationUrlToClientName)); + page.append("
\n"); + page.append(" \n"); + page.append(""); + return page.toString().getBytes(Charset.defaultCharset()); } private String formLogin(MultiValueMap queryParams, String contextPath, String csrfTokenHtmlInput) { @@ -107,17 +117,24 @@ public class LoginPageGeneratingWebFilter implements WebFilter { } boolean isError = queryParams.containsKey("error"); boolean isLogoutSuccess = queryParams.containsKey("logout"); - return "
\n" - + " \n" + createError(isError) - + createLogoutSuccess(isLogoutSuccess) + "

\n" - + " \n" - + " \n" - + "

\n" + "

\n" - + " \n" - + " \n" - + "

\n" + csrfTokenHtmlInput - + " \n" - + "
\n"; + StringBuilder page = new StringBuilder(); + page.append("
\n"); + page.append(" \n"); + page.append(createError(isError)); + page.append(createLogoutSuccess(isLogoutSuccess)); + page.append("

\n"); + page.append(" \n"); + page.append(" \n"); + page.append("

\n" + "

\n"); + page.append(" \n"); + page.append(" \n"); + page.append("

\n"); + page.append(csrfTokenHtmlInput); + page.append(" \n"); + page.append("
\n"); + return page.toString(); } private static String oauth2LoginLinks(MultiValueMap queryParams, String contextPath, diff --git a/web/src/main/java/org/springframework/security/web/server/ui/LogoutPageGeneratingWebFilter.java b/web/src/main/java/org/springframework/security/web/server/ui/LogoutPageGeneratingWebFilter.java index b32c5af1b2..983ddb015b 100644 --- a/web/src/main/java/org/springframework/security/web/server/ui/LogoutPageGeneratingWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/ui/LogoutPageGeneratingWebFilter.java @@ -54,7 +54,6 @@ public class LogoutPageGeneratingWebFilter implements WebFilter { result.setStatusCode(HttpStatus.OK); result.getHeaders().setContentType(MediaType.TEXT_HTML); return result.writeWith(createBuffer(exchange)); - // .doOnError( (error) -> DataBufferUtils.release(buffer)); } private Mono createBuffer(ServerWebExchange exchange) { @@ -67,20 +66,31 @@ public class LogoutPageGeneratingWebFilter implements WebFilter { } private static byte[] createPage(String csrfTokenHtmlInput) { - String page = "\n" + "\n" + " \n" + " \n" - + " \n" - + " \n" + " \n" - + " Confirm Log Out?\n" - + " \n" - + " \n" - + " \n" + " \n" + "
\n" - + "
\n" - + " \n" - + csrfTokenHtmlInput - + " \n" - + "
\n" + "
\n" + " \n" + ""; - - return page.getBytes(Charset.defaultCharset()); + StringBuilder page = new StringBuilder(); + page.append("\n"); + page.append("\n"); + page.append(" \n"); + page.append(" \n"); + page.append(" \n"); + page.append(" \n"); + page.append(" \n"); + page.append(" Confirm Log Out?\n"); + page.append(" \n"); + page.append(" \n"); + page.append(" \n"); + page.append(" \n"); + page.append("
\n"); + page.append("
\n"); + page.append(" \n"); + page.append(csrfTokenHtmlInput); + page.append(" \n"); + page.append("
\n"); + page.append("
\n"); + page.append(" \n"); + page.append(""); + return page.toString().getBytes(Charset.defaultCharset()); } private static String csrfToken(CsrfToken token) { diff --git a/web/src/main/java/org/springframework/security/web/server/util/matcher/AndServerWebExchangeMatcher.java b/web/src/main/java/org/springframework/security/web/server/util/matcher/AndServerWebExchangeMatcher.java index d5f7ae92ec..c898b7639c 100644 --- a/web/src/main/java/org/springframework/security/web/server/util/matcher/AndServerWebExchangeMatcher.java +++ b/web/src/main/java/org/springframework/security/web/server/util/matcher/AndServerWebExchangeMatcher.java @@ -26,6 +26,7 @@ import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.springframework.core.log.LogMessage; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; @@ -56,18 +57,13 @@ public class AndServerWebExchangeMatcher implements ServerWebExchangeMatcher { public Mono matches(ServerWebExchange exchange) { return Mono.defer(() -> { Map variables = new HashMap<>(); - return Flux.fromIterable(this.matchers).doOnNext((it) -> { - if (logger.isDebugEnabled()) { - logger.debug("Trying to match using " + it); - } - }).flatMap((matcher) -> matcher.matches(exchange)) + return Flux.fromIterable(this.matchers) + .doOnNext((matcher) -> logger.debug(LogMessage.format("Trying to match using %s", matcher))) + .flatMap((matcher) -> matcher.matches(exchange)) .doOnNext((matchResult) -> variables.putAll(matchResult.getVariables())).all(MatchResult::isMatch) .flatMap((allMatch) -> allMatch ? MatchResult.match(variables) : MatchResult.notMatch()) - .doOnNext((it) -> { - if (logger.isDebugEnabled()) { - logger.debug(it.isMatch() ? "All requestMatchers returned true" : "Did not match"); - } - }); + .doOnNext((matchResult) -> logger + .debug(matchResult.isMatch() ? "All requestMatchers returned true" : "Did not match")); }); } diff --git a/web/src/main/java/org/springframework/security/web/server/util/matcher/MediaTypeServerWebExchangeMatcher.java b/web/src/main/java/org/springframework/security/web/server/util/matcher/MediaTypeServerWebExchangeMatcher.java index d455e37977..3ddc6b09df 100644 --- a/web/src/main/java/org/springframework/security/web/server/util/matcher/MediaTypeServerWebExchangeMatcher.java +++ b/web/src/main/java/org/springframework/security/web/server/util/matcher/MediaTypeServerWebExchangeMatcher.java @@ -26,6 +26,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Mono; +import org.springframework.core.log.LogMessage; import org.springframework.http.InvalidMediaTypeException; import org.springframework.http.MediaType; import org.springframework.util.Assert; @@ -80,13 +81,9 @@ public class MediaTypeServerWebExchangeMatcher implements ServerWebExchangeMatch this.logger.debug("Failed to parse MediaTypes, returning false", ex); return MatchResult.notMatch(); } - if (this.logger.isDebugEnabled()) { - this.logger.debug("httpRequestMediaTypes=" + httpRequestMediaTypes); - } + this.logger.debug(LogMessage.format("httpRequestMediaTypes=%s", httpRequestMediaTypes)); for (MediaType httpRequestMediaType : httpRequestMediaTypes) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("Processing " + httpRequestMediaType); - } + this.logger.debug(LogMessage.format("Processing %s", httpRequestMediaType)); if (shouldIgnore(httpRequestMediaType)) { this.logger.debug("Ignoring"); continue; @@ -98,10 +95,8 @@ public class MediaTypeServerWebExchangeMatcher implements ServerWebExchangeMatch } for (MediaType matchingMediaType : this.matchingMediaTypes) { boolean isCompatibleWith = matchingMediaType.isCompatibleWith(httpRequestMediaType); - if (this.logger.isDebugEnabled()) { - this.logger.debug(matchingMediaType + " .isCompatibleWith " + httpRequestMediaType + " = " - + isCompatibleWith); - } + this.logger.debug(LogMessage.format("%s .isCompatibleWith %s = %s", matchingMediaType, + httpRequestMediaType, isCompatibleWith)); if (isCompatibleWith) { return MatchResult.match(); } diff --git a/web/src/main/java/org/springframework/security/web/server/util/matcher/NegatedServerWebExchangeMatcher.java b/web/src/main/java/org/springframework/security/web/server/util/matcher/NegatedServerWebExchangeMatcher.java index e0b3b17993..ce414b4087 100644 --- a/web/src/main/java/org/springframework/security/web/server/util/matcher/NegatedServerWebExchangeMatcher.java +++ b/web/src/main/java/org/springframework/security/web/server/util/matcher/NegatedServerWebExchangeMatcher.java @@ -20,6 +20,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Mono; +import org.springframework.core.log.LogMessage; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; @@ -44,12 +45,12 @@ public class NegatedServerWebExchangeMatcher implements ServerWebExchangeMatcher @Override public Mono matches(ServerWebExchange exchange) { - return this.matcher.matches(exchange).flatMap((m) -> m.isMatch() ? MatchResult.notMatch() : MatchResult.match()) - .doOnNext((it) -> { - if (logger.isDebugEnabled()) { - logger.debug("matches = " + it.isMatch()); - } - }); + return this.matcher.matches(exchange).flatMap(this::negate) + .doOnNext((matchResult) -> logger.debug(LogMessage.format("matches = %s", matchResult.isMatch()))); + } + + private Mono negate(MatchResult matchResult) { + return matchResult.isMatch() ? MatchResult.notMatch() : MatchResult.match(); } @Override diff --git a/web/src/main/java/org/springframework/security/web/server/util/matcher/OrServerWebExchangeMatcher.java b/web/src/main/java/org/springframework/security/web/server/util/matcher/OrServerWebExchangeMatcher.java index 5ceb1b1788..74eca6a7d4 100644 --- a/web/src/main/java/org/springframework/security/web/server/util/matcher/OrServerWebExchangeMatcher.java +++ b/web/src/main/java/org/springframework/security/web/server/util/matcher/OrServerWebExchangeMatcher.java @@ -24,6 +24,7 @@ import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.springframework.core.log.LogMessage; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; @@ -52,21 +53,16 @@ public class OrServerWebExchangeMatcher implements ServerWebExchangeMatcher { @Override public Mono matches(ServerWebExchange exchange) { - return Flux.fromIterable(this.matchers).doOnNext((it) -> { - if (logger.isDebugEnabled()) { - logger.debug("Trying to match using " + it); - } - }).flatMap((m) -> m.matches(exchange)).filter(MatchResult::isMatch).next().switchIfEmpty(MatchResult.notMatch()) - .doOnNext((it) -> { - if (logger.isDebugEnabled()) { - logger.debug(it.isMatch() ? "matched" : "No matches found"); - } - }); + return Flux.fromIterable(this.matchers) + .doOnNext((matcher) -> logger.debug(LogMessage.format("Trying to match using %s", matcher))) + .flatMap((matcher) -> matcher.matches(exchange)).filter(MatchResult::isMatch).next() + .switchIfEmpty(MatchResult.notMatch()) + .doOnNext((matchResult) -> logger.debug(matchResult.isMatch() ? "matched" : "No matches found")); } @Override public String toString() { - return "OrServerWebExchangeMatcher{" + "matchers=" + this.matchers + '}'; + return "OrServerWebExchangeMatcher{matchers=" + this.matchers + '}'; } } diff --git a/web/src/main/java/org/springframework/security/web/server/util/matcher/ServerWebExchangeMatchers.java b/web/src/main/java/org/springframework/security/web/server/util/matcher/ServerWebExchangeMatchers.java index 9fe5da54a6..5b7bec52c7 100644 --- a/web/src/main/java/org/springframework/security/web/server/util/matcher/ServerWebExchangeMatchers.java +++ b/web/src/main/java/org/springframework/security/web/server/util/matcher/ServerWebExchangeMatchers.java @@ -32,6 +32,9 @@ import org.springframework.web.server.ServerWebExchange; */ public abstract class ServerWebExchangeMatchers { + private ServerWebExchangeMatchers() { + } + /** * Creates a matcher that matches on the specific method and any of the provided * patterns. @@ -75,14 +78,13 @@ public abstract class ServerWebExchangeMatchers { // which otherwise can cause problems with adding multiple entries to an ordered // LinkedHashMap return new ServerWebExchangeMatcher() { + @Override public Mono matches(ServerWebExchange exchange) { return ServerWebExchangeMatcher.MatchResult.match(); } + }; } - private ServerWebExchangeMatchers() { - } - } diff --git a/web/src/main/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessor.java b/web/src/main/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessor.java index d904f4d1fe..1ea18dbe78 100644 --- a/web/src/main/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessor.java +++ b/web/src/main/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessor.java @@ -65,7 +65,6 @@ public final class CsrfRequestDataValueProcessor implements RequestDataValueProc request.removeAttribute(this.DISABLE_CSRF_TOKEN_ATTR); return Collections.emptyMap(); } - CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); if (token == null) { return Collections.emptyMap(); diff --git a/web/src/main/java/org/springframework/security/web/servlet/util/matcher/MvcRequestMatcher.java b/web/src/main/java/org/springframework/security/web/servlet/util/matcher/MvcRequestMatcher.java index 8ca395e047..415e6ca6c1 100644 --- a/web/src/main/java/org/springframework/security/web/servlet/util/matcher/MvcRequestMatcher.java +++ b/web/src/main/java/org/springframework/security/web/servlet/util/matcher/MvcRequestMatcher.java @@ -127,17 +127,13 @@ public class MvcRequestMatcher implements RequestMatcher, RequestVariablesExtrac public String toString() { StringBuilder sb = new StringBuilder(); sb.append("Mvc [pattern='").append(this.pattern).append("'"); - if (this.servletPath != null) { sb.append(", servletPath='").append(this.servletPath).append("'"); } - if (this.method != null) { sb.append(", ").append(this.method); } - sb.append("]"); - return sb.toString(); } diff --git a/web/src/main/java/org/springframework/security/web/servletapi/HttpServlet3RequestFactory.java b/web/src/main/java/org/springframework/security/web/servletapi/HttpServlet3RequestFactory.java index 22ca477e2d..3a2dd9d36d 100644 --- a/web/src/main/java/org/springframework/security/web/servletapi/HttpServlet3RequestFactory.java +++ b/web/src/main/java/org/springframework/security/web/servletapi/HttpServlet3RequestFactory.java @@ -17,7 +17,6 @@ package org.springframework.security.web.servletapi; import java.io.IOException; -import java.security.Principal; import java.util.List; import javax.servlet.AsyncContext; @@ -225,17 +224,21 @@ final class HttpServlet3RequestFactory implements HttpServletRequestFactory { super.login(username, password); return; } - Authentication authentication; - try { - authentication = authManager.authenticate(new UsernamePasswordAuthenticationToken(username, password)); - } - catch (AuthenticationException loginFailed) { - SecurityContextHolder.clearContext(); - throw new ServletException(loginFailed.getMessage(), loginFailed); - } + Authentication authentication = getAuthentication(authManager, username, password); SecurityContextHolder.getContext().setAuthentication(authentication); } + private Authentication getAuthentication(AuthenticationManager authManager, String username, String password) + throws ServletException { + try { + return authManager.authenticate(new UsernamePasswordAuthenticationToken(username, password)); + } + catch (AuthenticationException ex) { + SecurityContextHolder.clearContext(); + throw new ServletException(ex.getMessage(), ex); + } + } + @Override public void logout() throws ServletException { List handlers = HttpServlet3RequestFactory.this.logoutHandlers; @@ -252,8 +255,7 @@ final class HttpServlet3RequestFactory implements HttpServletRequestFactory { } private boolean isAuthenticated() { - Principal userPrincipal = getUserPrincipal(); - return userPrincipal != null; + return getUserPrincipal() != null; } } diff --git a/web/src/main/java/org/springframework/security/web/servletapi/SecurityContextHolderAwareRequestWrapper.java b/web/src/main/java/org/springframework/security/web/servletapi/SecurityContextHolderAwareRequestWrapper.java index 42932190ff..e9d434e4e0 100644 --- a/web/src/main/java/org/springframework/security/web/servletapi/SecurityContextHolderAwareRequestWrapper.java +++ b/web/src/main/java/org/springframework/security/web/servletapi/SecurityContextHolderAwareRequestWrapper.java @@ -88,12 +88,7 @@ public class SecurityContextHolderAwareRequestWrapper extends HttpServletRequest */ private Authentication getAuthentication() { Authentication auth = SecurityContextHolder.getContext().getAuthentication(); - - if (!this.trustResolver.isAnonymous(auth)) { - return auth; - } - - return null; + return (!this.trustResolver.isAnonymous(auth)) ? auth : null; } /** @@ -105,15 +100,12 @@ public class SecurityContextHolderAwareRequestWrapper extends HttpServletRequest @Override public String getRemoteUser() { Authentication auth = getAuthentication(); - if ((auth == null) || (auth.getPrincipal() == null)) { return null; } - if (auth.getPrincipal() instanceof UserDetails) { return ((UserDetails) auth.getPrincipal()).getUsername(); } - return auth.getPrincipal().toString(); } @@ -125,37 +117,29 @@ public class SecurityContextHolderAwareRequestWrapper extends HttpServletRequest @Override public Principal getUserPrincipal() { Authentication auth = getAuthentication(); - if ((auth == null) || (auth.getPrincipal() == null)) { return null; } - return auth; } private boolean isGranted(String role) { Authentication auth = getAuthentication(); - if (this.rolePrefix != null && role != null && !role.startsWith(this.rolePrefix)) { role = this.rolePrefix + role; } - if ((auth == null) || (auth.getPrincipal() == null)) { return false; } - Collection authorities = auth.getAuthorities(); - if (authorities == null) { return false; } - for (GrantedAuthority grantedAuthority : authorities) { if (role.equals(grantedAuthority.getAuthority())) { return true; } } - return false; } diff --git a/web/src/main/java/org/springframework/security/web/session/ConcurrentSessionFilter.java b/web/src/main/java/org/springframework/security/web/session/ConcurrentSessionFilter.java index 8bdc2af005..57d5928ec0 100644 --- a/web/src/main/java/org/springframework/security/web/session/ConcurrentSessionFilter.java +++ b/web/src/main/java/org/springframework/security/web/session/ConcurrentSessionFilter.java @@ -27,6 +27,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; +import org.springframework.core.log.LogMessage; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.session.SessionInformation; @@ -101,7 +102,6 @@ public class ConcurrentSessionFilter extends GenericFilterBean { HttpServletRequest request = event.getRequest(); HttpServletResponse response = event.getResponse(); SessionInformation info = event.getSessionInformation(); - this.redirectStrategy.sendRedirect(request, response, determineExpiredUrl(request, info)); }; } @@ -120,35 +120,30 @@ public class ConcurrentSessionFilter extends GenericFilterBean { } @Override - public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - HttpServletRequest request = (HttpServletRequest) req; - HttpServletResponse response = (HttpServletResponse) res; + doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain); + } + private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) + throws IOException, ServletException { HttpSession session = request.getSession(false); - if (session != null) { SessionInformation info = this.sessionRegistry.getSessionInformation(session.getId()); - if (info != null) { if (info.isExpired()) { // Expired - abort processing - if (this.logger.isDebugEnabled()) { - this.logger.debug("Requested session ID " + request.getRequestedSessionId() + " has expired."); - } + this.logger.debug(LogMessage + .of(() -> "Requested session ID " + request.getRequestedSessionId() + " has expired.")); doLogout(request, response); - this.sessionInformationExpiredStrategy .onExpiredSessionDetected(new SessionInformationExpiredEvent(info, request, response)); return; } - else { - // Non-expired - update last request date/time - this.sessionRegistry.refreshLastRequest(info.getSessionId()); - } + // Non-expired - update last request date/time + this.sessionRegistry.refreshLastRequest(info.getSessionId()); } } - chain.doFilter(request, response); } diff --git a/web/src/main/java/org/springframework/security/web/session/HttpSessionDestroyedEvent.java b/web/src/main/java/org/springframework/security/web/session/HttpSessionDestroyedEvent.java index 7cce0db2b4..ec06003d64 100644 --- a/web/src/main/java/org/springframework/security/web/session/HttpSessionDestroyedEvent.java +++ b/web/src/main/java/org/springframework/security/web/session/HttpSessionDestroyedEvent.java @@ -47,11 +47,8 @@ public class HttpSessionDestroyedEvent extends SessionDestroyedEvent { @Override public List getSecurityContexts() { HttpSession session = getSession(); - Enumeration attributes = session.getAttributeNames(); - ArrayList contexts = new ArrayList<>(); - while (attributes.hasMoreElements()) { String attributeName = attributes.nextElement(); Object attributeValue = session.getAttribute(attributeName); @@ -59,7 +56,6 @@ public class HttpSessionDestroyedEvent extends SessionDestroyedEvent { contexts.add((SecurityContext) attributeValue); } } - return contexts; } diff --git a/web/src/main/java/org/springframework/security/web/session/HttpSessionEventPublisher.java b/web/src/main/java/org/springframework/security/web/session/HttpSessionEventPublisher.java index 6dc5162626..146922129a 100644 --- a/web/src/main/java/org/springframework/security/web/session/HttpSessionEventPublisher.java +++ b/web/src/main/java/org/springframework/security/web/session/HttpSessionEventPublisher.java @@ -17,6 +17,7 @@ package org.springframework.security.web.session; import javax.servlet.ServletContext; +import javax.servlet.http.HttpSession; import javax.servlet.http.HttpSessionEvent; import javax.servlet.http.HttpSessionIdListener; import javax.servlet.http.HttpSessionListener; @@ -25,6 +26,8 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationEvent; +import org.springframework.core.log.LogMessage; import org.springframework.security.web.context.support.SecurityWebApplicationContextUtils; /** @@ -59,14 +62,7 @@ public class HttpSessionEventPublisher implements HttpSessionListener, HttpSessi */ @Override public void sessionCreated(HttpSessionEvent event) { - HttpSessionCreatedEvent e = new HttpSessionCreatedEvent(event.getSession()); - Log log = LogFactory.getLog(LOGGER_NAME); - - if (log.isDebugEnabled()) { - log.debug("Publishing event: " + e); - } - - getContext(event.getSession().getServletContext()).publishEvent(e); + extracted(event.getSession(), new HttpSessionCreatedEvent(event.getSession())); } /** @@ -76,26 +72,18 @@ public class HttpSessionEventPublisher implements HttpSessionListener, HttpSessi */ @Override public void sessionDestroyed(HttpSessionEvent event) { - HttpSessionDestroyedEvent e = new HttpSessionDestroyedEvent(event.getSession()); - Log log = LogFactory.getLog(LOGGER_NAME); - - if (log.isDebugEnabled()) { - log.debug("Publishing event: " + e); - } - - getContext(event.getSession().getServletContext()).publishEvent(e); + extracted(event.getSession(), new HttpSessionDestroyedEvent(event.getSession())); } @Override public void sessionIdChanged(HttpSessionEvent event, String oldSessionId) { - HttpSessionIdChangedEvent e = new HttpSessionIdChangedEvent(event.getSession(), oldSessionId); + extracted(event.getSession(), new HttpSessionIdChangedEvent(event.getSession(), oldSessionId)); + } + + private void extracted(HttpSession session, ApplicationEvent e) { Log log = LogFactory.getLog(LOGGER_NAME); - - if (log.isDebugEnabled()) { - log.debug("Publishing event: " + e); - } - - getContext(event.getSession().getServletContext()).publishEvent(e); + log.debug(LogMessage.format("Publishing event: %s", e)); + getContext(session.getServletContext()).publishEvent(e); } } diff --git a/web/src/main/java/org/springframework/security/web/session/SessionInformationExpiredEvent.java b/web/src/main/java/org/springframework/security/web/session/SessionInformationExpiredEvent.java index 7211a99e2a..c23c882d19 100644 --- a/web/src/main/java/org/springframework/security/web/session/SessionInformationExpiredEvent.java +++ b/web/src/main/java/org/springframework/security/web/session/SessionInformationExpiredEvent.java @@ -31,9 +31,9 @@ import org.springframework.util.Assert; */ public final class SessionInformationExpiredEvent extends ApplicationEvent { - private HttpServletRequest request; + private final HttpServletRequest request; - private HttpServletResponse response; + private final HttpServletResponse response; /** * Creates a new instance diff --git a/web/src/main/java/org/springframework/security/web/session/SessionManagementFilter.java b/web/src/main/java/org/springframework/security/web/session/SessionManagementFilter.java index 13ea888eb8..510aa072ac 100644 --- a/web/src/main/java/org/springframework/security/web/session/SessionManagementFilter.java +++ b/web/src/main/java/org/springframework/security/web/session/SessionManagementFilter.java @@ -25,6 +25,7 @@ import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AuthenticationTrustResolver; import org.springframework.security.authentication.AuthenticationTrustResolverImpl; import org.springframework.security.core.Authentication; @@ -75,21 +76,20 @@ public class SessionManagementFilter extends GenericFilterBean { } @Override - public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - HttpServletRequest request = (HttpServletRequest) req; - HttpServletResponse response = (HttpServletResponse) res; + doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain); + } + private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) + throws IOException, ServletException { if (request.getAttribute(FILTER_APPLIED) != null) { chain.doFilter(request, response); return; } - request.setAttribute(FILTER_APPLIED, Boolean.TRUE); - if (!this.securityContextRepository.containsContext(request)) { Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); - if (authentication != null && !this.trustResolver.isAnonymous(authentication)) { // The user has been authenticated during the current request, so call the // session strategy @@ -101,23 +101,19 @@ public class SessionManagementFilter extends GenericFilterBean { this.logger.debug("SessionAuthenticationStrategy rejected the authentication object", ex); SecurityContextHolder.clearContext(); this.failureHandler.onAuthenticationFailure(request, response, ex); - return; } // Eagerly save the security context to make it available for any possible - // re-entrant - // requests which may occur before the current request completes. - // SEC-1396. + // re-entrant requests which may occur before the current request + // completes. SEC-1396. this.securityContextRepository.saveContext(SecurityContextHolder.getContext(), request, response); } else { // No security context or authentication present. Check for a session // timeout if (request.getRequestedSessionId() != null && !request.isRequestedSessionIdValid()) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("Requested session ID " + request.getRequestedSessionId() + " is invalid."); - } - + this.logger.debug( + LogMessage.format("Requested session ID %s is invalid.", request.getRequestedSessionId())); if (this.invalidSessionStrategy != null) { this.invalidSessionStrategy.onInvalidSessionDetected(request, response); return; @@ -125,7 +121,6 @@ public class SessionManagementFilter extends GenericFilterBean { } } } - chain.doFilter(request, response); } diff --git a/web/src/main/java/org/springframework/security/web/util/RedirectUrlBuilder.java b/web/src/main/java/org/springframework/security/web/util/RedirectUrlBuilder.java index f41e39f253..894b69e7a9 100644 --- a/web/src/main/java/org/springframework/security/web/util/RedirectUrlBuilder.java +++ b/web/src/main/java/org/springframework/security/web/util/RedirectUrlBuilder.java @@ -43,9 +43,7 @@ public class RedirectUrlBuilder { private String query; public void setScheme(String scheme) { - if (!("http".equals(scheme) | "https".equals(scheme))) { - throw new IllegalArgumentException("Unsupported scheme '" + scheme + "'"); - } + Assert.isTrue("http".equals(scheme) || "https".equals(scheme), () -> "Unsupported scheme '" + scheme + "'"); this.scheme = scheme; } @@ -75,33 +73,25 @@ public class RedirectUrlBuilder { public String getUrl() { StringBuilder sb = new StringBuilder(); - Assert.notNull(this.scheme, "scheme cannot be null"); Assert.notNull(this.serverName, "serverName cannot be null"); - sb.append(this.scheme).append("://").append(this.serverName); - // Append the port number if it's not standard for the scheme if (this.port != (this.scheme.equals("http") ? 80 : 443)) { sb.append(":").append(this.port); } - if (this.contextPath != null) { sb.append(this.contextPath); } - if (this.servletPath != null) { sb.append(this.servletPath); } - if (this.pathInfo != null) { sb.append(this.pathInfo); } - if (this.query != null) { sb.append("?").append(this.query); } - return sb.toString(); } diff --git a/web/src/main/java/org/springframework/security/web/util/TextEscapeUtils.java b/web/src/main/java/org/springframework/security/web/util/TextEscapeUtils.java index abedceb714..71653664c6 100644 --- a/web/src/main/java/org/springframework/security/web/util/TextEscapeUtils.java +++ b/web/src/main/java/org/springframework/security/web/util/TextEscapeUtils.java @@ -28,57 +28,51 @@ public abstract class TextEscapeUtils { if (s == null || s.length() == 0) { return s; } - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < s.length(); i++) { - char c = s.charAt(i); - - if (c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9') { - sb.append(c); + char ch = s.charAt(i); + if (ch >= 'a' && ch <= 'z' || ch >= 'A' && ch <= 'Z' || ch >= '0' && ch <= '9') { + sb.append(ch); } - else if (c == '<') { + else if (ch == '<') { sb.append("<"); } - else if (c == '>') { + else if (ch == '>') { sb.append(">"); } - else if (c == '&') { + else if (ch == '&') { sb.append("&"); } - else if (Character.isWhitespace(c)) { - sb.append("&#").append((int) c).append(";"); + else if (Character.isWhitespace(ch)) { + sb.append("&#").append((int) ch).append(";"); } - else if (Character.isISOControl(c)) { + else if (Character.isISOControl(ch)) { // ignore control chars } - else if (Character.isHighSurrogate(c)) { + else if (Character.isHighSurrogate(ch)) { if (i + 1 >= s.length()) { // Unexpected end throw new IllegalArgumentException("Missing low surrogate character at end of string"); } char low = s.charAt(i + 1); - if (!Character.isLowSurrogate(low)) { throw new IllegalArgumentException( "Expected low surrogate character but found value = " + (int) low); } - - int codePoint = Character.toCodePoint(c, low); + int codePoint = Character.toCodePoint(ch, low); if (Character.isDefined(codePoint)) { sb.append("&#").append(codePoint).append(";"); } i++; // skip the next character as we have already dealt with it } - else if (Character.isLowSurrogate(c)) { - throw new IllegalArgumentException("Unexpected low surrogate character, value = " + (int) c); + else if (Character.isLowSurrogate(ch)) { + throw new IllegalArgumentException("Unexpected low surrogate character, value = " + (int) ch); } - else if (Character.isDefined(c)) { - sb.append("&#").append((int) c).append(";"); + else if (Character.isDefined(ch)) { + sb.append("&#").append((int) ch).append(";"); } // Ignore anything else } - return sb.toString(); } diff --git a/web/src/main/java/org/springframework/security/web/util/ThrowableAnalyzer.java b/web/src/main/java/org/springframework/security/web/util/ThrowableAnalyzer.java index 73207cdcb1..2fac52493c 100755 --- a/web/src/main/java/org/springframework/security/web/util/ThrowableAnalyzer.java +++ b/web/src/main/java/org/springframework/security/web/util/ThrowableAnalyzer.java @@ -63,12 +63,10 @@ public class ThrowableAnalyzer { if (class1.isAssignableFrom(class2)) { return 1; } - else if (class2.isAssignableFrom(class1)) { + if (class2.isAssignableFrom(class1)) { return -1; } - else { - return class1.getName().compareTo(class2.getName()); - } + return class1.getName().compareTo(class2.getName()); }; /** @@ -82,7 +80,6 @@ public class ThrowableAnalyzer { */ public ThrowableAnalyzer() { this.extractorMap = new TreeMap<>(CLASS_HIERARCHY_COMPARATOR); - initExtractorMap(); } @@ -97,7 +94,6 @@ public class ThrowableAnalyzer { protected final void registerExtractor(Class throwableType, ThrowableCauseExtractor extractor) { Assert.notNull(extractor, "Invalid extractor: null"); - this.extractorMap.put(throwableType, extractor); } @@ -155,18 +151,13 @@ public class ThrowableAnalyzer { * @see #initExtractorMap() */ public final Throwable[] determineCauseChain(Throwable throwable) { - if (throwable == null) { - throw new IllegalArgumentException("Invalid throwable: null"); - } - + Assert.notNull(throwable, "Invalid throwable: null"); List chain = new ArrayList<>(); Throwable currentThrowable = throwable; - while (currentThrowable != null) { chain.add(currentThrowable); currentThrowable = extractCause(currentThrowable); } - return chain.toArray(new Throwable[0]); } @@ -183,7 +174,6 @@ public class ThrowableAnalyzer { return extractor.extractCause(throwable); } } - return null; } @@ -206,7 +196,6 @@ public class ThrowableAnalyzer { } } } - return null; } @@ -226,16 +215,10 @@ public class ThrowableAnalyzer { if (expectedBaseType == null) { return; } - - if (throwable == null) { - throw new IllegalArgumentException("Invalid throwable: null"); - } + Assert.notNull(throwable, "Invalid throwable: null"); Class throwableType = throwable.getClass(); - - if (!expectedBaseType.isAssignableFrom(throwableType)) { - throw new IllegalArgumentException("Invalid type: '" + throwableType.getName() - + "'. Has to be a subclass of '" + expectedBaseType.getName() + "'"); - } + Assert.isTrue(expectedBaseType.isAssignableFrom(throwableType), () -> "Invalid type: '" + + throwableType.getName() + "'. Has to be a subclass of '" + expectedBaseType.getName() + "'"); } } diff --git a/web/src/main/java/org/springframework/security/web/util/UrlUtils.java b/web/src/main/java/org/springframework/security/web/util/UrlUtils.java index 8ff574f66e..98f4bbf5c8 100644 --- a/web/src/main/java/org/springframework/security/web/util/UrlUtils.java +++ b/web/src/main/java/org/springframework/security/web/util/UrlUtils.java @@ -30,6 +30,8 @@ import javax.servlet.http.HttpServletRequest; */ public final class UrlUtils { + private static final Pattern ABSOLUTE_URL = Pattern.compile("\\A[a-z0-9.+-]+://.*", Pattern.CASE_INSENSITIVE); + private UrlUtils() { } @@ -47,12 +49,9 @@ public final class UrlUtils { */ public static String buildFullRequestUrl(String scheme, String serverName, int serverPort, String requestURI, String queryString) { - scheme = scheme.toLowerCase(); - StringBuilder url = new StringBuilder(); url.append(scheme).append("://").append(serverName); - // Only add port if not default if ("http".equals(scheme)) { if (serverPort != 80) { @@ -64,15 +63,12 @@ public final class UrlUtils { url.append(":").append(serverPort); } } - // Use the requestURI as it is encoded (RFC 3986) and hence suitable for // redirects. url.append(requestURI); - if (queryString != null) { url.append("?").append(queryString); } - return url.toString(); } @@ -104,9 +100,7 @@ public final class UrlUtils { */ private static String buildRequestUrl(String servletPath, String requestURI, String contextPath, String pathInfo, String queryString) { - StringBuilder url = new StringBuilder(); - if (servletPath != null) { url.append(servletPath); if (pathInfo != null) { @@ -116,11 +110,9 @@ public final class UrlUtils { else { url.append(requestURI.substring(contextPath.length())); } - if (queryString != null) { url.append("?").append(queryString); } - return url.toString(); } @@ -136,12 +128,7 @@ public final class UrlUtils { * defined in RFC 1738. */ public static boolean isAbsoluteUrl(String url) { - if (url == null) { - return false; - } - final Pattern ABSOLUTE_URL = Pattern.compile("\\A[a-z0-9.+-]+://.*", Pattern.CASE_INSENSITIVE); - - return ABSOLUTE_URL.matcher(url).matches(); + return (url != null) ? ABSOLUTE_URL.matcher(url).matches() : false; } } diff --git a/web/src/main/java/org/springframework/security/web/util/matcher/AndRequestMatcher.java b/web/src/main/java/org/springframework/security/web/util/matcher/AndRequestMatcher.java index 567dd989a0..d3c9808677 100644 --- a/web/src/main/java/org/springframework/security/web/util/matcher/AndRequestMatcher.java +++ b/web/src/main/java/org/springframework/security/web/util/matcher/AndRequestMatcher.java @@ -24,6 +24,7 @@ import javax.servlet.http.HttpServletRequest; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.util.Assert; /** @@ -45,9 +46,7 @@ public final class AndRequestMatcher implements RequestMatcher { */ public AndRequestMatcher(List requestMatchers) { Assert.notEmpty(requestMatchers, "requestMatchers must contain a value"); - if (requestMatchers.contains(null)) { - throw new IllegalArgumentException("requestMatchers cannot contain null values"); - } + Assert.isTrue(!requestMatchers.contains(null), "requestMatchers cannot contain null values"); this.requestMatchers = requestMatchers; } @@ -62,9 +61,7 @@ public final class AndRequestMatcher implements RequestMatcher { @Override public boolean matches(HttpServletRequest request) { for (RequestMatcher matcher : this.requestMatchers) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("Trying to match using " + matcher); - } + this.logger.debug(LogMessage.format("Trying to match using %s", matcher)); if (!matcher.matches(request)) { this.logger.debug("Did not match"); return false; diff --git a/web/src/main/java/org/springframework/security/web/util/matcher/AntPathRequestMatcher.java b/web/src/main/java/org/springframework/security/web/util/matcher/AntPathRequestMatcher.java index 0e707cfdfe..4b63dcda1a 100644 --- a/web/src/main/java/org/springframework/security/web/util/matcher/AntPathRequestMatcher.java +++ b/web/src/main/java/org/springframework/security/web/util/matcher/AntPathRequestMatcher.java @@ -24,6 +24,7 @@ import javax.servlet.http.HttpServletRequest; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.http.HttpMethod; import org.springframework.util.AntPathMatcher; import org.springframework.util.Assert; @@ -116,7 +117,6 @@ public final class AntPathRequestMatcher implements RequestMatcher, RequestVaria UrlPathHelper urlPathHelper) { Assert.hasText(pattern, "Pattern cannot be null or empty"); this.caseSensitive = caseSensitive; - if (pattern.equals(MATCH_ALL) || pattern.equals("**")) { pattern = MATCH_ALL; this.matcher = null; @@ -133,7 +133,6 @@ public final class AntPathRequestMatcher implements RequestMatcher, RequestVaria this.matcher = new SpringAntMatcher(pattern, caseSensitive); } } - this.pattern = pattern; this.httpMethod = StringUtils.hasText(httpMethod) ? HttpMethod.valueOf(httpMethod) : null; this.urlPathHelper = urlPathHelper; @@ -149,28 +148,17 @@ public final class AntPathRequestMatcher implements RequestMatcher, RequestVaria public boolean matches(HttpServletRequest request) { if (this.httpMethod != null && StringUtils.hasText(request.getMethod()) && this.httpMethod != valueOf(request.getMethod())) { - if (logger.isDebugEnabled()) { - logger.debug("Request '" + request.getMethod() + " " + getRequestPath(request) + "'" - + " doesn't match '" + this.httpMethod + " " + this.pattern + "'"); - } - + logger.debug(LogMessage.of(() -> "Request '" + request.getMethod() + " " + getRequestPath(request) + "'" + + " doesn't match '" + this.httpMethod + " " + this.pattern + "'")); return false; } - if (this.pattern.equals(MATCH_ALL)) { - if (logger.isDebugEnabled()) { - logger.debug("Request '" + getRequestPath(request) + "' matched by universal pattern '/**'"); - } - + logger.debug(LogMessage + .of(() -> "Request '" + getRequestPath(request) + "' matched by universal pattern '/**'")); return true; } - String url = getRequestPath(request); - - if (logger.isDebugEnabled()) { - logger.debug("Checking match of request : '" + url + "'; against '" + this.pattern + "'"); - } - + logger.debug(LogMessage.format("Checking match of request : '%s'; against '%s'", url, this.pattern)); return this.matcher.matches(url); } @@ -194,12 +182,10 @@ public final class AntPathRequestMatcher implements RequestMatcher, RequestVaria return this.urlPathHelper.getPathWithinApplication(request); } String url = request.getServletPath(); - String pathInfo = request.getPathInfo(); if (pathInfo != null) { url = StringUtils.hasLength(url) ? url + pathInfo : pathInfo; } - return url; } @@ -212,7 +198,6 @@ public final class AntPathRequestMatcher implements RequestMatcher, RequestVaria if (!(obj instanceof AntPathRequestMatcher)) { return false; } - AntPathRequestMatcher other = (AntPathRequestMatcher) obj; return this.pattern.equals(other.pattern) && this.httpMethod == other.httpMethod && this.caseSensitive == other.caseSensitive; @@ -230,13 +215,10 @@ public final class AntPathRequestMatcher implements RequestMatcher, RequestVaria public String toString() { StringBuilder sb = new StringBuilder(); sb.append("Ant [pattern='").append(this.pattern).append("'"); - if (this.httpMethod != null) { sb.append(", ").append(this.httpMethod); } - sb.append("]"); - return sb.toString(); } @@ -251,9 +233,8 @@ public final class AntPathRequestMatcher implements RequestMatcher, RequestVaria return HttpMethod.valueOf(method); } catch (IllegalArgumentException ex) { + return null; } - - return null; } private interface Matcher { @@ -306,7 +287,7 @@ public final class AntPathRequestMatcher implements RequestMatcher, RequestVaria private final boolean caseSensitive; private SubpathMatcher(String subpath, boolean caseSensitive) { - assert !subpath.contains("*"); + Assert.isTrue(!subpath.contains("*"), "subpath cannot contain \"*\""); this.subpath = caseSensitive ? subpath : subpath.toLowerCase(); this.length = subpath.length(); this.caseSensitive = caseSensitive; diff --git a/web/src/main/java/org/springframework/security/web/util/matcher/AnyRequestMatcher.java b/web/src/main/java/org/springframework/security/web/util/matcher/AnyRequestMatcher.java index 7d83e3cd0a..52a818cc5d 100644 --- a/web/src/main/java/org/springframework/security/web/util/matcher/AnyRequestMatcher.java +++ b/web/src/main/java/org/springframework/security/web/util/matcher/AnyRequestMatcher.java @@ -28,6 +28,9 @@ public final class AnyRequestMatcher implements RequestMatcher { public static final RequestMatcher INSTANCE = new AnyRequestMatcher(); + private AnyRequestMatcher() { + } + @Override public boolean matches(HttpServletRequest request) { return true; @@ -50,7 +53,4 @@ public final class AnyRequestMatcher implements RequestMatcher { return "any request"; } - private AnyRequestMatcher() { - } - } diff --git a/web/src/main/java/org/springframework/security/web/util/matcher/ELRequestMatcherContext.java b/web/src/main/java/org/springframework/security/web/util/matcher/ELRequestMatcherContext.java index de339c3d3a..d6988b3284 100644 --- a/web/src/main/java/org/springframework/security/web/util/matcher/ELRequestMatcherContext.java +++ b/web/src/main/java/org/springframework/security/web/util/matcher/ELRequestMatcherContext.java @@ -34,15 +34,7 @@ class ELRequestMatcherContext { public boolean hasHeader(String headerName, String value) { String header = this.request.getHeader(headerName); - if (!StringUtils.hasText(header)) { - return false; - } - - if (header.contains(value)) { - return true; - } - - return false; + return StringUtils.hasText(header) && header.contains(value); } } diff --git a/web/src/main/java/org/springframework/security/web/util/matcher/IpAddressMatcher.java b/web/src/main/java/org/springframework/security/web/util/matcher/IpAddressMatcher.java index c53686db2d..2666905c6a 100644 --- a/web/src/main/java/org/springframework/security/web/util/matcher/IpAddressMatcher.java +++ b/web/src/main/java/org/springframework/security/web/util/matcher/IpAddressMatcher.java @@ -47,7 +47,6 @@ public final class IpAddressMatcher implements RequestMatcher { * come. */ public IpAddressMatcher(String ipAddress) { - if (ipAddress.indexOf('/') > 0) { String[] addressAndMask = StringUtils.split(ipAddress, "/"); ipAddress = addressAndMask[0]; @@ -68,33 +67,24 @@ public final class IpAddressMatcher implements RequestMatcher { public boolean matches(String address) { InetAddress remoteAddress = parseAddress(address); - if (!this.requiredAddress.getClass().equals(remoteAddress.getClass())) { return false; } - if (this.nMaskBits < 0) { return remoteAddress.equals(this.requiredAddress); } - byte[] remAddr = remoteAddress.getAddress(); byte[] reqAddr = this.requiredAddress.getAddress(); - int nMaskFullBytes = this.nMaskBits / 8; byte finalByte = (byte) (0xFF00 >> (this.nMaskBits & 0x07)); - - // System.out.println("Mask is " + new sun.misc.HexDumpEncoder().encode(mask)); - for (int i = 0; i < nMaskFullBytes; i++) { if (remAddr[i] != reqAddr[i]) { return false; } } - if (finalByte != 0) { return (remAddr[nMaskFullBytes] & finalByte) == (reqAddr[nMaskFullBytes] & finalByte); } - return true; } diff --git a/web/src/main/java/org/springframework/security/web/util/matcher/MediaTypeRequestMatcher.java b/web/src/main/java/org/springframework/security/web/util/matcher/MediaTypeRequestMatcher.java index 29c22f81c6..7178bc3ed6 100644 --- a/web/src/main/java/org/springframework/security/web/util/matcher/MediaTypeRequestMatcher.java +++ b/web/src/main/java/org/springframework/security/web/util/matcher/MediaTypeRequestMatcher.java @@ -27,6 +27,7 @@ import javax.servlet.http.HttpServletRequest; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.http.MediaType; import org.springframework.util.Assert; import org.springframework.web.HttpMediaTypeNotAcceptableException; @@ -206,28 +207,22 @@ public final class MediaTypeRequestMatcher implements RequestMatcher { this.logger.debug("Failed to parse MediaTypes, returning false", ex); return false; } - if (this.logger.isDebugEnabled()) { - this.logger.debug("httpRequestMediaTypes=" + httpRequestMediaTypes); - } + this.logger.debug(LogMessage.format("httpRequestMediaTypes=%s", httpRequestMediaTypes)); for (MediaType httpRequestMediaType : httpRequestMediaTypes) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("Processing " + httpRequestMediaType); - } + this.logger.debug(LogMessage.format("Processing %s", httpRequestMediaType)); if (shouldIgnore(httpRequestMediaType)) { this.logger.debug("Ignoring"); continue; } if (this.useEquals) { boolean isEqualTo = this.matchingMediaTypes.contains(httpRequestMediaType); - this.logger.debug("isEqualTo " + isEqualTo); + this.logger.debug(LogMessage.format("isEqualTo %s", isEqualTo)); return isEqualTo; } for (MediaType matchingMediaType : this.matchingMediaTypes) { boolean isCompatibleWith = matchingMediaType.isCompatibleWith(httpRequestMediaType); - if (this.logger.isDebugEnabled()) { - this.logger.debug(matchingMediaType + " .isCompatibleWith " + httpRequestMediaType + " = " - + isCompatibleWith); - } + this.logger.debug(LogMessage.format("%s .isCompatibleWith %s = %s", matchingMediaType, + httpRequestMediaType, isCompatibleWith)); if (isCompatibleWith) { return true; } diff --git a/web/src/main/java/org/springframework/security/web/util/matcher/NegatedRequestMatcher.java b/web/src/main/java/org/springframework/security/web/util/matcher/NegatedRequestMatcher.java index f7e500a362..d31f6e6e36 100644 --- a/web/src/main/java/org/springframework/security/web/util/matcher/NegatedRequestMatcher.java +++ b/web/src/main/java/org/springframework/security/web/util/matcher/NegatedRequestMatcher.java @@ -21,6 +21,7 @@ import javax.servlet.http.HttpServletRequest; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.util.Assert; /** @@ -50,9 +51,7 @@ public class NegatedRequestMatcher implements RequestMatcher { @Override public boolean matches(HttpServletRequest request) { boolean result = !this.requestMatcher.matches(request); - if (this.logger.isDebugEnabled()) { - this.logger.debug("matches = " + result); - } + this.logger.debug(LogMessage.format("matches = %s", result)); return result; } diff --git a/web/src/main/java/org/springframework/security/web/util/matcher/OrRequestMatcher.java b/web/src/main/java/org/springframework/security/web/util/matcher/OrRequestMatcher.java index cf501da94a..cacfa15af5 100644 --- a/web/src/main/java/org/springframework/security/web/util/matcher/OrRequestMatcher.java +++ b/web/src/main/java/org/springframework/security/web/util/matcher/OrRequestMatcher.java @@ -24,6 +24,7 @@ import javax.servlet.http.HttpServletRequest; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.util.Assert; /** @@ -45,9 +46,7 @@ public final class OrRequestMatcher implements RequestMatcher { */ public OrRequestMatcher(List requestMatchers) { Assert.notEmpty(requestMatchers, "requestMatchers must contain a value"); - if (requestMatchers.contains(null)) { - throw new IllegalArgumentException("requestMatchers cannot contain null values"); - } + Assert.isTrue(!requestMatchers.contains(null), "requestMatchers cannot contain null values"); this.requestMatchers = requestMatchers; } @@ -62,9 +61,7 @@ public final class OrRequestMatcher implements RequestMatcher { @Override public boolean matches(HttpServletRequest request) { for (RequestMatcher matcher : this.requestMatchers) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("Trying to match using " + matcher); - } + this.logger.debug(LogMessage.format("Trying to match using %s", matcher)); if (matcher.matches(request)) { this.logger.debug("matched"); return true; diff --git a/web/src/main/java/org/springframework/security/web/util/matcher/RegexRequestMatcher.java b/web/src/main/java/org/springframework/security/web/util/matcher/RegexRequestMatcher.java index 0e36a55a18..6d0bc19a1a 100644 --- a/web/src/main/java/org/springframework/security/web/util/matcher/RegexRequestMatcher.java +++ b/web/src/main/java/org/springframework/security/web/util/matcher/RegexRequestMatcher.java @@ -23,6 +23,7 @@ import javax.servlet.http.HttpServletRequest; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.http.HttpMethod; import org.springframework.util.StringUtils; @@ -42,6 +43,8 @@ import org.springframework.util.StringUtils; */ public final class RegexRequestMatcher implements RequestMatcher { + private static final int DEFAULT = 0; + private static final Log logger = LogFactory.getLog(RegexRequestMatcher.class); private final Pattern pattern; @@ -65,12 +68,7 @@ public final class RegexRequestMatcher implements RequestMatcher { * {@link Pattern#CASE_INSENSITIVE} flag set. */ public RegexRequestMatcher(String pattern, String httpMethod, boolean caseInsensitive) { - if (caseInsensitive) { - this.pattern = Pattern.compile(pattern, Pattern.CASE_INSENSITIVE); - } - else { - this.pattern = Pattern.compile(pattern); - } + this.pattern = Pattern.compile(pattern, caseInsensitive ? Pattern.CASE_INSENSITIVE : DEFAULT); this.httpMethod = StringUtils.hasText(httpMethod) ? HttpMethod.valueOf(httpMethod) : null; } @@ -86,28 +84,20 @@ public final class RegexRequestMatcher implements RequestMatcher { if (this.httpMethod != null && request.getMethod() != null && this.httpMethod != valueOf(request.getMethod())) { return false; } - String url = request.getServletPath(); String pathInfo = request.getPathInfo(); String query = request.getQueryString(); - if (pathInfo != null || query != null) { StringBuilder sb = new StringBuilder(url); - if (pathInfo != null) { sb.append(pathInfo); } - if (query != null) { sb.append('?').append(query); } url = sb.toString(); } - - if (logger.isDebugEnabled()) { - logger.debug("Checking match of request : '" + url + "'; against '" + this.pattern + "'"); - } - + logger.debug(LogMessage.format("Checking match of request : '%s'; against '%s'", url, this.pattern)); return this.pattern.matcher(url).matches(); } @@ -122,22 +112,18 @@ public final class RegexRequestMatcher implements RequestMatcher { return HttpMethod.valueOf(method); } catch (IllegalArgumentException ex) { + return null; } - - return null; } @Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append("Regex [pattern='").append(this.pattern).append("'"); - if (this.httpMethod != null) { sb.append(", ").append(this.httpMethod); } - sb.append("]"); - return sb.toString(); } diff --git a/web/src/main/java/org/springframework/security/web/util/matcher/RequestHeaderRequestMatcher.java b/web/src/main/java/org/springframework/security/web/util/matcher/RequestHeaderRequestMatcher.java index 8faa366d61..8b476f37bc 100644 --- a/web/src/main/java/org/springframework/security/web/util/matcher/RequestHeaderRequestMatcher.java +++ b/web/src/main/java/org/springframework/security/web/util/matcher/RequestHeaderRequestMatcher.java @@ -87,7 +87,6 @@ public final class RequestHeaderRequestMatcher implements RequestMatcher { if (this.expectedHeaderValue == null) { return actualHeaderValue != null; } - return this.expectedHeaderValue.equals(actualHeaderValue); } diff --git a/web/src/test/java/org/springframework/security/web/authentication/rememberme/JdbcTokenRepositoryImplTests.java b/web/src/test/java/org/springframework/security/web/authentication/rememberme/JdbcTokenRepositoryImplTests.java index 0835c51e80..6aaa1617cb 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/rememberme/JdbcTokenRepositoryImplTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/rememberme/JdbcTokenRepositoryImplTests.java @@ -29,6 +29,7 @@ import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; @@ -41,7 +42,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -133,14 +133,11 @@ public class JdbcTokenRepositoryImplTests { // SEC-1964 @Test public void retrievingTokenWithNoSeriesReturnsNull() { - given(this.logger.isDebugEnabled()).willReturn(true); - assertThat(this.repo.getTokenForSeries("missingSeries")).isNull(); - - verify(this.logger).isDebugEnabled(); - verify(this.logger).debug(eq("Querying token for series 'missingSeries' returned no results."), - any(EmptyResultDataAccessException.class)); + ArgumentCaptor captor = ArgumentCaptor.forClass(Object.class); + verify(this.logger).debug(captor.capture(), any(EmptyResultDataAccessException.class)); verifyNoMoreInteractions(this.logger); + assertThat(captor.getValue()).hasToString("Querying token for series 'missingSeries' returned no results."); } @Test diff --git a/web/src/test/java/org/springframework/security/web/server/authorization/DelegatingReactiveAuthorizationManagerTests.java b/web/src/test/java/org/springframework/security/web/server/authorization/DelegatingReactiveAuthorizationManagerTests.java index ec9056cd2f..c052bf2b2c 100644 --- a/web/src/test/java/org/springframework/security/web/server/authorization/DelegatingReactiveAuthorizationManagerTests.java +++ b/web/src/test/java/org/springframework/security/web/server/authorization/DelegatingReactiveAuthorizationManagerTests.java @@ -18,11 +18,12 @@ package org.springframework.security.web.server.authorization; import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; import org.mockito.Mock; -import org.mockito.junit.MockitoJUnitRunner; +import org.mockito.MockitoAnnotations; import reactor.core.publisher.Mono; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.authorization.AuthorityReactiveAuthorizationManager; import org.springframework.security.authorization.AuthorizationDecision; import org.springframework.security.core.Authentication; @@ -40,7 +41,6 @@ import static org.mockito.Mockito.verifyZeroInteractions; * @author Rob Winch * @since 5.0 */ -@RunWith(MockitoJUnitRunner.class) public class DelegatingReactiveAuthorizationManagerTests { @Mock @@ -55,7 +55,6 @@ public class DelegatingReactiveAuthorizationManagerTests { @Mock AuthorityReactiveAuthorizationManager delegate2; - @Mock ServerWebExchange exchange; @Mock @@ -68,9 +67,12 @@ public class DelegatingReactiveAuthorizationManagerTests { @Before public void setup() { + MockitoAnnotations.initMocks(this); this.manager = DelegatingReactiveAuthorizationManager.builder() .add(new ServerWebExchangeMatcherEntry<>(this.match1, this.delegate1)) .add(new ServerWebExchangeMatcherEntry<>(this.match2, this.delegate2)).build(); + MockServerHttpRequest request = MockServerHttpRequest.get("/test").build(); + this.exchange = MockServerWebExchange.from(request); } @Test