Polish spring-security-web main code

Manually polish `spring-security-web` following the formatting
and checkstyle fixes.

Issue gh-8945
This commit is contained in:
Phillip Webb 2020-08-03 22:57:18 -07:00 committed by Rob Winch
parent ef951bae90
commit 5bdd757108
178 changed files with 1676 additions and 2791 deletions

View File

@ -45,7 +45,6 @@ public interface AuthenticationEntryPoint {
* @param request that resulted in an <code>AuthenticationException</code>
* @param response so that the user agent can begin authentication
* @param authException that caused the invocation
*
*/
void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException authException)
throws IOException, ServletException;

View File

@ -24,7 +24,9 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.web.util.UrlUtils;
import org.springframework.util.Assert;
/**
* Simple implementation of <tt>RedirectStrategy</tt> which is the default used throughout
@ -51,11 +53,7 @@ public class DefaultRedirectStrategy implements RedirectStrategy {
public void sendRedirect(HttpServletRequest request, HttpServletResponse response, String url) throws IOException {
String redirectUrl = calculateRedirectUrl(request.getContextPath(), url);
redirectUrl = response.encodeRedirectURL(redirectUrl);
if (this.logger.isDebugEnabled()) {
this.logger.debug("Redirecting to '" + redirectUrl + "'");
}
this.logger.debug(LogMessage.format("Redirecting to '%s'", redirectUrl));
response.sendRedirect(redirectUrl);
}
@ -64,30 +62,20 @@ public class DefaultRedirectStrategy implements RedirectStrategy {
if (isContextRelative()) {
return url;
}
else {
return contextPath + url;
}
return contextPath + url;
}
// Full URL, including http(s)://
if (!isContextRelative()) {
return url;
}
if (!url.contains(contextPath)) {
throw new IllegalArgumentException("The fully qualified URL does not include context path.");
}
Assert.isTrue(url.contains(contextPath), "The fully qualified URL does not include context path.");
// Calculate the relative URL from the fully qualified URL, minus the last
// occurrence of the scheme and base context.
url = url.substring(url.lastIndexOf("://") + 3); // strip off scheme
url = url.substring(url.lastIndexOf("://") + 3);
url = url.substring(url.indexOf(contextPath) + contextPath.length());
if (url.length() > 1 && url.charAt(0) == '/') {
url = url.substring(1);
}
return url;
}

View File

@ -26,6 +26,7 @@ import javax.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.web.util.matcher.RequestMatcher;
/**
@ -47,7 +48,7 @@ public final class DefaultSecurityFilterChain implements SecurityFilterChain {
}
public DefaultSecurityFilterChain(RequestMatcher requestMatcher, List<Filter> filters) {
logger.info("Creating filter chain: " + requestMatcher + ", " + filters);
logger.info(LogMessage.format("Creating filter chain: %s, %s", requestMatcher, filters));
this.requestMatcher = requestMatcher;
this.filters = new ArrayList<>(filters);
}

View File

@ -32,6 +32,7 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.firewall.DefaultRequestRejectedHandler;
import org.springframework.security.web.firewall.FirewalledRequest;
@ -173,47 +174,37 @@ public class FilterChainProxy extends GenericFilterBean {
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
boolean clearContext = request.getAttribute(FILTER_APPLIED) == null;
if (clearContext) {
try {
request.setAttribute(FILTER_APPLIED, Boolean.TRUE);
doFilterInternal(request, response, chain);
}
catch (RequestRejectedException ex) {
this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response, ex);
}
finally {
SecurityContextHolder.clearContext();
request.removeAttribute(FILTER_APPLIED);
}
}
else {
if (!clearContext) {
doFilterInternal(request, response, chain);
return;
}
try {
request.setAttribute(FILTER_APPLIED, Boolean.TRUE);
doFilterInternal(request, response, chain);
}
catch (RequestRejectedException ex) {
this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response, ex);
}
finally {
SecurityContextHolder.clearContext();
request.removeAttribute(FILTER_APPLIED);
}
}
private void doFilterInternal(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
FirewalledRequest fwRequest = this.firewall.getFirewalledRequest((HttpServletRequest) request);
HttpServletResponse fwResponse = this.firewall.getFirewalledResponse((HttpServletResponse) response);
List<Filter> filters = getFilters(fwRequest);
FirewalledRequest firewallRequest = this.firewall.getFirewalledRequest((HttpServletRequest) request);
HttpServletResponse firewallResponse = this.firewall.getFirewalledResponse((HttpServletResponse) response);
List<Filter> filters = getFilters(firewallRequest);
if (filters == null || filters.size() == 0) {
if (logger.isDebugEnabled()) {
logger.debug(UrlUtils.buildRequestUrl(fwRequest)
+ ((filters != null) ? " has an empty filter list" : " has no matching filters"));
}
fwRequest.reset();
chain.doFilter(fwRequest, fwResponse);
logger.debug(LogMessage.of(() -> UrlUtils.buildRequestUrl(firewallRequest)
+ ((filters != null) ? " has an empty filter list" : " has no matching filters")));
firewallRequest.reset();
chain.doFilter(firewallRequest, firewallResponse);
return;
}
VirtualFilterChain vfc = new VirtualFilterChain(fwRequest, chain, filters);
vfc.doFilter(fwRequest, fwResponse);
VirtualFilterChain virtualFilterChain = new VirtualFilterChain(firewallRequest, chain, filters);
virtualFilterChain.doFilter(firewallRequest, firewallResponse);
}
/**
@ -227,7 +218,6 @@ public class FilterChainProxy extends GenericFilterBean {
return chain.getFilters();
}
}
return null;
}
@ -286,7 +276,6 @@ public class FilterChainProxy extends GenericFilterBean {
sb.append("Filter Chains: ");
sb.append(this.filterChains);
sb.append("]");
return sb.toString();
}
@ -317,30 +306,19 @@ public class FilterChainProxy extends GenericFilterBean {
@Override
public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException {
if (this.currentPosition == this.size) {
if (logger.isDebugEnabled()) {
logger.debug(UrlUtils.buildRequestUrl(this.firewalledRequest)
+ " reached end of additional filter chain; proceeding with original chain");
}
logger.debug(LogMessage.of(() -> UrlUtils.buildRequestUrl(this.firewalledRequest)
+ " reached end of additional filter chain; proceeding with original chain"));
// Deactivate path stripping as we exit the security filter chain
this.firewalledRequest.reset();
this.originalChain.doFilter(request, response);
return;
}
else {
this.currentPosition++;
Filter nextFilter = this.additionalFilters.get(this.currentPosition - 1);
if (logger.isDebugEnabled()) {
logger.debug(
UrlUtils.buildRequestUrl(this.firewalledRequest) + " at position " + this.currentPosition
+ " of " + this.size + " in additional filter chain; firing Filter: '"
+ nextFilter.getClass().getSimpleName() + "'");
}
nextFilter.doFilter(request, response, this);
}
this.currentPosition++;
Filter nextFilter = this.additionalFilters.get(this.currentPosition - 1);
logger.debug(LogMessage.of(() -> UrlUtils.buildRequestUrl(this.firewalledRequest) + " at position "
+ this.currentPosition + " of " + this.size + " in additional filter chain; firing Filter: '"
+ nextFilter.getClass().getSimpleName() + "'"));
nextFilter.doFilter(request, response, this);
}
}

View File

@ -37,6 +37,7 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpHeaders;
import org.springframework.security.web.util.UrlUtils;
import org.springframework.util.Assert;
/**
* Holds objects associated with a HTTP filter.
@ -65,10 +66,7 @@ public class FilterInvocation {
private HttpServletResponse response;
public FilterInvocation(ServletRequest request, ServletResponse response, FilterChain chain) {
if ((request == null) || (response == null) || (chain == null)) {
throw new IllegalArgumentException("Cannot pass null values to constructor");
}
Assert.isTrue(request != null && response != null && chain != null, "Cannot pass null values to constructor");
this.request = (HttpServletRequest) request;
this.response = (HttpServletResponse) response;
this.chain = chain;
@ -84,9 +82,7 @@ public class FilterInvocation {
public FilterInvocation(String contextPath, String servletPath, String pathInfo, String query, String method) {
DummyRequest request = new DummyRequest();
if (contextPath == null) {
contextPath = "/cp";
}
contextPath = (contextPath != null) ? contextPath : "/cp";
request.setContextPath(contextPath);
request.setServletPath(servletPath);
request.setRequestURI(contextPath + servletPath + ((pathInfo != null) ? pathInfo : ""));
@ -256,9 +252,7 @@ public class FilterInvocation {
if (value == null) {
return -1;
}
else {
return Integer.parseInt(value);
}
return Integer.parseInt(value);
}
void addHeader(String name, String value) {
@ -267,8 +261,8 @@ public class FilterInvocation {
@Override
public String getParameter(String name) {
String[] arr = this.parameters.get(name);
return (arr != null && arr.length > 0) ? arr[0] : null;
String[] array = this.parameters.get(name);
return (array != null && array.length > 0) ? array[0] : null;
}
@Override
@ -317,7 +311,6 @@ public class FilterInvocation {
private Object invokeDefaultMethodForJdk8(Object proxy, Method method, Object[] args) throws Throwable {
Constructor<Lookup> constructor = Lookup.class.getDeclaredConstructor(Class.class);
constructor.setAccessible(true);
Class<?> clazz = method.getDeclaringClass();
return constructor.newInstance(clazz).in(clazz).unreflectSpecial(method, clazz).bindTo(proxy)
.invokeWithArguments(args);

View File

@ -56,7 +56,6 @@ public class PortMapperImpl implements PortMapper {
return httpPort;
}
}
return null;
}
@ -88,24 +87,19 @@ public class PortMapperImpl implements PortMapper {
*/
public void setPortMappings(Map<String, String> newMappings) {
Assert.notNull(newMappings, "A valid list of HTTPS port mappings must be provided");
this.httpsPortMappings.clear();
for (Map.Entry<String, String> entry : newMappings.entrySet()) {
Integer httpPort = Integer.valueOf(entry.getKey());
Integer httpsPort = Integer.valueOf(entry.getValue());
if ((httpPort < 1) || (httpPort > 65535) || (httpsPort < 1) || (httpsPort > 65535)) {
throw new IllegalArgumentException(
"one or both ports out of legal range: " + httpPort + ", " + httpsPort);
}
Assert.isTrue(isInPortRange(httpPort) && isInPortRange(httpsPort),
() -> "one or both ports out of legal range: " + httpPort + ", " + httpsPort);
this.httpsPortMappings.put(httpPort, httpsPort);
}
Assert.isTrue(!this.httpsPortMappings.isEmpty(), "must map at least one port");
}
if (this.httpsPortMappings.size() < 1) {
throw new IllegalArgumentException("must map at least one port");
}
private boolean isInPortRange(int port) {
return port >= 1 && port <= 65535;
}
}

View File

@ -45,24 +45,19 @@ public class PortResolverImpl implements PortResolver {
@Override
public int getServerPort(ServletRequest request) {
int serverPort = request.getServerPort();
Integer portLookup = null;
String scheme = request.getScheme().toLowerCase();
Integer mappedPort = getMappedPort(serverPort, scheme);
return (mappedPort != null) ? mappedPort : serverPort;
}
private Integer getMappedPort(int serverPort, String scheme) {
if ("http".equals(scheme)) {
portLookup = this.portMapper.lookupHttpPort(serverPort);
return this.portMapper.lookupHttpPort(serverPort);
}
else if ("https".equals(scheme)) {
portLookup = this.portMapper.lookupHttpsPort(serverPort);
if ("https".equals(scheme)) {
return this.portMapper.lookupHttpsPort(serverPort);
}
if (portLookup != null) {
// IE 6 bug
serverPort = portLookup;
}
return serverPort;
return null;
}
public void setPortMapper(PortMapper portMapper) {

View File

@ -18,7 +18,6 @@ package org.springframework.security.web.access;
import java.io.IOException;
import javax.servlet.RequestDispatcher;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@ -29,6 +28,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpStatus;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.web.WebAttributes;
import org.springframework.util.Assert;
/**
* Base implementation of {@link AccessDeniedHandler}.
@ -52,22 +52,19 @@ public class AccessDeniedHandlerImpl implements AccessDeniedHandler {
@Override
public void handle(HttpServletRequest request, HttpServletResponse response,
AccessDeniedException accessDeniedException) throws IOException, ServletException {
if (!response.isCommitted()) {
if (this.errorPage != null) {
// Put exception into request scope (perhaps of use to a view)
request.setAttribute(WebAttributes.ACCESS_DENIED_403, accessDeniedException);
// Set the 403 status code.
response.setStatus(HttpStatus.FORBIDDEN.value());
// forward to error page.
RequestDispatcher dispatcher = request.getRequestDispatcher(this.errorPage);
dispatcher.forward(request, response);
}
else {
response.sendError(HttpStatus.FORBIDDEN.value(), HttpStatus.FORBIDDEN.getReasonPhrase());
}
if (response.isCommitted()) {
return;
}
if (this.errorPage == null) {
response.sendError(HttpStatus.FORBIDDEN.value(), HttpStatus.FORBIDDEN.getReasonPhrase());
return;
}
// Put exception into request scope (perhaps of use to a view)
request.setAttribute(WebAttributes.ACCESS_DENIED_403, accessDeniedException);
// Set the 403 status code.
response.setStatus(HttpStatus.FORBIDDEN.value());
// forward to error page.
request.getRequestDispatcher(this.errorPage).forward(request, response);
}
/**
@ -78,10 +75,7 @@ public class AccessDeniedHandlerImpl implements AccessDeniedHandler {
* limitations
*/
public void setErrorPage(String errorPage) {
if ((errorPage != null) && !errorPage.startsWith("/")) {
throw new IllegalArgumentException("errorPage must begin with '/'");
}
Assert.isTrue(errorPage == null || errorPage.startsWith("/"), "errorPage must begin with '/'");
this.errorPage = errorPage;
}

View File

@ -21,6 +21,7 @@ import java.util.Collection;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.access.ConfigAttribute;
import org.springframework.security.access.intercept.AbstractSecurityInterceptor;
@ -47,7 +48,6 @@ public class DefaultWebInvocationPrivilegeEvaluator implements WebInvocationPriv
"AbstractSecurityInterceptor does not support FilterInvocations");
Assert.notNull(securityInterceptor.getAccessDecisionManager(),
"AbstractSecurityInterceptor must provide a non-null AccessDecisionManager");
this.securityInterceptor = securityInterceptor;
}
@ -82,34 +82,23 @@ public class DefaultWebInvocationPrivilegeEvaluator implements WebInvocationPriv
@Override
public boolean isAllowed(String contextPath, String uri, String method, Authentication authentication) {
Assert.notNull(uri, "uri parameter is required");
FilterInvocation fi = new FilterInvocation(contextPath, uri, method);
Collection<ConfigAttribute> attrs = this.securityInterceptor.obtainSecurityMetadataSource().getAttributes(fi);
if (attrs == null) {
if (this.securityInterceptor.isRejectPublicInvocations()) {
return false;
}
return true;
FilterInvocation filterInvocation = new FilterInvocation(contextPath, uri, method);
Collection<ConfigAttribute> attributes = this.securityInterceptor.obtainSecurityMetadataSource()
.getAttributes(filterInvocation);
if (attributes == null) {
return (!this.securityInterceptor.isRejectPublicInvocations());
}
if (authentication == null) {
return false;
}
try {
this.securityInterceptor.getAccessDecisionManager().decide(authentication, fi, attrs);
this.securityInterceptor.getAccessDecisionManager().decide(authentication, filterInvocation, attributes);
return true;
}
catch (AccessDeniedException unauthorized) {
if (logger.isDebugEnabled()) {
logger.debug(fi.toString() + " denied for " + authentication.toString(), unauthorized);
}
catch (AccessDeniedException ex) {
logger.debug(LogMessage.format("%s denied for %s", filterInvocation, authentication), ex);
return false;
}
return true;
}
}

View File

@ -26,6 +26,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.context.support.MessageSourceAccessor;
import org.springframework.core.log.LogMessage;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.authentication.AuthenticationTrustResolver;
import org.springframework.security.authentication.AuthenticationTrustResolverImpl;
@ -107,14 +108,15 @@ public class ExceptionTranslationFilter extends GenericFilterBean {
}
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req;
HttpServletResponse response = (HttpServletResponse) res;
doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
}
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
try {
chain.doFilter(request, response);
this.logger.debug("Chain processed normally");
}
catch (IOException ex) {
@ -123,38 +125,36 @@ public class ExceptionTranslationFilter extends GenericFilterBean {
catch (Exception ex) {
// Try to extract a SpringSecurityException from the stacktrace
Throwable[] causeChain = this.throwableAnalyzer.determineCauseChain(ex);
RuntimeException ase = (AuthenticationException) this.throwableAnalyzer
RuntimeException securityException = (AuthenticationException) this.throwableAnalyzer
.getFirstThrowableOfType(AuthenticationException.class, causeChain);
if (ase == null) {
ase = (AccessDeniedException) this.throwableAnalyzer
if (securityException == null) {
securityException = (AccessDeniedException) this.throwableAnalyzer
.getFirstThrowableOfType(AccessDeniedException.class, causeChain);
}
if (ase != null) {
if (response.isCommitted()) {
throw new ServletException(
"Unable to handle the Spring Security Exception because the response is already committed.",
ex);
}
handleSpringSecurityException(request, response, chain, ase);
if (securityException == null) {
rethrow(ex);
}
else {
// Rethrow ServletExceptions and RuntimeExceptions as-is
if (ex instanceof ServletException) {
throw (ServletException) ex;
}
else if (ex instanceof RuntimeException) {
throw (RuntimeException) ex;
}
// Wrap other Exceptions. This shouldn't actually happen
// as we've already covered all the possibilities for doFilter
throw new RuntimeException(ex);
if (response.isCommitted()) {
throw new ServletException("Unable to handle the Spring Security Exception "
+ "because the response is already committed.", ex);
}
handleSpringSecurityException(request, response, chain, securityException);
}
}
private void rethrow(Exception ex) throws ServletException {
// Rethrow ServletExceptions and RuntimeExceptions as-is
if (ex instanceof ServletException) {
throw (ServletException) ex;
}
if (ex instanceof RuntimeException) {
throw (RuntimeException) ex;
}
// Wrap other Exceptions. This shouldn't actually happen
// as we've already covered all the possibilities for doFilter
throw new RuntimeException(ex);
}
public AuthenticationEntryPoint getAuthenticationEntryPoint() {
return this.authenticationEntryPoint;
}
@ -166,32 +166,36 @@ public class ExceptionTranslationFilter extends GenericFilterBean {
private void handleSpringSecurityException(HttpServletRequest request, HttpServletResponse response,
FilterChain chain, RuntimeException exception) throws IOException, ServletException {
if (exception instanceof AuthenticationException) {
this.logger.debug("Authentication exception occurred; redirecting to authentication entry point",
exception);
sendStartAuthentication(request, response, chain, (AuthenticationException) exception);
handleAuthenticationException(request, response, chain, (AuthenticationException) exception);
}
else if (exception instanceof AccessDeniedException) {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (this.authenticationTrustResolver.isAnonymous(authentication)
|| this.authenticationTrustResolver.isRememberMe(authentication)) {
this.logger.debug(
"Access is denied (user is " + (this.authenticationTrustResolver.isAnonymous(authentication)
? "anonymous" : "not fully authenticated")
+ "); redirecting to authentication entry point",
exception);
handleAccessDeniedException(request, response, chain, (AccessDeniedException) exception);
}
}
sendStartAuthentication(request, response, chain,
new InsufficientAuthenticationException(
this.messages.getMessage("ExceptionTranslationFilter.insufficientAuthentication",
"Full authentication is required to access this resource")));
}
else {
this.logger.debug("Access is denied (user is not anonymous); delegating to AccessDeniedHandler",
exception);
private void handleAuthenticationException(HttpServletRequest request, HttpServletResponse response,
FilterChain chain, AuthenticationException exception) throws ServletException, IOException {
this.logger.debug("Authentication exception occurred; redirecting to authentication entry point", exception);
sendStartAuthentication(request, response, chain, exception);
}
this.accessDeniedHandler.handle(request, response, (AccessDeniedException) exception);
}
private void handleAccessDeniedException(HttpServletRequest request, HttpServletResponse response,
FilterChain chain, AccessDeniedException exception) throws ServletException, IOException {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
boolean isAnonymous = this.authenticationTrustResolver.isAnonymous(authentication);
if (isAnonymous || this.authenticationTrustResolver.isRememberMe(authentication)) {
this.logger.debug(LogMessage
.of(() -> "Access is denied (user is " + (isAnonymous ? "anonymous" : "not fully authenticated")
+ "); redirecting to authentication entry point"),
exception);
sendStartAuthentication(request, response, chain,
new InsufficientAuthenticationException(
this.messages.getMessage("ExceptionTranslationFilter.insufficientAuthentication",
"Full authentication is required to access this resource")));
}
else {
this.logger.debug("Access is denied (user is not anonymous); delegating to AccessDeniedHandler", exception);
this.accessDeniedHandler.handle(request, response, exception);
}
}
@ -232,7 +236,6 @@ public class ExceptionTranslationFilter extends GenericFilterBean {
@Override
protected void initExtractorMap() {
super.initExtractorMap();
registerExtractor(ServletException.class, (throwable) -> {
ThrowableAnalyzer.verifyThrowableHierarchy(throwable, ServletException.class);
return ((ServletException) throwable).getRootCause();

View File

@ -24,6 +24,7 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.PortMapper;
import org.springframework.security.web.PortMapperImpl;
@ -43,10 +44,14 @@ public abstract class AbstractRetryEntryPoint implements ChannelEntryPoint {
private PortResolver portResolver = new PortResolverImpl();
/** The scheme ("http://" or "https://") */
/**
* The scheme ("http://" or "https://")
*/
private final String scheme;
/** The standard port for the scheme (80 for http, 443 for https) */
/**
* The standard port for the scheme (80 for http, 443 for https)
*/
private final int standardPort;
private RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
@ -60,21 +65,14 @@ public abstract class AbstractRetryEntryPoint implements ChannelEntryPoint {
public void commence(HttpServletRequest request, HttpServletResponse response) throws IOException {
String queryString = request.getQueryString();
String redirectUrl = request.getRequestURI() + ((queryString != null) ? ("?" + queryString) : "");
Integer currentPort = this.portResolver.getServerPort(request);
Integer redirectPort = getMappedPort(currentPort);
if (redirectPort != null) {
boolean includePort = redirectPort != this.standardPort;
redirectUrl = this.scheme + request.getServerName() + ((includePort) ? (":" + redirectPort) : "")
+ redirectUrl;
String port = (includePort) ? (":" + redirectPort) : "";
redirectUrl = this.scheme + request.getServerName() + port + redirectUrl;
}
if (this.logger.isDebugEnabled()) {
this.logger.debug("Redirecting to: " + redirectUrl);
}
this.logger.debug(LogMessage.format("Redirecting to: %s", redirectUrl));
this.redirectStrategy.sendRedirect(request, response, redirectUrl);
}

View File

@ -64,10 +64,8 @@ public class ChannelDecisionManagerImpl implements ChannelDecisionManager, Initi
return;
}
}
for (ChannelProcessor processor : this.channelProcessors) {
processor.decide(invocation, config);
if (invocation.getResponse().isCommitted()) {
break;
}
@ -79,11 +77,10 @@ public class ChannelDecisionManagerImpl implements ChannelDecisionManager, Initi
}
@SuppressWarnings("cast")
public void setChannelProcessors(List<?> newList) {
Assert.notEmpty(newList, "A list of ChannelProcessors is required");
this.channelProcessors = new ArrayList<>(newList.size());
for (Object currentObject : newList) {
public void setChannelProcessors(List<?> channelProcessors) {
Assert.notEmpty(channelProcessors, "A list of ChannelProcessors is required");
this.channelProcessors = new ArrayList<>(channelProcessors.size());
for (Object currentObject : channelProcessors) {
Assert.isInstanceOf(ChannelProcessor.class, currentObject, () -> "ChannelProcessor "
+ currentObject.getClass().getName() + " must implement ChannelProcessor");
this.channelProcessors.add((ChannelProcessor) currentObject);
@ -95,13 +92,11 @@ public class ChannelDecisionManagerImpl implements ChannelDecisionManager, Initi
if (ANY_CHANNEL.equals(attribute.getAttribute())) {
return true;
}
for (ChannelProcessor processor : this.channelProcessors) {
if (processor.supports(attribute)) {
return true;
}
}
return false;
}

View File

@ -28,6 +28,7 @@ import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.core.log.LogMessage;
import org.springframework.security.access.ConfigAttribute;
import org.springframework.security.web.FilterInvocation;
import org.springframework.security.web.access.intercept.FilterInvocationSecurityMetadataSource;
@ -93,35 +94,26 @@ public class ChannelProcessingFilter extends GenericFilterBean {
public void afterPropertiesSet() {
Assert.notNull(this.securityMetadataSource, "securityMetadataSource must be specified");
Assert.notNull(this.channelDecisionManager, "channelDecisionManager must be specified");
Collection<ConfigAttribute> attrDefs = this.securityMetadataSource.getAllConfigAttributes();
if (attrDefs == null) {
if (this.logger.isWarnEnabled()) {
this.logger.warn(
"Could not validate configuration attributes as the FilterInvocationSecurityMetadataSource did "
+ "not return any attributes");
}
Collection<ConfigAttribute> attributes = this.securityMetadataSource.getAllConfigAttributes();
if (attributes == null) {
this.logger.warn("Could not validate configuration attributes as the "
+ "FilterInvocationSecurityMetadataSource did not return any attributes");
return;
}
Set<ConfigAttribute> unsupportedAttributes = getUnsupportedAttributes(attributes);
Assert.isTrue(unsupportedAttributes.isEmpty(),
() -> "Unsupported configuration attributes: " + unsupportedAttributes);
this.logger.info("Validated configuration attributes");
}
private Set<ConfigAttribute> getUnsupportedAttributes(Collection<ConfigAttribute> attrDefs) {
Set<ConfigAttribute> unsupportedAttributes = new HashSet<>();
for (ConfigAttribute attr : attrDefs) {
if (!this.channelDecisionManager.supports(attr)) {
unsupportedAttributes.add(attr);
}
}
if (unsupportedAttributes.size() == 0) {
if (this.logger.isInfoEnabled()) {
this.logger.info("Validated configuration attributes");
}
}
else {
throw new IllegalArgumentException("Unsupported configuration attributes: " + unsupportedAttributes);
}
return unsupportedAttributes;
}
@Override
@ -129,22 +121,15 @@ public class ChannelProcessingFilter extends GenericFilterBean {
throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req;
HttpServletResponse response = (HttpServletResponse) res;
FilterInvocation fi = new FilterInvocation(request, response, chain);
Collection<ConfigAttribute> attr = this.securityMetadataSource.getAttributes(fi);
if (attr != null) {
if (this.logger.isDebugEnabled()) {
this.logger.debug("Request: " + fi.toString() + "; ConfigAttributes: " + attr);
}
this.channelDecisionManager.decide(fi, attr);
if (fi.getResponse().isCommitted()) {
FilterInvocation filterInvocation = new FilterInvocation(request, response, chain);
Collection<ConfigAttribute> attributes = this.securityMetadataSource.getAttributes(filterInvocation);
if (attributes != null) {
this.logger.debug(LogMessage.format("Request: %s; ConfigAttributes: %s", filterInvocation, attributes));
this.channelDecisionManager.decide(filterInvocation, attributes);
if (filterInvocation.getResponse().isCommitted()) {
return;
}
}
chain.doFilter(request, response);
}

View File

@ -40,7 +40,6 @@ public interface ChannelProcessor {
/**
* Decided whether the presented {@link FilterInvocation} provides the appropriate
* level of channel security based on the requested list of <tt>ConfigAttribute</tt>s.
*
*/
void decide(FilterInvocation invocation, Collection<ConfigAttribute> config) throws IOException, ServletException;

View File

@ -55,10 +55,7 @@ public class InsecureChannelProcessor implements InitializingBean, ChannelProces
@Override
public void decide(FilterInvocation invocation, Collection<ConfigAttribute> config)
throws IOException, ServletException {
if ((invocation == null) || (config == null)) {
throw new IllegalArgumentException("Nulls cannot be provided");
}
Assert.isTrue(invocation != null && config != null, "Nulls cannot be provided");
for (ConfigAttribute attribute : config) {
if (supports(attribute)) {
if (invocation.getHttpRequest().isSecure()) {

View File

@ -56,7 +56,6 @@ public class SecureChannelProcessor implements InitializingBean, ChannelProcesso
public void decide(FilterInvocation invocation, Collection<ConfigAttribute> config)
throws IOException, ServletException {
Assert.isTrue((invocation != null) && (config != null), "Nulls cannot be provided");
for (ConfigAttribute attribute : config) {
if (supports(attribute)) {
if (!invocation.getHttpRequest().isSecure()) {

View File

@ -41,25 +41,37 @@ abstract class AbstractVariableEvaluationContextPostProcessor
@Override
public final EvaluationContext postProcess(EvaluationContext context, FilterInvocation invocation) {
final HttpServletRequest request = invocation.getHttpRequest();
return new DelegatingEvaluationContext(context) {
private Map<String, String> variables;
@Override
public Object lookupVariable(String name) {
Object result = super.lookupVariable(name);
if (result != null) {
return result;
}
if (this.variables == null) {
this.variables = extractVariables(request);
}
return this.variables.get(name);
}
};
return new VariableEvaluationContext(context, invocation.getHttpRequest());
}
abstract Map<String, String> extractVariables(HttpServletRequest request);
/**
* {@link DelegatingEvaluationContext} to expose variable.
*/
class VariableEvaluationContext extends DelegatingEvaluationContext {
private final HttpServletRequest request;
private Map<String, String> variables;
VariableEvaluationContext(EvaluationContext delegate, HttpServletRequest request) {
super(delegate);
this.request = request;
}
@Override
public Object lookupVariable(String name) {
Object result = super.lookupVariable(name);
if (result != null) {
return result;
}
if (this.variables == null) {
this.variables = extractVariables(this.request);
}
return this.variables.get(name);
}
}
}

View File

@ -20,6 +20,7 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.BiConsumer;
import javax.servlet.http.HttpServletRequest;
@ -58,29 +59,29 @@ public final class ExpressionBasedFilterInvocationSecurityMetadataSource
private static LinkedHashMap<RequestMatcher, Collection<ConfigAttribute>> processMap(
LinkedHashMap<RequestMatcher, Collection<ConfigAttribute>> requestMap, ExpressionParser parser) {
Assert.notNull(parser, "SecurityExpressionHandler returned a null parser object");
LinkedHashMap<RequestMatcher, Collection<ConfigAttribute>> processed = new LinkedHashMap<>(requestMap);
requestMap.forEach((request, value) -> process(parser, request, value, processed::put));
return processed;
}
LinkedHashMap<RequestMatcher, Collection<ConfigAttribute>> requestToExpressionAttributesMap = new LinkedHashMap<>(
requestMap);
for (Map.Entry<RequestMatcher, Collection<ConfigAttribute>> entry : requestMap.entrySet()) {
RequestMatcher request = entry.getKey();
Assert.isTrue(entry.getValue().size() == 1, () -> "Expected a single expression attribute for " + request);
ArrayList<ConfigAttribute> attributes = new ArrayList<>(1);
String expression = entry.getValue().toArray(new ConfigAttribute[1])[0].getAttribute();
logger.debug("Adding web access control expression '" + expression + "', for " + request);
AbstractVariableEvaluationContextPostProcessor postProcessor = createPostProcessor(request);
try {
attributes.add(new WebExpressionConfigAttribute(parser.parseExpression(expression), postProcessor));
}
catch (ParseException ex) {
throw new IllegalArgumentException("Failed to parse expression '" + expression + "'");
}
requestToExpressionAttributesMap.put(request, attributes);
private static void process(ExpressionParser parser, RequestMatcher request, Collection<ConfigAttribute> value,
BiConsumer<RequestMatcher, Collection<ConfigAttribute>> consumer) {
String expression = getExpression(request, value);
logger.debug("Adding web access control expression '" + expression + "', for " + request);
AbstractVariableEvaluationContextPostProcessor postProcessor = createPostProcessor(request);
ArrayList<ConfigAttribute> processed = new ArrayList<>(1);
try {
processed.add(new WebExpressionConfigAttribute(parser.parseExpression(expression), postProcessor));
}
catch (ParseException ex) {
throw new IllegalArgumentException("Failed to parse expression '" + expression + "'");
}
consumer.accept(request, processed);
}
return requestToExpressionAttributesMap;
private static String getExpression(RequestMatcher request, Collection<ConfigAttribute> value) {
Assert.isTrue(value.size() == 1, () -> "Expected a single expression attribute for " + request);
return value.toArray(new ConfigAttribute[1])[0].getAttribute();
}
private static AbstractVariableEvaluationContextPostProcessor createPostProcessor(RequestMatcher request) {

View File

@ -25,6 +25,7 @@ import org.springframework.security.access.expression.ExpressionUtils;
import org.springframework.security.access.expression.SecurityExpressionHandler;
import org.springframework.security.core.Authentication;
import org.springframework.security.web.FilterInvocation;
import org.springframework.util.Assert;
/**
* Voter which handles web authorisation decisions.
@ -37,21 +38,19 @@ public class WebExpressionVoter implements AccessDecisionVoter<FilterInvocation>
private SecurityExpressionHandler<FilterInvocation> expressionHandler = new DefaultWebSecurityExpressionHandler();
@Override
public int vote(Authentication authentication, FilterInvocation fi, Collection<ConfigAttribute> attributes) {
assert authentication != null;
assert fi != null;
assert attributes != null;
WebExpressionConfigAttribute weca = findConfigAttribute(attributes);
if (weca == null) {
public int vote(Authentication authentication, FilterInvocation filterInvocation,
Collection<ConfigAttribute> attributes) {
Assert.notNull(authentication, "authentication must not be null");
Assert.notNull(filterInvocation, "filterInvocation must not be null");
Assert.notNull(attributes, "attributes must not be null");
WebExpressionConfigAttribute webExpressionConfigAttribute = findConfigAttribute(attributes);
if (webExpressionConfigAttribute == null) {
return ACCESS_ABSTAIN;
}
EvaluationContext ctx = this.expressionHandler.createEvaluationContext(authentication, fi);
ctx = weca.postProcess(ctx, fi);
return ExpressionUtils.evaluateAsBoolean(weca.getAuthorizeExpression(), ctx) ? ACCESS_GRANTED : ACCESS_DENIED;
EvaluationContext ctx = webExpressionConfigAttribute.postProcess(
this.expressionHandler.createEvaluationContext(authentication, filterInvocation), filterInvocation);
return ExpressionUtils.evaluateAsBoolean(webExpressionConfigAttribute.getAuthorizeExpression(), ctx)
? ACCESS_GRANTED : ACCESS_DENIED;
}
private WebExpressionConfigAttribute findConfigAttribute(Collection<ConfigAttribute> attributes) {

View File

@ -29,13 +29,13 @@ import org.springframework.security.web.util.matcher.IpAddressMatcher;
*/
public class WebSecurityExpressionRoot extends SecurityExpressionRoot {
// private FilterInvocation filterInvocation;
/** Allows direct access to the request object */
/**
* Allows direct access to the request object
*/
public final HttpServletRequest request;
public WebSecurityExpressionRoot(Authentication a, FilterInvocation fi) {
super(a);
// this.filterInvocation = fi;
this.request = fi.getRequest();
}
@ -47,7 +47,8 @@ public class WebSecurityExpressionRoot extends SecurityExpressionRoot {
* @return true if the IP address of the current request is in the required range.
*/
public boolean hasIpAddress(String ipAddress) {
return (new IpAddressMatcher(ipAddress).matches(this.request));
IpAddressMatcher matcher = new IpAddressMatcher(ipAddress);
return matcher.matches(this.request);
}
}

View File

@ -65,18 +65,13 @@ public class DefaultFilterInvocationSecurityMetadataSource implements FilterInvo
*/
public DefaultFilterInvocationSecurityMetadataSource(
LinkedHashMap<RequestMatcher, Collection<ConfigAttribute>> requestMap) {
this.requestMap = requestMap;
}
@Override
public Collection<ConfigAttribute> getAllConfigAttributes() {
Set<ConfigAttribute> allAttributes = new HashSet<>();
for (Map.Entry<RequestMatcher, Collection<ConfigAttribute>> entry : this.requestMap.entrySet()) {
allAttributes.addAll(entry.getValue());
}
this.requestMap.values().forEach(allAttributes::addAll);
return allAttributes;
}

View File

@ -78,8 +78,7 @@ public class FilterSecurityInterceptor extends AbstractSecurityInterceptor imple
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
FilterInvocation fi = new FilterInvocation(request, response, chain);
invoke(fi);
invoke(new FilterInvocation(request, response, chain));
}
public FilterInvocationSecurityMetadataSource getSecurityMetadataSource() {
@ -100,30 +99,30 @@ public class FilterSecurityInterceptor extends AbstractSecurityInterceptor imple
return FilterInvocation.class;
}
public void invoke(FilterInvocation fi) throws IOException, ServletException {
if ((fi.getRequest() != null) && (fi.getRequest().getAttribute(FILTER_APPLIED) != null)
&& this.observeOncePerRequest) {
public void invoke(FilterInvocation filterInvocation) throws IOException, ServletException {
if (isApplied(filterInvocation) && this.observeOncePerRequest) {
// filter already applied to this request and user wants us to observe
// once-per-request handling, so don't re-do security checking
fi.getChain().doFilter(fi.getRequest(), fi.getResponse());
filterInvocation.getChain().doFilter(filterInvocation.getRequest(), filterInvocation.getResponse());
return;
}
else {
// first time this request being called, so perform security checking
if (fi.getRequest() != null && this.observeOncePerRequest) {
fi.getRequest().setAttribute(FILTER_APPLIED, Boolean.TRUE);
}
InterceptorStatusToken token = super.beforeInvocation(fi);
try {
fi.getChain().doFilter(fi.getRequest(), fi.getResponse());
}
finally {
super.finallyInvocation(token);
}
super.afterInvocation(token, null);
// first time this request being called, so perform security checking
if (filterInvocation.getRequest() != null && this.observeOncePerRequest) {
filterInvocation.getRequest().setAttribute(FILTER_APPLIED, Boolean.TRUE);
}
InterceptorStatusToken token = super.beforeInvocation(filterInvocation);
try {
filterInvocation.getChain().doFilter(filterInvocation.getRequest(), filterInvocation.getResponse());
}
finally {
super.finallyInvocation(token);
}
super.afterInvocation(token, null);
}
private boolean isApplied(FilterInvocation filterInvocation) {
return (filterInvocation.getRequest() != null)
&& (filterInvocation.getRequest().getAttribute(FILTER_APPLIED) != null);
}
/**

View File

@ -77,7 +77,6 @@ public class RequestKey {
}
sb.append(this.url);
sb.append("]");
return sb.toString();
}

View File

@ -30,6 +30,7 @@ import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.context.MessageSource;
import org.springframework.context.MessageSourceAware;
import org.springframework.context.support.MessageSourceAccessor;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.InternalAuthenticationServiceException;
@ -206,52 +207,39 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
* </ol>
*/
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
}
HttpServletRequest request = (HttpServletRequest) req;
HttpServletResponse response = (HttpServletResponse) res;
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
if (!requiresAuthentication(request, response)) {
chain.doFilter(request, response);
return;
}
if (this.logger.isDebugEnabled()) {
this.logger.debug("Request is to process authentication");
}
Authentication authResult;
this.logger.debug("Request is to process authentication");
try {
authResult = attemptAuthentication(request, response);
if (authResult == null) {
Authentication authenticationResult = attemptAuthentication(request, response);
if (authenticationResult == null) {
// return immediately as subclass has indicated that it hasn't completed
// authentication
return;
}
this.sessionStrategy.onAuthentication(authResult, request, response);
this.sessionStrategy.onAuthentication(authenticationResult, request, response);
// Authentication success
if (this.continueChainBeforeSuccessfulAuthentication) {
chain.doFilter(request, response);
}
successfulAuthentication(request, response, chain, authenticationResult);
}
catch (InternalAuthenticationServiceException failed) {
this.logger.error("An internal error occurred while trying to authenticate the user.", failed);
unsuccessfulAuthentication(request, response, failed);
return;
}
catch (AuthenticationException failed) {
catch (AuthenticationException ex) {
// Authentication failed
unsuccessfulAuthentication(request, response, failed);
return;
unsuccessfulAuthentication(request, response, ex);
}
// Authentication success
if (this.continueChainBeforeSuccessfulAuthentication) {
chain.doFilter(request, response);
}
successfulAuthentication(request, response, chain, authResult);
}
/**
@ -316,20 +304,13 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
*/
protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain,
Authentication authResult) throws IOException, ServletException {
if (this.logger.isDebugEnabled()) {
this.logger.debug("Authentication success. Updating SecurityContextHolder to contain: " + authResult);
}
this.logger.debug(
LogMessage.format("Authentication success. Updating SecurityContextHolder to contain: %s", authResult));
SecurityContextHolder.getContext().setAuthentication(authResult);
this.rememberMeServices.loginSuccess(request, response, authResult);
// Fire event
if (this.eventPublisher != null) {
this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass()));
}
this.successHandler.onAuthenticationSuccess(request, response, authResult);
}
@ -347,15 +328,12 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response,
AuthenticationException failed) throws IOException, ServletException {
SecurityContextHolder.clearContext();
if (this.logger.isDebugEnabled()) {
this.logger.debug("Authentication request failed: " + failed.toString(), failed);
this.logger.debug("Updated SecurityContextHolder to contain null Authentication");
this.logger.debug("Delegating to authentication failure handler " + this.failureHandler);
}
this.rememberMeServices.loginFail(request, response);
this.failureHandler.onAuthenticationFailure(request, response, failed);
}

View File

@ -25,6 +25,7 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.Authentication;
import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.RedirectStrategy;
@ -84,18 +85,16 @@ public abstract class AbstractAuthenticationTargetUrlRequestHandler {
protected void handle(HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException, ServletException {
String targetUrl = determineTargetUrl(request, response, authentication);
if (response.isCommitted()) {
this.logger.debug("Response has already been committed. Unable to redirect to " + targetUrl);
this.logger.debug(
LogMessage.format("Response has already been committed. Unable to redirect to %s", targetUrl));
return;
}
this.redirectStrategy.sendRedirect(request, response, targetUrl);
}
/**
* Builds the target URL according to the logic defined in the main class Javadoc
*
* @since 5.2
*/
protected String determineTargetUrl(HttpServletRequest request, HttpServletResponse response,
@ -110,30 +109,23 @@ public abstract class AbstractAuthenticationTargetUrlRequestHandler {
if (isAlwaysUseDefaultTargetUrl()) {
return this.defaultTargetUrl;
}
// Check for the parameter and use that if available
String targetUrl = null;
if (this.targetUrlParameter != null) {
targetUrl = request.getParameter(this.targetUrlParameter);
if (StringUtils.hasText(targetUrl)) {
this.logger.debug("Found targetUrlParameter in request: " + targetUrl);
return targetUrl;
}
}
if (this.useReferer && !StringUtils.hasLength(targetUrl)) {
targetUrl = request.getHeader("Referer");
this.logger.debug("Using Referer header: " + targetUrl);
}
if (!StringUtils.hasText(targetUrl)) {
targetUrl = this.defaultTargetUrl;
this.logger.debug("Using default Url: " + targetUrl);
}
return targetUrl;
}

View File

@ -26,6 +26,7 @@ import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.core.Authentication;
@ -85,31 +86,24 @@ public class AnonymousAuthenticationFilter extends GenericFilterBean implements
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
throws IOException, ServletException {
if (SecurityContextHolder.getContext().getAuthentication() == null) {
SecurityContextHolder.getContext().setAuthentication(createAuthentication((HttpServletRequest) req));
if (this.logger.isDebugEnabled()) {
this.logger.debug("Populated SecurityContextHolder with anonymous token: '"
+ SecurityContextHolder.getContext().getAuthentication() + "'");
}
this.logger.debug(LogMessage.of(() -> "Populated SecurityContextHolder with anonymous token: '"
+ SecurityContextHolder.getContext().getAuthentication() + "'"));
}
else {
if (this.logger.isDebugEnabled()) {
this.logger.debug("SecurityContextHolder not populated with anonymous token, as it already contained: '"
+ SecurityContextHolder.getContext().getAuthentication() + "'");
}
this.logger.debug(LogMessage
.of(() -> "SecurityContextHolder not populated with anonymous token, as it already contained: '"
+ SecurityContextHolder.getContext().getAuthentication() + "'"));
}
chain.doFilter(req, res);
}
protected Authentication createAuthentication(HttpServletRequest request) {
AnonymousAuthenticationToken auth = new AnonymousAuthenticationToken(this.key, this.principal,
AnonymousAuthenticationToken token = new AnonymousAuthenticationToken(this.key, this.principal,
this.authorities);
auth.setDetails(this.authenticationDetailsSource.buildDetails(request));
return auth;
token.setDetails(this.authenticationDetailsSource.buildDetails(request));
return token;
}
public void setAuthenticationDetailsSource(

View File

@ -29,7 +29,7 @@ import org.springframework.util.Assert;
/**
* Adapts a {@link AuthenticationEntryPoint} into a {@link AuthenticationFailureHandler}
*
* @author sbespalov
* @author Sergey Bespalov
* @since 5.2.0
*/
public class AuthenticationEntryPointFailureHandler implements AuthenticationFailureHandler {

View File

@ -84,7 +84,6 @@ public class AuthenticationFilter extends OncePerRequestFilter {
AuthenticationConverter authenticationConverter) {
Assert.notNull(authenticationManagerResolver, "authenticationManagerResolver cannot be null");
Assert.notNull(authenticationConverter, "authenticationConverter cannot be null");
this.authenticationManagerResolver = authenticationManagerResolver;
this.authenticationConverter = authenticationConverter;
}
@ -142,19 +141,16 @@ public class AuthenticationFilter extends OncePerRequestFilter {
filterChain.doFilter(request, response);
return;
}
try {
Authentication authenticationResult = attemptAuthentication(request, response);
if (authenticationResult == null) {
filterChain.doFilter(request, response);
return;
}
HttpSession session = request.getSession(false);
if (session != null) {
request.changeSessionId();
}
successfulAuthentication(request, response, filterChain, authenticationResult);
}
catch (AuthenticationException ex) {
@ -182,13 +178,11 @@ public class AuthenticationFilter extends OncePerRequestFilter {
if (authentication == null) {
return null;
}
AuthenticationManager authenticationManager = this.authenticationManagerResolver.resolve(request);
Authentication authenticationResult = authenticationManager.authenticate(authentication);
if (authenticationResult == null) {
throw new ServletException("AuthenticationManager should not return null Authentication object.");
}
return authenticationResult;
}

View File

@ -27,6 +27,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.util.matcher.ELRequestMatcher;
@ -62,7 +63,7 @@ import org.springframework.util.Assert;
*/
public class DelegatingAuthenticationEntryPoint implements AuthenticationEntryPoint, InitializingBean {
private final Log logger = LogFactory.getLog(getClass());
private static final Log logger = LogFactory.getLog(DelegatingAuthenticationEntryPoint.class);
private final LinkedHashMap<RequestMatcher, AuthenticationEntryPoint> entryPoints;
@ -75,25 +76,16 @@ public class DelegatingAuthenticationEntryPoint implements AuthenticationEntryPo
@Override
public void commence(HttpServletRequest request, HttpServletResponse response,
AuthenticationException authException) throws IOException, ServletException {
for (RequestMatcher requestMatcher : this.entryPoints.keySet()) {
if (this.logger.isDebugEnabled()) {
this.logger.debug("Trying to match using " + requestMatcher);
}
logger.debug(LogMessage.format("Trying to match using %s", requestMatcher));
if (requestMatcher.matches(request)) {
AuthenticationEntryPoint entryPoint = this.entryPoints.get(requestMatcher);
if (this.logger.isDebugEnabled()) {
this.logger.debug("Match found! Executing " + entryPoint);
}
logger.debug(LogMessage.format("Match found! Executing %s", entryPoint));
entryPoint.commence(request, response, authException);
return;
}
}
if (this.logger.isDebugEnabled()) {
this.logger.debug("No match found. Using default entry point " + this.defaultEntryPoint);
}
logger.debug(LogMessage.format("No match found. Using default entry point %s", this.defaultEntryPoint));
// No EntryPoint matched, use defaultEntryPoint
this.defaultEntryPoint.commence(request, response, authException);
}

View File

@ -62,9 +62,6 @@ public class DelegatingAuthenticationFailureHandler implements AuthenticationFai
this.defaultHandler = defaultHandler;
}
/**
* {@inheritDoc}
*/
@Override
public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
AuthenticationException exception) throws IOException, ServletException {

View File

@ -49,7 +49,6 @@ public class ExceptionMappingAuthenticationFailureHandler extends SimpleUrlAuthe
public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
AuthenticationException exception) throws IOException, ServletException {
String url = this.failureUrlMap.get(exception.getClass().getName());
if (url != null) {
getRedirectStrategy().sendRedirect(request, response, url);
}

View File

@ -55,9 +55,7 @@ public class Http403ForbiddenEntryPoint implements AuthenticationEntryPoint {
@Override
public void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException arg2)
throws IOException {
if (logger.isDebugEnabled()) {
logger.debug("Pre-authenticated entry point called. Rejecting access");
}
logger.debug("Pre-authenticated entry point called. Rejecting access");
response.sendError(HttpServletResponse.SC_FORBIDDEN, "Access Denied");
}

View File

@ -27,6 +27,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.DefaultRedirectStrategy;
@ -93,9 +94,8 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin
public void afterPropertiesSet() {
Assert.isTrue(StringUtils.hasText(this.loginFormUrl) && UrlUtils.isValidRedirectUrl(this.loginFormUrl),
"loginFormUrl must be specified and must be a valid redirect URL");
if (this.useForward && UrlUtils.isAbsoluteUrl(this.loginFormUrl)) {
throw new IllegalArgumentException("useForward must be false if using an absolute loginFormURL");
}
Assert.isTrue(!this.useForward || !UrlUtils.isAbsoluteUrl(this.loginFormUrl),
"useForward must be false if using an absolute loginFormURL");
Assert.notNull(this.portMapper, "portMapper must be specified");
Assert.notNull(this.portResolver, "portResolver must be specified");
}
@ -110,7 +110,6 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin
*/
protected String determineUrlToUseForThisRequest(HttpServletRequest request, HttpServletResponse response,
AuthenticationException exception) {
return getLoginFormUrl();
}
@ -120,75 +119,55 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin
@Override
public void commence(HttpServletRequest request, HttpServletResponse response,
AuthenticationException authException) throws IOException, ServletException {
String redirectUrl = null;
if (this.useForward) {
if (this.forceHttps && "http".equals(request.getScheme())) {
// First redirect the current request to HTTPS.
// When that request is received, the forward to the login page will be
// used.
redirectUrl = buildHttpsRedirectUrlForRequest(request);
}
if (redirectUrl == null) {
String loginForm = determineUrlToUseForThisRequest(request, response, authException);
if (logger.isDebugEnabled()) {
logger.debug("Server side forward to: " + loginForm);
}
RequestDispatcher dispatcher = request.getRequestDispatcher(loginForm);
dispatcher.forward(request, response);
return;
}
}
else {
if (!this.useForward) {
// redirect to login page. Use https if forceHttps true
redirectUrl = buildRedirectUrlToLoginPage(request, response, authException);
String redirectUrl = buildRedirectUrlToLoginPage(request, response, authException);
this.redirectStrategy.sendRedirect(request, response, redirectUrl);
return;
}
this.redirectStrategy.sendRedirect(request, response, redirectUrl);
String redirectUrl = null;
if (this.forceHttps && "http".equals(request.getScheme())) {
// First redirect the current request to HTTPS. When that request is received,
// the forward to the login page will be used.
redirectUrl = buildHttpsRedirectUrlForRequest(request);
}
if (redirectUrl != null) {
this.redirectStrategy.sendRedirect(request, response, redirectUrl);
return;
}
String loginForm = determineUrlToUseForThisRequest(request, response, authException);
logger.debug(LogMessage.format("Server side forward to: %s", loginForm));
RequestDispatcher dispatcher = request.getRequestDispatcher(loginForm);
dispatcher.forward(request, response);
return;
}
protected String buildRedirectUrlToLoginPage(HttpServletRequest request, HttpServletResponse response,
AuthenticationException authException) {
String loginForm = determineUrlToUseForThisRequest(request, response, authException);
if (UrlUtils.isAbsoluteUrl(loginForm)) {
return loginForm;
}
int serverPort = this.portResolver.getServerPort(request);
String scheme = request.getScheme();
RedirectUrlBuilder urlBuilder = new RedirectUrlBuilder();
urlBuilder.setScheme(scheme);
urlBuilder.setServerName(request.getServerName());
urlBuilder.setPort(serverPort);
urlBuilder.setContextPath(request.getContextPath());
urlBuilder.setPathInfo(loginForm);
if (this.forceHttps && "http".equals(scheme)) {
Integer httpsPort = this.portMapper.lookupHttpsPort(serverPort);
if (httpsPort != null) {
// Overwrite scheme and port in the redirect URL
urlBuilder.setScheme("https");
urlBuilder.setPort(httpsPort);
}
else {
logger.warn("Unable to redirect to HTTPS as no port mapping found for HTTP port " + serverPort);
logger.warn(LogMessage.format("Unable to redirect to HTTPS as no port mapping found for HTTP port %s",
serverPort));
}
}
return urlBuilder.getUrl();
}
@ -197,10 +176,8 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin
* current request to HTTPS, before doing a forward to the login page.
*/
protected String buildHttpsRedirectUrlForRequest(HttpServletRequest request) throws IOException, ServletException {
int serverPort = this.portResolver.getServerPort(request);
Integer httpsPort = this.portMapper.lookupHttpsPort(serverPort);
if (httpsPort != null) {
RedirectUrlBuilder urlBuilder = new RedirectUrlBuilder();
urlBuilder.setScheme("https");
@ -210,13 +187,11 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin
urlBuilder.setServletPath(request.getServletPath());
urlBuilder.setPathInfo(request.getPathInfo());
urlBuilder.setQuery(request.getQueryString());
return urlBuilder.getUrl();
}
// Fall through to server-side forward with warning message
logger.warn("Unable to redirect to HTTPS as no port mapping found for HTTP port " + serverPort);
logger.warn(
LogMessage.format("Unable to redirect to HTTPS as no port mapping found for HTTP port %s", serverPort));
return null;
}

View File

@ -74,10 +74,8 @@ public class SavedRequestAwareAuthenticationSuccessHandler extends SimpleUrlAuth
public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response,
Authentication authentication) throws ServletException, IOException {
SavedRequest savedRequest = this.requestCache.getRequest(request, response);
if (savedRequest == null) {
super.onAuthenticationSuccess(request, response, authentication);
return;
}
String targetUrlParameter = getTargetUrlParameter();
@ -85,12 +83,9 @@ public class SavedRequestAwareAuthenticationSuccessHandler extends SimpleUrlAuth
|| (targetUrlParameter != null && StringUtils.hasText(request.getParameter(targetUrlParameter)))) {
this.requestCache.removeRequest(request, response);
super.onAuthenticationSuccess(request, response, authentication);
return;
}
clearAuthenticationAttributes(request);
// Use the DefaultSavedRequest URL
String targetUrl = savedRequest.getRedirectUrl();
this.logger.debug("Redirecting to DefaultSavedRequest Url: " + targetUrl);

View File

@ -76,24 +76,19 @@ public class SimpleUrlAuthenticationFailureHandler implements AuthenticationFail
@Override
public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
AuthenticationException exception) throws IOException, ServletException {
if (this.defaultFailureUrl == null) {
this.logger.debug("No failure URL set, sending 401 Unauthorized error");
response.sendError(HttpStatus.UNAUTHORIZED.value(), HttpStatus.UNAUTHORIZED.getReasonPhrase());
return;
}
saveException(request, exception);
if (this.forwardToDestination) {
this.logger.debug("Forwarding to " + this.defaultFailureUrl);
request.getRequestDispatcher(this.defaultFailureUrl).forward(request, response);
}
else {
saveException(request, exception);
if (this.forwardToDestination) {
this.logger.debug("Forwarding to " + this.defaultFailureUrl);
request.getRequestDispatcher(this.defaultFailureUrl).forward(request, response);
}
else {
this.logger.debug("Redirecting to " + this.defaultFailureUrl);
this.redirectStrategy.sendRedirect(request, response, this.defaultFailureUrl);
}
this.logger.debug("Redirecting to " + this.defaultFailureUrl);
this.redirectStrategy.sendRedirect(request, response, this.defaultFailureUrl);
}
}
@ -108,13 +103,11 @@ public class SimpleUrlAuthenticationFailureHandler implements AuthenticationFail
protected final void saveException(HttpServletRequest request, AuthenticationException exception) {
if (this.forwardToDestination) {
request.setAttribute(WebAttributes.AUTHENTICATION_EXCEPTION, exception);
return;
}
else {
HttpSession session = request.getSession(false);
if (session != null || this.allowSessionCreation) {
request.getSession().setAttribute(WebAttributes.AUTHENTICATION_EXCEPTION, exception);
}
HttpSession session = request.getSession(false);
if (session != null || this.allowSessionCreation) {
request.getSession().setAttribute(WebAttributes.AUTHENTICATION_EXCEPTION, exception);
}
}

View File

@ -59,7 +59,6 @@ public class SimpleUrlAuthenticationSuccessHandler extends AbstractAuthenticatio
@Override
public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response,
Authentication authentication) throws IOException, ServletException {
handle(request, response, authentication);
clearAuthenticationAttributes(request);
}
@ -70,12 +69,9 @@ public class SimpleUrlAuthenticationSuccessHandler extends AbstractAuthenticatio
*/
protected final void clearAuthenticationAttributes(HttpServletRequest request) {
HttpSession session = request.getSession(false);
if (session == null) {
return;
if (session != null) {
session.removeAttribute(WebAttributes.AUTHENTICATION_EXCEPTION);
}
session.removeAttribute(WebAttributes.AUTHENTICATION_EXCEPTION);
}
}

View File

@ -74,25 +74,14 @@ public class UsernamePasswordAuthenticationFilter extends AbstractAuthentication
if (this.postOnly && !request.getMethod().equals("POST")) {
throw new AuthenticationServiceException("Authentication method not supported: " + request.getMethod());
}
String username = obtainUsername(request);
String password = obtainPassword(request);
if (username == null) {
username = "";
}
if (password == null) {
password = "";
}
username = (username != null) ? username : "";
username = username.trim();
String password = obtainPassword(request);
password = (password != null) ? password : "";
UsernamePasswordAuthenticationToken authRequest = new UsernamePasswordAuthenticationToken(username, password);
// Allow subclasses to set the "details" property
setDetails(request, authRequest);
return this.getAuthenticationManager().authenticate(authRequest);
}

View File

@ -44,7 +44,6 @@ public class WebAuthenticationDetails implements Serializable {
*/
public WebAuthenticationDetails(HttpServletRequest request) {
this.remoteAddress = request.getRemoteAddr();
HttpSession session = request.getSession(false);
this.sessionId = (session != null) ? session.getId() : null;
}
@ -62,39 +61,31 @@ public class WebAuthenticationDetails implements Serializable {
@Override
public boolean equals(Object obj) {
if (obj instanceof WebAuthenticationDetails) {
WebAuthenticationDetails rhs = (WebAuthenticationDetails) obj;
if ((this.remoteAddress == null) && (rhs.getRemoteAddress() != null)) {
WebAuthenticationDetails other = (WebAuthenticationDetails) obj;
if ((this.remoteAddress == null) && (other.getRemoteAddress() != null)) {
return false;
}
if ((this.remoteAddress != null) && (rhs.getRemoteAddress() == null)) {
if ((this.remoteAddress != null) && (other.getRemoteAddress() == null)) {
return false;
}
if (this.remoteAddress != null) {
if (!this.remoteAddress.equals(rhs.getRemoteAddress())) {
if (!this.remoteAddress.equals(other.getRemoteAddress())) {
return false;
}
}
if ((this.sessionId == null) && (rhs.getSessionId() != null)) {
if ((this.sessionId == null) && (other.getSessionId() != null)) {
return false;
}
if ((this.sessionId != null) && (rhs.getSessionId() == null)) {
if ((this.sessionId != null) && (other.getSessionId() == null)) {
return false;
}
if (this.sessionId != null) {
if (!this.sessionId.equals(rhs.getSessionId())) {
if (!this.sessionId.equals(other.getSessionId())) {
return false;
}
}
return true;
}
return false;
}
@ -118,15 +109,12 @@ public class WebAuthenticationDetails implements Serializable {
@Override
public int hashCode() {
int code = 7654;
if (this.remoteAddress != null) {
code = code * (this.remoteAddress.hashCode() % 7);
}
if (this.sessionId != null) {
code = code * (this.sessionId.hashCode() % 7);
}
return code;
}
@ -136,7 +124,6 @@ public class WebAuthenticationDetails implements Serializable {
sb.append(super.toString()).append(": ");
sb.append("RemoteIpAddress: ").append(this.getRemoteAddress()).append("; ");
sb.append("SessionId: ").append(this.getSessionId());
return sb.toString();
}

View File

@ -43,15 +43,14 @@ public final class CookieClearingLogoutHandler implements LogoutHandler {
Assert.notNull(cookiesToClear, "List of cookies cannot be null");
List<Function<HttpServletRequest, Cookie>> cookieList = new ArrayList<>();
for (String cookieName : cookiesToClear) {
Function<HttpServletRequest, Cookie> f = (request) -> {
cookieList.add((request) -> {
Cookie cookie = new Cookie(cookieName, null);
String cookiePath = request.getContextPath() + "/";
cookie.setPath(cookiePath);
cookie.setMaxAge(0);
cookie.setSecure(request.isSecure());
return cookie;
};
cookieList.add(f);
});
}
this.cookiesToClear = cookieList;
}
@ -65,8 +64,7 @@ public final class CookieClearingLogoutHandler implements LogoutHandler {
List<Function<HttpServletRequest, Cookie>> cookieList = new ArrayList<>();
for (Cookie cookie : cookiesToClear) {
Assert.isTrue(cookie.getMaxAge() == 0, "Cookie maxAge must be 0");
Function<HttpServletRequest, Cookie> f = (request) -> cookie;
cookieList.add(f);
cookieList.add((request) -> cookie);
}
this.cookiesToClear = cookieList;
}

View File

@ -25,6 +25,7 @@ import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.util.UrlUtils;
@ -83,25 +84,20 @@ public class LogoutFilter extends GenericFilterBean {
}
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req;
HttpServletResponse response = (HttpServletResponse) res;
doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
}
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
if (requiresLogout(request, response)) {
Authentication auth = SecurityContextHolder.getContext().getAuthentication();
if (this.logger.isDebugEnabled()) {
this.logger.debug("Logging out user '" + auth + "' and transferring to logout destination");
}
this.logger.debug(LogMessage.format("Logging out user '%s' and transferring to logout destination", auth));
this.handler.logout(request, response, auth);
this.logoutSuccessHandler.onLogoutSuccess(request, response, auth);
return;
}
chain.doFilter(request, response);
}

View File

@ -23,6 +23,7 @@ import javax.servlet.http.HttpSession;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
@ -61,16 +62,14 @@ public class SecurityContextLogoutHandler implements LogoutHandler {
if (this.invalidateHttpSession) {
HttpSession session = request.getSession(false);
if (session != null) {
this.logger.debug("Invalidating session: " + session.getId());
this.logger.debug(LogMessage.format("Invalidating session: %s", session.getId()));
session.invalidate();
}
}
if (this.clearAuthentication) {
SecurityContext context = SecurityContextHolder.getContext();
context.setAuthentication(null);
}
SecurityContextHolder.clearContext();
}

View File

@ -28,6 +28,7 @@ import javax.servlet.http.HttpSession;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.event.InteractiveAuthenticationSuccessEvent;
@ -124,16 +125,11 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
if (this.logger.isDebugEnabled()) {
this.logger
.debug("Checking secure context token: " + SecurityContextHolder.getContext().getAuthentication());
}
this.logger.debug(LogMessage
.of(() -> "Checking secure context token: " + SecurityContextHolder.getContext().getAuthentication()));
if (this.requiresAuthenticationRequestMatcher.matches((HttpServletRequest) request)) {
doAuthenticate((HttpServletRequest) request, (HttpServletResponse) response);
}
chain.doFilter(request, response);
}
@ -156,21 +152,15 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
* @return true if the principal has changed, else false
*/
protected boolean principalChanged(HttpServletRequest request, Authentication currentAuthentication) {
Object principal = getPreAuthenticatedPrincipal(request);
if ((principal instanceof String) && currentAuthentication.getName().equals(principal)) {
return false;
}
if (principal != null && principal.equals(currentAuthentication.getPrincipal())) {
return false;
}
if (this.logger.isDebugEnabled()) {
this.logger
.debug("Pre-authenticated principal has changed to " + principal + " and will be reauthenticated");
}
this.logger.debug(LogMessage.format("Pre-authenticated principal has changed to %s and will be reauthenticated",
principal));
return true;
}
@ -179,35 +169,24 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
*/
private void doAuthenticate(HttpServletRequest request, HttpServletResponse response)
throws IOException, ServletException {
Authentication authResult;
Object principal = getPreAuthenticatedPrincipal(request);
Object credentials = getPreAuthenticatedCredentials(request);
if (principal == null) {
if (this.logger.isDebugEnabled()) {
this.logger.debug("No pre-authenticated principal found in request");
}
this.logger.debug("No pre-authenticated principal found in request");
return;
}
if (this.logger.isDebugEnabled()) {
this.logger.debug("preAuthenticatedPrincipal = " + principal + ", trying to authenticate");
}
this.logger.debug(LogMessage.format("preAuthenticatedPrincipal = %s, trying to authenticate", principal));
Object credentials = getPreAuthenticatedCredentials(request);
try {
PreAuthenticatedAuthenticationToken authRequest = new PreAuthenticatedAuthenticationToken(principal,
credentials);
authRequest.setDetails(this.authenticationDetailsSource.buildDetails(request));
authResult = this.authenticationManager.authenticate(authRequest);
successfulAuthentication(request, response, authResult);
PreAuthenticatedAuthenticationToken authenticationRequest = new PreAuthenticatedAuthenticationToken(
principal, credentials);
authenticationRequest.setDetails(this.authenticationDetailsSource.buildDetails(request));
Authentication authenticationResult = this.authenticationManager.authenticate(authenticationRequest);
successfulAuthentication(request, response, authenticationResult);
}
catch (AuthenticationException failed) {
unsuccessfulAuthentication(request, response, failed);
catch (AuthenticationException ex) {
unsuccessfulAuthentication(request, response, ex);
if (!this.continueFilterChainOnUnsuccessfulAuthentication) {
throw failed;
throw ex;
}
}
}
@ -218,15 +197,11 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
*/
protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response,
Authentication authResult) throws IOException, ServletException {
if (this.logger.isDebugEnabled()) {
this.logger.debug("Authentication success: " + authResult);
}
this.logger.debug(LogMessage.format("Authentication success: %s", authResult));
SecurityContextHolder.getContext().setAuthentication(authResult);
// Fire event
if (this.eventPublisher != null) {
this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass()));
}
if (this.authenticationSuccessHandler != null) {
this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authResult);
}
@ -241,12 +216,8 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response,
AuthenticationException failed) throws IOException, ServletException {
SecurityContextHolder.clearContext();
if (this.logger.isDebugEnabled()) {
this.logger.debug("Cleared security context due to exception", failed);
}
this.logger.debug("Cleared security context due to exception", failed);
request.setAttribute(WebAttributes.AUTHENTICATION_EXCEPTION, failed);
if (this.authenticationFailureHandler != null) {
this.authenticationFailureHandler.onAuthenticationFailure(request, response, failed);
}
@ -355,36 +326,27 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
@Override
public boolean matches(HttpServletRequest request) {
Authentication currentUser = SecurityContextHolder.getContext().getAuthentication();
if (currentUser == null) {
return true;
}
if (!AbstractPreAuthenticatedProcessingFilter.this.checkForPrincipalChanges) {
return false;
}
if (!principalChanged(request, currentUser)) {
return false;
}
AbstractPreAuthenticatedProcessingFilter.this.logger
.debug("Pre-authenticated principal has changed and will be reauthenticated");
if (AbstractPreAuthenticatedProcessingFilter.this.invalidateSessionOnPrincipalChange) {
SecurityContextHolder.clearContext();
HttpSession session = request.getSession(false);
if (session != null) {
AbstractPreAuthenticatedProcessingFilter.this.logger.debug("Invalidating existing session");
session.invalidate();
request.getSession();
}
}
return true;
}

View File

@ -21,6 +21,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.Ordered;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AccountStatusUserDetailsChecker;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.authentication.BadCredentialsException;
@ -50,11 +51,11 @@ public class PreAuthenticatedAuthenticationProvider implements AuthenticationPro
private static final Log logger = LogFactory.getLog(PreAuthenticatedAuthenticationProvider.class);
private AuthenticationUserDetailsService<PreAuthenticatedAuthenticationToken> preAuthenticatedUserDetailsService = null;
private AuthenticationUserDetailsService<PreAuthenticatedAuthenticationToken> preAuthenticatedUserDetailsService;
private UserDetailsChecker userDetailsChecker = new AccountStatusUserDetailsChecker();
private boolean throwExceptionWhenTokenRejected = false;
private boolean throwExceptionWhenTokenRejected;
private int order = -1; // default: same as non-ordered
@ -77,38 +78,27 @@ public class PreAuthenticatedAuthenticationProvider implements AuthenticationPro
if (!supports(authentication.getClass())) {
return null;
}
if (logger.isDebugEnabled()) {
logger.debug("PreAuthenticated authentication request: " + authentication);
}
logger.debug(LogMessage.format("PreAuthenticated authentication request: %s", authentication));
if (authentication.getPrincipal() == null) {
logger.debug("No pre-authenticated principal found in request.");
if (this.throwExceptionWhenTokenRejected) {
throw new BadCredentialsException("No pre-authenticated principal found in request.");
}
return null;
}
if (authentication.getCredentials() == null) {
logger.debug("No pre-authenticated credentials found in request.");
if (this.throwExceptionWhenTokenRejected) {
throw new BadCredentialsException("No pre-authenticated credentials found in request.");
}
return null;
}
UserDetails ud = this.preAuthenticatedUserDetailsService
UserDetails userDetails = this.preAuthenticatedUserDetailsService
.loadUserDetails((PreAuthenticatedAuthenticationToken) authentication);
this.userDetailsChecker.check(ud);
PreAuthenticatedAuthenticationToken result = new PreAuthenticatedAuthenticationToken(ud,
authentication.getCredentials(), ud.getAuthorities());
this.userDetailsChecker.check(userDetails);
PreAuthenticatedAuthenticationToken result = new PreAuthenticatedAuthenticationToken(userDetails,
authentication.getCredentials(), userDetails.getAuthorities());
result.setDetails(authentication.getDetails());
return result;
}

View File

@ -46,7 +46,6 @@ public class PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails extends
public PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails(HttpServletRequest request,
Collection<? extends GrantedAuthority> authorities) {
super(request);
List<GrantedAuthority> temp = new ArrayList<>(authorities.size());
temp.addAll(authorities);
this.authorities = Collections.unmodifiableList(temp);

View File

@ -59,12 +59,10 @@ public class RequestAttributeAuthenticationFilter extends AbstractPreAuthenticat
@Override
protected Object getPreAuthenticatedPrincipal(HttpServletRequest request) {
String principal = (String) request.getAttribute(this.principalEnvironmentVariable);
if (principal == null && this.exceptionIfVariableMissing) {
throw new PreAuthenticatedCredentialsNotFoundException(
this.principalEnvironmentVariable + " variable not found in request.");
}
return principal;
}
@ -78,7 +76,6 @@ public class RequestAttributeAuthenticationFilter extends AbstractPreAuthenticat
if (this.credentialsEnvironmentVariable != null) {
return request.getAttribute(this.credentialsEnvironmentVariable);
}
return "N/A";
}

View File

@ -60,12 +60,10 @@ public class RequestHeaderAuthenticationFilter extends AbstractPreAuthenticatedP
@Override
protected Object getPreAuthenticatedPrincipal(HttpServletRequest request) {
String principal = request.getHeader(this.principalRequestHeader);
if (principal == null && this.exceptionIfHeaderMissing) {
throw new PreAuthenticatedCredentialsNotFoundException(
this.principalRequestHeader + " header not found in request.");
}
return principal;
}
@ -79,7 +77,6 @@ public class RequestHeaderAuthenticationFilter extends AbstractPreAuthenticatedP
if (this.credentialsRequestHeader != null) {
return request.getHeader(this.credentialsRequestHeader);
}
return "N/A";
}

View File

@ -27,6 +27,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.mapping.Attributes2GrantedAuthoritiesMapper;
@ -76,13 +77,11 @@ public class J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource implements
*/
protected Collection<String> getUserRoles(HttpServletRequest request) {
ArrayList<String> j2eeUserRolesList = new ArrayList<>();
for (String role : this.j2eeMappableRoles) {
if (request.isUserInRole(role)) {
j2eeUserRolesList.add(role);
}
}
return j2eeUserRolesList;
}
@ -93,19 +92,14 @@ public class J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource implements
*/
@Override
public PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails buildDetails(HttpServletRequest context) {
Collection<String> j2eeUserRoles = getUserRoles(context);
Collection<? extends GrantedAuthority> userGas = this.j2eeUserRoles2GrantedAuthoritiesMapper
Collection<? extends GrantedAuthority> userGrantedAuthorities = this.j2eeUserRoles2GrantedAuthoritiesMapper
.getGrantedAuthorities(j2eeUserRoles);
if (this.logger.isDebugEnabled()) {
this.logger.debug("J2EE roles [" + j2eeUserRoles + "] mapped to Granted Authorities: [" + userGas + "]");
this.logger.debug(LogMessage.format("J2EE roles [%s] mapped to Granted Authorities: [%s]", j2eeUserRoles,
userGrantedAuthorities));
}
PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails result = new PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails(
context, userGas);
return result;
return new PreAuthenticatedGrantedAuthoritiesWebAuthenticationDetails(context, userGrantedAuthorities);
}
/**

View File

@ -18,6 +18,7 @@ package org.springframework.security.web.authentication.preauth.j2ee;
import javax.servlet.http.HttpServletRequest;
import org.springframework.core.log.LogMessage;
import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter;
/**
@ -36,9 +37,7 @@ public class J2eePreAuthenticatedProcessingFilter extends AbstractPreAuthenticat
@Override
protected Object getPreAuthenticatedPrincipal(HttpServletRequest httpRequest) {
Object principal = (httpRequest.getUserPrincipal() != null) ? httpRequest.getUserPrincipal().getName() : null;
if (this.logger.isDebugEnabled()) {
this.logger.debug("PreAuthenticated J2EE principal: " + principal);
}
this.logger.debug(LogMessage.format("PreAuthenticated J2EE principal: %s", principal));
return principal;
}

View File

@ -22,6 +22,7 @@ import java.io.StringReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import javax.xml.parsers.DocumentBuilder;
@ -43,6 +44,7 @@ import org.springframework.context.ResourceLoaderAware;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.security.core.authority.mapping.MappableAttributesRetriever;
import org.springframework.util.Assert;
/**
* This <tt>MappableAttributesRetriever</tt> implementation reads the list of defined J2EE
@ -82,17 +84,17 @@ public class WebXmlMappableAttributesRetriever
Resource webXml = this.resourceLoader.getResource("/WEB-INF/web.xml");
Document doc = getDocument(webXml.getInputStream());
NodeList webApp = doc.getElementsByTagName("web-app");
if (webApp.getLength() != 1) {
throw new IllegalArgumentException("Failed to find 'web-app' element in resource" + webXml);
}
Assert.isTrue(webApp.getLength() == 1, () -> "Failed to find 'web-app' element in resource" + webXml);
NodeList securityRoles = ((Element) webApp.item(0)).getElementsByTagName("security-role");
List<String> roleNames = getRoleNames(webXml, securityRoles);
this.mappableAttributes = Collections.unmodifiableSet(new HashSet<>(roleNames));
}
private List<String> getRoleNames(Resource webXml, NodeList securityRoles) {
ArrayList<String> roleNames = new ArrayList<>();
for (int i = 0; i < securityRoles.getLength(); i++) {
Element secRoleElt = (Element) securityRoles.item(i);
NodeList roles = secRoleElt.getElementsByTagName("role-name");
Element securityRoleElement = (Element) securityRoles.item(i);
NodeList roles = securityRoleElement.getElementsByTagName("role-name");
if (roles.getLength() > 0) {
String roleName = roles.item(0).getTextContent().trim();
roleNames.add(roleName);
@ -102,22 +104,19 @@ public class WebXmlMappableAttributesRetriever
this.logger.info("No security-role elements found in " + webXml);
}
}
this.mappableAttributes = Collections.unmodifiableSet(new HashSet<>(roleNames));
return roleNames;
}
/**
* @return Document for the specified InputStream
*/
private Document getDocument(InputStream aStream) {
Document doc;
try {
DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
factory.setValidating(false);
DocumentBuilder db = factory.newDocumentBuilder();
db.setEntityResolver(new MyEntityResolver());
doc = db.parse(aStream);
return doc;
DocumentBuilder builder = factory.newDocumentBuilder();
builder.setEntityResolver(new MyEntityResolver());
return builder.parse(aStream);
}
catch (FactoryConfigurationError | IOException | SAXException | ParserConfigurationException ex) {
throw new RuntimeException("Unable to parse document object", ex);

View File

@ -31,6 +31,8 @@ import javax.security.auth.Subject;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
/**
* WebSphere Security helper class to allow retrieval of the current username and groups.
* <p>
@ -75,9 +77,7 @@ final class DefaultWASUsernameAndGroupsExtractor implements WASUsernameAndGroups
* @return String the security name for the given subject
*/
private static String getSecurityName(final Subject subject) {
if (logger.isDebugEnabled()) {
logger.debug("Determining Websphere security name for subject " + subject);
}
logger.debug(LogMessage.format("Determining Websphere security name for subject %s", subject));
String userSecurityName = null;
if (subject != null) {
// SEC-803
@ -86,9 +86,7 @@ final class DefaultWASUsernameAndGroupsExtractor implements WASUsernameAndGroups
userSecurityName = (String) invokeMethod(getSecurityNameMethod(), credential);
}
}
if (logger.isDebugEnabled()) {
logger.debug("Websphere security name is " + userSecurityName + " for subject " + subject);
}
logger.debug(LogMessage.format("Websphere security name is %s for subject %s", subject, userSecurityName));
return userSecurityName;
}
@ -119,69 +117,56 @@ final class DefaultWASUsernameAndGroupsExtractor implements WASUsernameAndGroups
*/
@SuppressWarnings("unchecked")
private static List<String> getWebSphereGroups(final String securityName) {
Context ic = null;
Context context = null;
try {
// TODO: Cache UserRegistry object
ic = new InitialContext();
Object objRef = ic.lookup(USER_REGISTRY);
context = new InitialContext();
Object objRef = context.lookup(USER_REGISTRY);
Object userReg = invokeMethod(getNarrowMethod(), null, objRef,
Class.forName("com.ibm.websphere.security.UserRegistry"));
if (logger.isDebugEnabled()) {
logger.debug("Determining WebSphere groups for user " + securityName + " using WebSphere UserRegistry "
+ userReg);
}
final Collection groups = (Collection) invokeMethod(getGroupsForUserMethod(), userReg,
logger.debug(LogMessage.format("Determining WebSphere groups for user %s using WebSphere UserRegistry %s",
securityName, userReg));
final Collection<String> groups = (Collection<String>) invokeMethod(getGroupsForUserMethod(), userReg,
new Object[] { securityName });
if (logger.isDebugEnabled()) {
logger.debug("Groups for user " + securityName + ": " + groups.toString());
}
return new ArrayList(groups);
logger.debug(LogMessage.format("Groups for user %s: %s", securityName, groups));
return new ArrayList<String>(groups);
}
catch (Exception ex) {
logger.error("Exception occured while looking up groups for user", ex);
throw new RuntimeException("Exception occured while looking up groups for user", ex);
}
finally {
try {
if (ic != null) {
ic.close();
}
}
catch (NamingException ex) {
logger.debug("Exception occured while closing context", ex);
closeContext(context);
}
}
private static void closeContext(Context context) {
try {
if (context != null) {
context.close();
}
}
catch (NamingException ex) {
logger.debug("Exception occured while closing context", ex);
}
}
private static Object invokeMethod(Method method, Object instance, Object... args) {
try {
return method.invoke(instance, args);
}
catch (IllegalArgumentException ex) {
logger.error("Error while invoking method " + method.getClass().getName() + "." + method.getName() + "("
+ Arrays.asList(args) + ")", ex);
throw new RuntimeException("Error while invoking method " + method.getClass().getName() + "."
+ method.getName() + "(" + Arrays.asList(args) + ")", ex);
}
catch (IllegalAccessException ex) {
logger.error("Error while invoking method " + method.getClass().getName() + "." + method.getName() + "("
+ Arrays.asList(args) + ")", ex);
throw new RuntimeException("Error while invoking method " + method.getClass().getName() + "."
+ method.getName() + "(" + Arrays.asList(args) + ")", ex);
}
catch (InvocationTargetException ex) {
logger.error("Error while invoking method " + method.getClass().getName() + "." + method.getName() + "("
+ Arrays.asList(args) + ")", ex);
throw new RuntimeException("Error while invoking method " + method.getClass().getName() + "."
+ method.getName() + "(" + Arrays.asList(args) + ")", ex);
catch (IllegalArgumentException | IllegalAccessException | InvocationTargetException ex) {
String message = "Error while invoking method " + method.getClass().getName() + "." + method.getName() + "("
+ Arrays.asList(args) + ")";
logger.error(message, ex);
throw new RuntimeException(message, ex);
}
}
private static Method getMethod(String className, String methodName, String[] parameterTypeNames) {
try {
Class<?> c = Class.forName(className);
final int len = parameterTypeNames.length;
int len = parameterTypeNames.length;
Class<?>[] parameterTypes = new Class[len];
for (int i = 0; i < len; i++) {
parameterTypes[i] = Class.forName(parameterTypeNames[i]);

View File

@ -18,6 +18,7 @@ package org.springframework.security.web.authentication.preauth.websphere;
import javax.servlet.http.HttpServletRequest;
import org.springframework.core.log.LogMessage;
import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter;
/**
@ -51,9 +52,7 @@ public class WebSpherePreAuthenticatedProcessingFilter extends AbstractPreAuthen
@Override
protected Object getPreAuthenticatedPrincipal(HttpServletRequest httpRequest) {
Object principal = this.wasHelper.getCurrentUserName();
if (this.logger.isDebugEnabled()) {
this.logger.debug("PreAuthenticated WebSphere principal: " + principal);
}
this.logger.debug(LogMessage.format("PreAuthenticated WebSphere principal: %s", principal));
return principal;
}

View File

@ -24,6 +24,7 @@ import javax.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.mapping.Attributes2GrantedAuthoritiesMapper;
@ -68,9 +69,8 @@ public class WebSpherePreAuthenticatedWebAuthenticationDetailsSource implements
List<String> webSphereGroups = this.wasHelper.getGroupsForCurrentUser();
Collection<? extends GrantedAuthority> userGas = this.webSphereGroups2GrantedAuthoritiesMapper
.getGrantedAuthorities(webSphereGroups);
if (this.logger.isDebugEnabled()) {
this.logger.debug("WebSphere groups: " + webSphereGroups + " mapped to Granted Authorities: " + userGas);
}
this.logger.debug(
LogMessage.format("WebSphere groups: %s mapped to Granted Authorities: %s", webSphereGroups, userGas));
return userGas;
}

View File

@ -25,6 +25,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.context.MessageSource;
import org.springframework.context.support.MessageSourceAccessor;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.core.SpringSecurityMessageSource;
import org.springframework.util.Assert;
@ -58,24 +59,15 @@ public class SubjectDnX509PrincipalExtractor implements X509PrincipalExtractor {
public Object extractPrincipal(X509Certificate clientCert) {
// String subjectDN = clientCert.getSubjectX500Principal().getName();
String subjectDN = clientCert.getSubjectDN().getName();
this.logger.debug("Subject DN is '" + subjectDN + "'");
this.logger.debug(LogMessage.format("Subject DN is '%s'", subjectDN));
Matcher matcher = this.subjectDnPattern.matcher(subjectDN);
if (!matcher.find()) {
throw new BadCredentialsException(this.messages.getMessage("SubjectDnX509PrincipalExtractor.noMatching",
new Object[] { subjectDN }, "No matching pattern was found in subject DN: {0}"));
}
if (matcher.groupCount() != 1) {
throw new IllegalArgumentException("Regular expression must contain a single group ");
}
Assert.isTrue(matcher.groupCount() == 1, "Regular expression must contain a single group ");
String username = matcher.group(1);
this.logger.debug("Extracted Principal name is '" + username + "'");
this.logger.debug(LogMessage.format("Extracted Principal name is '%s'", username));
return username;
}

View File

@ -20,6 +20,7 @@ import java.security.cert.X509Certificate;
import javax.servlet.http.HttpServletRequest;
import org.springframework.core.log.LogMessage;
import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter;
/**
@ -32,12 +33,7 @@ public class X509AuthenticationFilter extends AbstractPreAuthenticatedProcessing
@Override
protected Object getPreAuthenticatedPrincipal(HttpServletRequest request) {
X509Certificate cert = extractClientCertificate(request);
if (cert == null) {
return null;
}
return this.principalExtractor.extractPrincipal(cert);
return (cert != null) ? this.principalExtractor.extractPrincipal(cert) : null;
}
@Override
@ -47,19 +43,11 @@ public class X509AuthenticationFilter extends AbstractPreAuthenticatedProcessing
private X509Certificate extractClientCertificate(HttpServletRequest request) {
X509Certificate[] certs = (X509Certificate[]) request.getAttribute("javax.servlet.request.X509Certificate");
if (certs != null && certs.length > 0) {
if (this.logger.isDebugEnabled()) {
this.logger.debug("X.509 client authentication certificate:" + certs[0]);
}
this.logger.debug(LogMessage.format("X.509 client authentication certificate:%s", certs[0]));
return certs[0];
}
if (this.logger.isDebugEnabled()) {
this.logger.debug("No client certificate found in request.");
}
this.logger.debug("No client certificate found in request.");
return null;
}

View File

@ -31,6 +31,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.support.MessageSourceAccessor;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AccountStatusException;
import org.springframework.security.authentication.AccountStatusUserDetailsChecker;
import org.springframework.security.authentication.AuthenticationDetailsSource;
@ -118,47 +119,38 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
@Override
public final Authentication autoLogin(HttpServletRequest request, HttpServletResponse response) {
String rememberMeCookie = extractRememberMeCookie(request);
if (rememberMeCookie == null) {
return null;
}
this.logger.debug("Remember-me cookie detected");
if (rememberMeCookie.length() == 0) {
this.logger.debug("Cookie was empty");
cancelCookie(request, response);
return null;
}
UserDetails user = null;
try {
String[] cookieTokens = decodeCookie(rememberMeCookie);
user = processAutoLoginCookie(cookieTokens, request, response);
UserDetails user = processAutoLoginCookie(cookieTokens, request, response);
this.userDetailsChecker.check(user);
this.logger.debug("Remember-me cookie accepted");
return createSuccessfulAuthentication(request, user);
}
catch (CookieTheftException cte) {
catch (CookieTheftException ex) {
cancelCookie(request, response);
throw cte;
throw ex;
}
catch (UsernameNotFoundException noUser) {
this.logger.debug("Remember-me login was valid but corresponding user not found.", noUser);
catch (UsernameNotFoundException ex) {
this.logger.debug("Remember-me login was valid but corresponding user not found.", ex);
}
catch (InvalidCookieException invalidCookie) {
this.logger.debug("Invalid remember-me cookie: " + invalidCookie.getMessage());
catch (InvalidCookieException ex) {
this.logger.debug("Invalid remember-me cookie: " + ex.getMessage());
}
catch (AccountStatusException statusInvalid) {
this.logger.debug("Invalid UserDetails: " + statusInvalid.getMessage());
catch (AccountStatusException ex) {
this.logger.debug("Invalid UserDetails: " + ex.getMessage());
}
catch (RememberMeAuthenticationException ex) {
this.logger.debug(ex.getMessage());
}
cancelCookie(request, response);
return null;
}
@ -172,17 +164,14 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
*/
protected String extractRememberMeCookie(HttpServletRequest request) {
Cookie[] cookies = request.getCookies();
if ((cookies == null) || (cookies.length == 0)) {
return null;
}
for (Cookie cookie : cookies) {
if (this.cookieName.equals(cookie.getName())) {
return cookie.getValue();
}
}
return null;
}
@ -216,18 +205,14 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
for (int j = 0; j < cookieValue.length() % 4; j++) {
cookieValue = cookieValue + "=";
}
try {
Base64.getDecoder().decode(cookieValue.getBytes());
}
catch (IllegalArgumentException ex) {
throw new InvalidCookieException("Cookie token was not Base64 encoded; value was '" + cookieValue + "'");
}
String cookieAsPlainText = new String(Base64.getDecoder().decode(cookieValue.getBytes()));
String[] tokens = StringUtils.delimitedListToStringArray(cookieAsPlainText, DELIMITER);
for (int i = 0; i < tokens.length; i++) {
try {
tokens[i] = URLDecoder.decode(tokens[i], StandardCharsets.UTF_8.toString());
@ -236,7 +221,6 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
this.logger.error(ex.getMessage(), ex);
}
}
return tokens;
}
@ -254,20 +238,15 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
catch (UnsupportedEncodingException ex) {
this.logger.error(ex.getMessage(), ex);
}
if (i < cookieTokens.length - 1) {
sb.append(DELIMITER);
}
}
String value = sb.toString();
sb = new StringBuilder(new String(Base64.getEncoder().encode(value.getBytes())));
while (sb.charAt(sb.length() - 1) == '=') {
sb.deleteCharAt(sb.length() - 1);
}
return sb.toString();
}
@ -293,12 +272,10 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
@Override
public final void loginSuccess(HttpServletRequest request, HttpServletResponse response,
Authentication successfulAuthentication) {
if (!rememberMeRequested(request, this.parameter)) {
this.logger.debug("Remember-me login not requested.");
return;
}
onLoginSuccess(request, response, successfulAuthentication);
}
@ -324,20 +301,15 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
if (this.alwaysRemember) {
return true;
}
String paramValue = request.getParameter(parameter);
if (paramValue != null) {
if (paramValue.equalsIgnoreCase("true") || paramValue.equalsIgnoreCase("on")
|| paramValue.equalsIgnoreCase("yes") || paramValue.equals("1")) {
return true;
}
}
if (this.logger.isDebugEnabled()) {
this.logger.debug("Did not send remember-me cookie (principal did not set parameter '" + parameter + "')");
}
this.logger.debug(
LogMessage.format("Did not send remember-me cookie (principal did not set parameter '%s')", parameter));
return false;
}
@ -370,12 +342,7 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
if (this.cookieDomain != null) {
cookie.setDomain(this.cookieDomain);
}
if (this.useSecureCookie == null) {
cookie.setSecure(request.isSecure());
}
else {
cookie.setSecure(this.useSecureCookie);
}
cookie.setSecure((this.useSecureCookie != null) ? this.useSecureCookie : request.isSecure());
response.addCookie(cookie);
}
@ -402,16 +369,8 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
if (maxAge < 1) {
cookie.setVersion(1);
}
if (this.useSecureCookie == null) {
cookie.setSecure(request.isSecure());
}
else {
cookie.setSecure(this.useSecureCookie);
}
cookie.setSecure((this.useSecureCookie != null) ? this.useSecureCookie : request.isSecure());
cookie.setHttpOnly(true);
response.addCookie(cookie);
}
@ -426,9 +385,8 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
*/
@Override
public void logout(HttpServletRequest request, HttpServletResponse response, Authentication authentication) {
if (this.logger.isDebugEnabled()) {
this.logger.debug("Logout of user " + ((authentication != null) ? authentication.getName() : "Unknown"));
}
this.logger.debug(LogMessage
.of(() -> "Logout of user " + ((authentication != null) ? authentication.getName() : "Unknown")));
cancelCookie(request, response);
}

View File

@ -36,21 +36,17 @@ public class InMemoryTokenRepositoryImpl implements PersistentTokenRepository {
@Override
public synchronized void createNewToken(PersistentRememberMeToken token) {
PersistentRememberMeToken current = this.seriesTokens.get(token.getSeries());
if (current != null) {
throw new DataIntegrityViolationException("Series Id '" + token.getSeries() + "' already exists!");
}
this.seriesTokens.put(token.getSeries(), token);
}
@Override
public synchronized void updateToken(String series, String tokenValue, Date lastUsed) {
PersistentRememberMeToken token = getTokenForSeries(series);
PersistentRememberMeToken newToken = new PersistentRememberMeToken(token.getUsername(), series, tokenValue,
new Date());
// Store it, overwriting the existing one.
this.seriesTokens.put(series, newToken);
}
@ -63,12 +59,9 @@ public class InMemoryTokenRepositoryImpl implements PersistentTokenRepository {
@Override
public synchronized void removeUserTokens(String username) {
Iterator<String> series = this.seriesTokens.keySet().iterator();
while (series.hasNext()) {
String seriesId = series.next();
PersistentRememberMeToken token = this.seriesTokens.get(seriesId);
if (username.equals(token.getUsername())) {
series.remove();
}

View File

@ -16,8 +16,11 @@
package org.springframework.security.web.authentication.rememberme;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Date;
import org.springframework.core.log.LogMessage;
import org.springframework.dao.DataAccessException;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
@ -87,27 +90,26 @@ public class JdbcTokenRepositoryImpl extends JdbcDaoSupport implements Persisten
@Override
public PersistentRememberMeToken getTokenForSeries(String seriesId) {
try {
return getJdbcTemplate().queryForObject(this.tokensBySeriesSql,
(rs, rowNum) -> new PersistentRememberMeToken(rs.getString(1), rs.getString(2), rs.getString(3),
rs.getTimestamp(4)),
seriesId);
return getJdbcTemplate().queryForObject(this.tokensBySeriesSql, this::createRememberMeToken, seriesId);
}
catch (EmptyResultDataAccessException zeroResults) {
if (this.logger.isDebugEnabled()) {
this.logger.debug("Querying token for series '" + seriesId + "' returned no results.", zeroResults);
}
catch (EmptyResultDataAccessException ex) {
this.logger.debug(LogMessage.format("Querying token for series '%s' returned no results.", seriesId), ex);
}
catch (IncorrectResultSizeDataAccessException moreThanOne) {
this.logger.error("Querying token for series '" + seriesId + "' returned more than one value. Series"
+ " should be unique");
catch (IncorrectResultSizeDataAccessException ex) {
this.logger.error(LogMessage.format(
"Querying token for series '%s' returned more than one value. Series" + " should be unique",
seriesId));
}
catch (DataAccessException ex) {
this.logger.error("Failed to load token for series " + seriesId, ex);
}
return null;
}
private PersistentRememberMeToken createRememberMeToken(ResultSet rs, int rowNum) throws SQLException {
return new PersistentRememberMeToken(rs.getString(1), rs.getString(2), rs.getString(3), rs.getTimestamp(4));
}
@Override
public void removeUserTokens(String username) {
getJdbcTemplate().update(this.removeUserTokensSql, username);

View File

@ -24,6 +24,7 @@ import java.util.Date;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService;
@ -93,47 +94,35 @@ public class PersistentTokenBasedRememberMeServices extends AbstractRememberMeSe
@Override
protected UserDetails processAutoLoginCookie(String[] cookieTokens, HttpServletRequest request,
HttpServletResponse response) {
if (cookieTokens.length != 2) {
throw new InvalidCookieException("Cookie token did not contain " + 2 + " tokens, but contained '"
+ Arrays.asList(cookieTokens) + "'");
}
final String presentedSeries = cookieTokens[0];
final String presentedToken = cookieTokens[1];
String presentedSeries = cookieTokens[0];
String presentedToken = cookieTokens[1];
PersistentRememberMeToken token = this.tokenRepository.getTokenForSeries(presentedSeries);
if (token == null) {
// No series match, so we can't authenticate using this cookie
throw new RememberMeAuthenticationException("No persistent token found for series id: " + presentedSeries);
}
// We have a match for this user/series combination
if (!presentedToken.equals(token.getTokenValue())) {
// Token doesn't match series value. Delete all logins for this user and throw
// an exception to warn them.
this.tokenRepository.removeUserTokens(token.getUsername());
throw new CookieTheftException(this.messages.getMessage(
"PersistentTokenBasedRememberMeServices.cookieStolen",
"Invalid remember-me token (Series/token) mismatch. Implies previous cookie theft attack."));
}
if (token.getDate().getTime() + getTokenValiditySeconds() * 1000L < System.currentTimeMillis()) {
throw new RememberMeAuthenticationException("Remember-me login has expired");
}
// Token also matches, so login is valid. Update the token value, keeping the
// *same* series number.
if (this.logger.isDebugEnabled()) {
this.logger.debug("Refreshing persistent login token for user '" + token.getUsername() + "', series '"
+ token.getSeries() + "'");
}
this.logger.debug(LogMessage.format("Refreshing persistent login token for user '%s', series '%s'",
token.getUsername(), token.getSeries()));
PersistentRememberMeToken newToken = new PersistentRememberMeToken(token.getUsername(), token.getSeries(),
generateTokenData(), new Date());
try {
this.tokenRepository.updateToken(newToken.getSeries(), newToken.getTokenValue(), newToken.getDate());
addCookie(newToken, request, response);
@ -142,7 +131,6 @@ public class PersistentTokenBasedRememberMeServices extends AbstractRememberMeSe
this.logger.error("Failed to update token: ", ex);
throw new RememberMeAuthenticationException("Autologin failed due to data access problem");
}
return getUserDetailsService().loadUserByUsername(token.getUsername());
}
@ -155,9 +143,7 @@ public class PersistentTokenBasedRememberMeServices extends AbstractRememberMeSe
protected void onLoginSuccess(HttpServletRequest request, HttpServletResponse response,
Authentication successfulAuthentication) {
String username = successfulAuthentication.getName();
this.logger.debug("Creating new persistent login for user " + username);
this.logger.debug(LogMessage.format("Creating new persistent login for user %s", username));
PersistentRememberMeToken persistentToken = new PersistentRememberMeToken(username, generateSeriesData(),
generateTokenData(), new Date());
try {
@ -172,7 +158,6 @@ public class PersistentTokenBasedRememberMeServices extends AbstractRememberMeSe
@Override
public void logout(HttpServletRequest request, HttpServletResponse response, Authentication authentication) {
super.logout(request, response, authentication);
if (authentication != null) {
this.tokenRepository.removeUserTokens(authentication.getName());
}

View File

@ -27,6 +27,7 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.event.InteractiveAuthenticationSuccessEvent;
import org.springframework.security.core.Authentication;
@ -86,66 +87,50 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements
}
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req;
HttpServletResponse response = (HttpServletResponse) res;
if (SecurityContextHolder.getContext().getAuthentication() == null) {
Authentication rememberMeAuth = this.rememberMeServices.autoLogin(request, response);
if (rememberMeAuth != null) {
// Attempt authenticaton via AuthenticationManager
try {
rememberMeAuth = this.authenticationManager.authenticate(rememberMeAuth);
// Store to SecurityContextHolder
SecurityContextHolder.getContext().setAuthentication(rememberMeAuth);
onSuccessfulAuthentication(request, response, rememberMeAuth);
if (this.logger.isDebugEnabled()) {
this.logger.debug("SecurityContextHolder populated with remember-me token: '"
+ SecurityContextHolder.getContext().getAuthentication() + "'");
}
// Fire event
if (this.eventPublisher != null) {
this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(
SecurityContextHolder.getContext().getAuthentication(), this.getClass()));
}
if (this.successHandler != null) {
this.successHandler.onAuthenticationSuccess(request, response, rememberMeAuth);
return;
}
doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
}
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
if (SecurityContextHolder.getContext().getAuthentication() != null) {
this.logger.debug(LogMessage
.of(() -> "SecurityContextHolder not populated with remember-me token, as it already contained: '"
+ SecurityContextHolder.getContext().getAuthentication() + "'"));
chain.doFilter(request, response);
return;
}
Authentication rememberMeAuth = this.rememberMeServices.autoLogin(request, response);
if (rememberMeAuth != null) {
// Attempt authenticaton via AuthenticationManager
try {
rememberMeAuth = this.authenticationManager.authenticate(rememberMeAuth);
// Store to SecurityContextHolder
SecurityContextHolder.getContext().setAuthentication(rememberMeAuth);
onSuccessfulAuthentication(request, response, rememberMeAuth);
this.logger.debug(LogMessage.of(() -> "SecurityContextHolder populated with remember-me token: '"
+ SecurityContextHolder.getContext().getAuthentication() + "'"));
if (this.eventPublisher != null) {
this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(
SecurityContextHolder.getContext().getAuthentication(), this.getClass()));
}
catch (AuthenticationException authenticationException) {
if (this.logger.isDebugEnabled()) {
this.logger.debug("SecurityContextHolder not populated with remember-me token, as "
+ "AuthenticationManager rejected Authentication returned by RememberMeServices: '"
+ rememberMeAuth + "'; invalidating remember-me token", authenticationException);
}
this.rememberMeServices.loginFail(request, response);
onUnsuccessfulAuthentication(request, response, authenticationException);
if (this.successHandler != null) {
this.successHandler.onAuthenticationSuccess(request, response, rememberMeAuth);
return;
}
}
chain.doFilter(request, response);
}
else {
if (this.logger.isDebugEnabled()) {
this.logger
.debug("SecurityContextHolder not populated with remember-me token, as it already contained: '"
+ SecurityContextHolder.getContext().getAuthentication() + "'");
catch (AuthenticationException ex) {
this.logger.debug(LogMessage
.format("SecurityContextHolder not populated with remember-me token, as AuthenticationManager "
+ "rejected Authentication returned by RememberMeServices: '%s'; "
+ "invalidating remember-me token", rememberMeAuth),
ex);
this.rememberMeServices.loginFail(request, response);
onUnsuccessfulAuthentication(request, response, ex);
}
chain.doFilter(request, response);
}
chain.doFilter(request, response);
}
/**

View File

@ -90,52 +90,43 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices {
@Override
protected UserDetails processAutoLoginCookie(String[] cookieTokens, HttpServletRequest request,
HttpServletResponse response) {
if (cookieTokens.length != 3) {
throw new InvalidCookieException(
"Cookie token did not contain 3" + " tokens, but contained '" + Arrays.asList(cookieTokens) + "'");
}
long tokenExpiryTime = getTokenExpiryTime(cookieTokens);
if (isTokenExpired(tokenExpiryTime)) {
throw new InvalidCookieException("Cookie token[1] has expired (expired on '" + new Date(tokenExpiryTime)
+ "'; current time is '" + new Date() + "')");
}
// Check the user exists. Defer lookup until after expiry time checked, to
// possibly avoid expensive database call.
UserDetails userDetails = getUserDetailsService().loadUserByUsername(cookieTokens[0]);
Assert.notNull(userDetails, () -> "UserDetailsService " + getUserDetailsService()
+ " returned null for username " + cookieTokens[0] + ". " + "This is an interface contract violation");
// Check signature of token matches remaining details. Must do this after user
// lookup, as we need the DAO-derived password. If efficiency was a major issue,
// just add in a UserCache implementation, but recall that this method is usually
// only called once per HttpSession - if the token is valid, it will cause
// SecurityContextHolder population, whilst if invalid, will cause the cookie to
// be cancelled.
String expectedTokenSignature = makeTokenSignature(tokenExpiryTime, userDetails.getUsername(),
userDetails.getPassword());
if (!equals(expectedTokenSignature, cookieTokens[2])) {
throw new InvalidCookieException("Cookie token[2] contained signature '" + cookieTokens[2]
+ "' but expected '" + expectedTokenSignature + "'");
}
return userDetails;
}
long tokenExpiryTime;
private long getTokenExpiryTime(String[] cookieTokens) {
try {
tokenExpiryTime = new Long(cookieTokens[1]);
return new Long(cookieTokens[1]);
}
catch (NumberFormatException nfe) {
throw new InvalidCookieException(
"Cookie token[1] did not contain a valid number (contained '" + cookieTokens[1] + "')");
}
if (isTokenExpired(tokenExpiryTime)) {
throw new InvalidCookieException("Cookie token[1] has expired (expired on '" + new Date(tokenExpiryTime)
+ "'; current time is '" + new Date() + "')");
}
// Check the user exists.
// Defer lookup until after expiry time checked, to possibly avoid expensive
// database call.
UserDetails userDetails = getUserDetailsService().loadUserByUsername(cookieTokens[0]);
Assert.notNull(userDetails, () -> "UserDetailsService " + getUserDetailsService()
+ " returned null for username " + cookieTokens[0] + ". " + "This is an interface contract violation");
// Check signature of token matches remaining details.
// Must do this after user lookup, as we need the DAO-derived password.
// If efficiency was a major issue, just add in a UserCache implementation,
// but recall that this method is usually only called once per HttpSession - if
// the token is valid,
// it will cause SecurityContextHolder population, whilst if invalid, will cause
// the cookie to be cancelled.
String expectedTokenSignature = makeTokenSignature(tokenExpiryTime, userDetails.getUsername(),
userDetails.getPassword());
if (!equals(expectedTokenSignature, cookieTokens[2])) {
throw new InvalidCookieException("Cookie token[2] contained signature '" + cookieTokens[2]
+ "' but expected '" + expectedTokenSignature + "'");
}
return userDetails;
}
/**
@ -144,15 +135,13 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices {
*/
protected String makeTokenSignature(long tokenExpiryTime, String username, String password) {
String data = username + ":" + tokenExpiryTime + ":" + password + ":" + getKey();
MessageDigest digest;
try {
digest = MessageDigest.getInstance("MD5");
MessageDigest digest = MessageDigest.getInstance("MD5");
return new String(Hex.encode(digest.digest(data.getBytes())));
}
catch (NoSuchAlgorithmException ex) {
throw new IllegalStateException("No MD5 algorithm available!");
}
return new String(Hex.encode(digest.digest(data.getBytes())));
}
protected boolean isTokenExpired(long tokenExpiryTime) {
@ -162,10 +151,8 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices {
@Override
public void onLoginSuccess(HttpServletRequest request, HttpServletResponse response,
Authentication successfulAuthentication) {
String username = retrieveUserName(successfulAuthentication);
String password = retrievePassword(successfulAuthentication);
// If unable to find a username and password, just abort as
// TokenBasedRememberMeServices is
// unable to construct a valid token in this case.
@ -173,27 +160,21 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices {
this.logger.debug("Unable to retrieve username");
return;
}
if (!StringUtils.hasLength(password)) {
UserDetails user = getUserDetailsService().loadUserByUsername(username);
password = user.getPassword();
if (!StringUtils.hasLength(password)) {
this.logger.debug("Unable to obtain password for user: " + username);
return;
}
}
int tokenLifetime = calculateLoginLifetime(request, successfulAuthentication);
long expiryTime = System.currentTimeMillis();
// SEC-949
expiryTime += 1000L * ((tokenLifetime < 0) ? TWO_WEEKS_S : tokenLifetime);
String signatureValue = makeTokenSignature(expiryTime, username, password);
setCookie(new String[] { username, Long.toString(expiryTime), signatureValue }, tokenLifetime, request,
response);
if (this.logger.isDebugEnabled()) {
this.logger.debug(
"Added remember-me cookie for user '" + username + "', expiry: '" + new Date(expiryTime) + "'");
@ -223,21 +204,17 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices {
if (isInstanceOfUserDetails(authentication)) {
return ((UserDetails) authentication.getPrincipal()).getUsername();
}
else {
return authentication.getPrincipal().toString();
}
return authentication.getPrincipal().toString();
}
protected String retrievePassword(Authentication authentication) {
if (isInstanceOfUserDetails(authentication)) {
return ((UserDetails) authentication.getPrincipal()).getPassword();
}
else {
if (authentication.getCredentials() == null) {
return null;
}
if (authentication.getCredentials() != null) {
return authentication.getCredentials().toString();
}
return null;
}
private boolean isInstanceOfUserDetails(Authentication authentication) {
@ -250,15 +227,11 @@ public class TokenBasedRememberMeServices extends AbstractRememberMeServices {
private static boolean equals(String expected, String actual) {
byte[] expectedBytes = bytesUtf8(expected);
byte[] actualBytes = bytesUtf8(actual);
return MessageDigest.isEqual(expectedBytes, actualBytes);
}
private static byte[] bytesUtf8(String s) {
if (s == null) {
return null;
}
return Utf8.encode(s);
return (s != null) ? Utf8.encode(s) : null;
}
}

View File

@ -73,35 +73,26 @@ public abstract class AbstractSessionFixationProtectionStrategy
public void onAuthentication(Authentication authentication, HttpServletRequest request,
HttpServletResponse response) {
boolean hadSessionAlready = request.getSession(false) != null;
if (!hadSessionAlready && !this.alwaysCreateSession) {
// Session fixation isn't a problem if there's no session
return;
}
// Create new session if necessary
HttpSession session = request.getSession();
if (hadSessionAlready && request.isRequestedSessionIdValid()) {
String originalSessionId;
String newSessionId;
Object mutex = WebUtils.getSessionMutex(session);
synchronized (mutex) {
// We need to migrate to a new session
originalSessionId = session.getId();
session = applySessionFixation(request);
newSessionId = session.getId();
}
if (originalSessionId.equals(newSessionId)) {
this.logger.warn(
"Your servlet container did not change the session ID when a new session was created. You will"
+ " not be adequately protected against session-fixation attacks");
this.logger.warn("Your servlet container did not change the session ID when a new session "
+ "was created. You will not be adequately protected against session-fixation attacks");
}
onSessionChange(originalSessionId, session, authentication);
}
}

View File

@ -25,6 +25,7 @@ import javax.servlet.http.HttpSession;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.Authentication;
import org.springframework.util.Assert;
@ -63,10 +64,7 @@ public class CompositeSessionAuthenticationStrategy implements SessionAuthentica
public CompositeSessionAuthenticationStrategy(List<SessionAuthenticationStrategy> delegateStrategies) {
Assert.notEmpty(delegateStrategies, "delegateStrategies cannot be null or empty");
for (SessionAuthenticationStrategy strategy : delegateStrategies) {
if (strategy == null) {
throw new IllegalArgumentException(
"delegateStrategies cannot contain null entires. Got " + delegateStrategies);
}
Assert.notNull(strategy, () -> "delegateStrategies cannot contain null entires. Got " + delegateStrategies);
}
this.delegateStrategies = delegateStrategies;
}
@ -75,9 +73,7 @@ public class CompositeSessionAuthenticationStrategy implements SessionAuthentica
public void onAuthentication(Authentication authentication, HttpServletRequest request,
HttpServletResponse response) throws SessionAuthenticationException {
for (SessionAuthenticationStrategy delegate : this.delegateStrategies) {
if (this.logger.isDebugEnabled()) {
this.logger.debug("Delegating to " + delegate);
}
this.logger.debug(LogMessage.format("Delegating to %s", delegate));
delegate.onAuthentication(authentication, request, response);
}
}

View File

@ -94,26 +94,19 @@ public class ConcurrentSessionControlAuthenticationStrategy
@Override
public void onAuthentication(Authentication authentication, HttpServletRequest request,
HttpServletResponse response) {
final List<SessionInformation> sessions = this.sessionRegistry.getAllSessions(authentication.getPrincipal(),
false);
List<SessionInformation> sessions = this.sessionRegistry.getAllSessions(authentication.getPrincipal(), false);
int sessionCount = sessions.size();
int allowedSessions = getMaximumSessionsForThisUser(authentication);
if (sessionCount < allowedSessions) {
// They haven't got too many login sessions running at present
return;
}
if (allowedSessions == -1) {
// We permit unlimited logins
return;
}
if (sessionCount == allowedSessions) {
HttpSession session = request.getSession(false);
if (session != null) {
// Only permit it though if this request is associated with one of the
// already registered sessions
@ -126,7 +119,6 @@ public class ConcurrentSessionControlAuthenticationStrategy
// If the session is null, a new one will be created by the parent class,
// exceeding the allowed number
}
allowableSessionsExceeded(sessions, allowedSessions, this.sessionRegistry);
}
@ -157,7 +149,6 @@ public class ConcurrentSessionControlAuthenticationStrategy
this.messages.getMessage("ConcurrentSessionControlAuthenticationStrategy.exceededAllowed",
new Object[] { allowableSessions }, "Maximum sessions of {0} for this principal exceeded"));
}
// Determine least recently used sessions, and mark them for invalidation
sessions.sort(Comparator.comparing(SessionInformation::getLastRequest));
int maximumSessionsExceededBy = sessions.size() - allowableSessions + 1;

View File

@ -23,6 +23,8 @@ import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession;
import org.springframework.core.log.LogMessage;
/**
* Uses {@code HttpServletRequest.invalidate()} to protect against session fixation
* attacks.
@ -82,21 +84,13 @@ public class SessionFixationProtectionStrategy extends AbstractSessionFixationPr
final HttpSession applySessionFixation(HttpServletRequest request) {
HttpSession session = request.getSession();
String originalSessionId = session.getId();
if (this.logger.isDebugEnabled()) {
this.logger.debug("Invalidating session with Id '" + originalSessionId + "' "
+ (this.migrateSessionAttributes ? "and" : "without") + " migrating attributes.");
}
this.logger.debug(LogMessage.of(() -> "Invalidating session with Id '" + originalSessionId + "' "
+ (this.migrateSessionAttributes ? "and" : "without") + " migrating attributes."));
Map<String, Object> attributesToMigrate = extractAttributes(session);
int maxInactiveIntervalToMigrate = session.getMaxInactiveInterval();
session.invalidate();
session = request.getSession(true); // we now have a new session
if (this.logger.isDebugEnabled()) {
this.logger.debug("Started new session: " + session.getId());
}
this.logger.debug(LogMessage.format("Started new session: %s", session.getId()));
transferAttributes(attributesToMigrate, session);
if (this.migrateSessionAttributes) {
session.setMaxInactiveInterval(maxInactiveIntervalToMigrate);
@ -111,27 +105,22 @@ public class SessionFixationProtectionStrategy extends AbstractSessionFixationPr
*/
void transferAttributes(Map<String, Object> attributes, HttpSession newSession) {
if (attributes != null) {
for (Map.Entry<String, Object> entry : attributes.entrySet()) {
newSession.setAttribute(entry.getKey(), entry.getValue());
}
attributes.forEach(newSession::setAttribute);
}
}
@SuppressWarnings("unchecked")
private HashMap<String, Object> createMigratedAttributeMap(HttpSession session) {
HashMap<String, Object> attributesToMigrate = new HashMap<>();
Enumeration enumer = session.getAttributeNames();
while (enumer.hasMoreElements()) {
String key = (String) enumer.nextElement();
Enumeration<String> enumeration = session.getAttributeNames();
while (enumeration.hasMoreElements()) {
String key = enumeration.nextElement();
if (!this.migrateSessionAttributes && !key.startsWith("SPRING_SECURITY_")) {
// Only retain Spring Security attributes
continue;
}
attributesToMigrate.put(key, session.getAttribute(key));
}
return attributesToMigrate;
}

View File

@ -34,6 +34,7 @@ import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.context.MessageSource;
import org.springframework.context.MessageSourceAware;
import org.springframework.context.support.MessageSourceAccessor;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AccountExpiredException;
import org.springframework.security.authentication.AccountStatusUserDetailsChecker;
import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
@ -149,7 +150,6 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
Assert.isNull(this.successHandler, "You cannot set both successHandler and targetUrl");
this.successHandler = new SimpleUrlAuthenticationSuccessHandler(this.targetUrl);
}
if (this.failureHandler == null) {
this.failureHandler = (this.switchFailureUrl != null)
? new SimpleUrlAuthenticationFailureHandler(this.switchFailureUrl)
@ -161,20 +161,20 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
}
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req;
HttpServletResponse response = (HttpServletResponse) res;
doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
}
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
// check for switch or exit request
if (requiresSwitchUser(request)) {
// if set, attempt switch and store original
try {
Authentication targetUser = attemptSwitchUser(request);
// update the current context to the new target user
SecurityContextHolder.getContext().setAuthentication(targetUser);
// redirect to target url
this.successHandler.onAuthenticationSuccess(request, response, targetUser);
}
@ -182,22 +182,17 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
this.logger.debug("Switch User failed", ex);
this.failureHandler.onAuthenticationFailure(request, response, ex);
}
return;
}
else if (requiresExitUser(request)) {
if (requiresExitUser(request)) {
// get the original authentication object (if exists)
Authentication originalUser = attemptExitUser(request);
// update the current context back to the original user
SecurityContextHolder.getContext().setAuthentication(originalUser);
// redirect to target url
this.successHandler.onAuthenticationSuccess(request, response, originalUser);
return;
}
chain.doFilter(request, response);
}
@ -214,33 +209,19 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
*/
protected Authentication attemptSwitchUser(HttpServletRequest request) throws AuthenticationException {
UsernamePasswordAuthenticationToken targetUserRequest;
String username = request.getParameter(this.usernameParameter);
if (username == null) {
username = "";
}
if (this.logger.isDebugEnabled()) {
this.logger.debug("Attempt to switch to user [" + username + "]");
}
username = (username != null) ? username : "";
this.logger.debug(LogMessage.format("Attempt to switch to user [%s]", username));
UserDetails targetUser = this.userDetailsService.loadUserByUsername(username);
this.userDetailsChecker.check(targetUser);
// OK, create the switch user token
targetUserRequest = createSwitchUserToken(request, targetUser);
if (this.logger.isDebugEnabled()) {
this.logger.debug("Switch User Token [" + targetUserRequest + "]");
}
this.logger.debug(LogMessage.format("Switch User Token [%s]", targetUserRequest));
// publish event
if (this.eventPublisher != null) {
this.eventPublisher.publishEvent(new AuthenticationSwitchUserEvent(
SecurityContextHolder.getContext().getAuthentication(), targetUser));
}
return targetUserRequest;
}
@ -256,35 +237,28 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
throws AuthenticationCredentialsNotFoundException {
// need to check to see if the current user has a SwitchUserGrantedAuthority
Authentication current = SecurityContextHolder.getContext().getAuthentication();
if (null == current) {
if (current == null) {
throw new AuthenticationCredentialsNotFoundException(this.messages
.getMessage("SwitchUserFilter.noCurrentUser", "No current user associated with this request"));
}
// check to see if the current user did actual switch to another user
// if so, get the original source user so we can switch back
Authentication original = getSourceAuthentication(current);
if (original == null) {
this.logger.debug("Could not find original user Authentication object!");
throw new AuthenticationCredentialsNotFoundException(this.messages.getMessage(
"SwitchUserFilter.noOriginalAuthentication", "Could not find original Authentication object"));
}
// get the source user details
UserDetails originalUser = null;
Object obj = original.getPrincipal();
if ((obj != null) && obj instanceof UserDetails) {
originalUser = (UserDetails) obj;
}
// publish event
if (this.eventPublisher != null) {
this.eventPublisher.publishEvent(new AuthenticationSwitchUserEvent(current, originalUser));
}
return original;
}
@ -299,45 +273,38 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
*/
private UsernamePasswordAuthenticationToken createSwitchUserToken(HttpServletRequest request,
UserDetails targetUser) {
UsernamePasswordAuthenticationToken targetUserRequest;
// grant an additional authority that contains the original Authentication object
// which will be used to 'exit' from the current switched user.
Authentication currentAuth;
try {
// SEC-1763. Check first if we are already switched.
currentAuth = attemptExitUser(request);
}
catch (AuthenticationCredentialsNotFoundException ex) {
currentAuth = SecurityContextHolder.getContext().getAuthentication();
}
GrantedAuthority switchAuthority = new SwitchUserGrantedAuthority(this.switchAuthorityRole, currentAuth);
Authentication currentAuthentication = getCurrentAuthentication(request);
GrantedAuthority switchAuthority = new SwitchUserGrantedAuthority(this.switchAuthorityRole,
currentAuthentication);
// get the original authorities
Collection<? extends GrantedAuthority> orig = targetUser.getAuthorities();
// Allow subclasses to change the authorities to be granted
if (this.switchUserAuthorityChanger != null) {
orig = this.switchUserAuthorityChanger.modifyGrantedAuthorities(targetUser, currentAuth, orig);
orig = this.switchUserAuthorityChanger.modifyGrantedAuthorities(targetUser, currentAuthentication, orig);
}
// add the new switch user authority
List<GrantedAuthority> newAuths = new ArrayList<>(orig);
newAuths.add(switchAuthority);
// create the new authentication token
targetUserRequest = new UsernamePasswordAuthenticationToken(targetUser, targetUser.getPassword(), newAuths);
// set details
targetUserRequest.setDetails(this.authenticationDetailsSource.buildDetails(request));
return targetUserRequest;
}
private Authentication getCurrentAuthentication(HttpServletRequest request) {
try {
// SEC-1763. Check first if we are already switched.
return attemptExitUser(request);
}
catch (AuthenticationCredentialsNotFoundException ex) {
return SecurityContextHolder.getContext().getAuthentication();
}
}
/**
* Find the original <code>Authentication</code> object from the current user's
* granted authorities. A successfully switched user should have a
@ -349,10 +316,8 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
*/
private Authentication getSourceAuthentication(Authentication current) {
Authentication original = null;
// iterate over granted authorities and find the 'switch user' authority
Collection<? extends GrantedAuthority> authorities = current.getAuthorities();
for (GrantedAuthority auth : authorities) {
// check for switch user type of authority
if (auth instanceof SwitchUserGrantedAuthority) {
@ -360,7 +325,6 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
this.logger.debug("Found original switch user granted authority [" + original + "]");
}
}
return original;
}

View File

@ -112,24 +112,28 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
this.logoutSuccessUrl = DEFAULT_LOGIN_PAGE_URL + "?logout";
this.failureUrl = DEFAULT_LOGIN_PAGE_URL + "?" + ERROR_PARAMETER_NAME;
if (authFilter != null) {
this.formLoginEnabled = true;
this.usernameParameter = authFilter.getUsernameParameter();
this.passwordParameter = authFilter.getPasswordParameter();
if (authFilter.getRememberMeServices() instanceof AbstractRememberMeServices) {
this.rememberMeParameter = ((AbstractRememberMeServices) authFilter.getRememberMeServices())
.getParameter();
}
initAuthFilter(authFilter);
}
if (openIDFilter != null) {
this.openIdEnabled = true;
this.openIDusernameParameter = "openid_identifier";
initOpenIdFilter(openIDFilter);
}
}
if (openIDFilter.getRememberMeServices() instanceof AbstractRememberMeServices) {
this.openIDrememberMeParameter = ((AbstractRememberMeServices) openIDFilter.getRememberMeServices())
.getParameter();
}
private void initAuthFilter(UsernamePasswordAuthenticationFilter authFilter) {
this.formLoginEnabled = true;
this.usernameParameter = authFilter.getUsernameParameter();
this.passwordParameter = authFilter.getPasswordParameter();
if (authFilter.getRememberMeServices() instanceof AbstractRememberMeServices) {
this.rememberMeParameter = ((AbstractRememberMeServices) authFilter.getRememberMeServices()).getParameter();
}
}
private void initOpenIdFilter(AbstractAuthenticationProcessingFilter openIDFilter) {
this.openIdEnabled = true;
this.openIDusernameParameter = "openid_identifier";
if (openIDFilter.getRememberMeServices() instanceof AbstractRememberMeServices) {
this.openIDrememberMeParameter = ((AbstractRememberMeServices) openIDFilter.getRememberMeServices())
.getParameter();
}
}
@ -214,11 +218,13 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
}
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req;
HttpServletResponse response = (HttpServletResponse) res;
doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
}
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
boolean loginError = isErrorPage(request);
boolean logoutSuccess = isLogoutSuccess(request);
if (isLoginUrlRequest(request) || loginError || logoutSuccess) {
@ -226,66 +232,69 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
response.setContentType("text/html;charset=UTF-8");
response.setContentLength(loginPageHtml.getBytes(StandardCharsets.UTF_8).length);
response.getWriter().write(loginPageHtml);
return;
}
chain.doFilter(request, response);
}
private String generateLoginPageHtml(HttpServletRequest request, boolean loginError, boolean logoutSuccess) {
String errorMsg = "Invalid credentials";
if (loginError) {
HttpSession session = request.getSession(false);
if (session != null) {
AuthenticationException ex = (AuthenticationException) session
.getAttribute(WebAttributes.AUTHENTICATION_EXCEPTION);
errorMsg = (ex != null) ? ex.getMessage() : "Invalid credentials";
}
}
StringBuilder sb = new StringBuilder();
sb.append("<!DOCTYPE html>\n" + "<html lang=\"en\">\n" + " <head>\n" + " <meta charset=\"utf-8\">\n"
+ " <meta name=\"viewport\" content=\"width=device-width, initial-scale=1, shrink-to-fit=no\">\n"
+ " <meta name=\"description\" content=\"\">\n" + " <meta name=\"author\" content=\"\">\n"
+ " <title>Please sign in</title>\n"
+ " <link href=\"https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0-beta/css/bootstrap.min.css\" rel=\"stylesheet\" integrity=\"sha384-/Y6pD6FV/Vv2HJnA6t+vslU6fwYXjCFtcEpHbNJ0lyAFsXTsjBbfaDjzALeQsN6M\" crossorigin=\"anonymous\">\n"
+ " <link href=\"https://getbootstrap.com/docs/4.0/examples/signin/signin.css\" rel=\"stylesheet\" crossorigin=\"anonymous\"/>\n"
+ " </head>\n" + " <body>\n" + " <div class=\"container\">\n");
String contextPath = request.getContextPath();
StringBuilder sb = new StringBuilder();
sb.append("<!DOCTYPE html>\n");
sb.append("<html lang=\"en\">\n");
sb.append(" <head>\n");
sb.append(" <meta charset=\"utf-8\">\n");
sb.append(" <meta name=\"viewport\" content=\"width=device-width, initial-scale=1, shrink-to-fit=no\">\n");
sb.append(" <meta name=\"description\" content=\"\">\n");
sb.append(" <meta name=\"author\" content=\"\">\n");
sb.append(" <title>Please sign in</title>\n");
sb.append(" <link href=\"https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0-beta/css/bootstrap.min.css\" "
+ "rel=\"stylesheet\" integrity=\"sha384-/Y6pD6FV/Vv2HJnA6t+vslU6fwYXjCFtcEpHbNJ0lyAFsXTsjBbfaDjzALeQsN6M\" crossorigin=\"anonymous\">\n");
sb.append(" <link href=\"https://getbootstrap.com/docs/4.0/examples/signin/signin.css\" "
+ "rel=\"stylesheet\" crossorigin=\"anonymous\"/>\n");
sb.append(" </head>\n");
sb.append(" <body>\n");
sb.append(" <div class=\"container\">\n");
if (this.formLoginEnabled) {
sb.append(" <form class=\"form-signin\" method=\"post\" action=\"" + contextPath
+ this.authenticationUrl + "\">\n"
+ " <h2 class=\"form-signin-heading\">Please sign in</h2>\n"
+ createError(loginError, errorMsg) + createLogoutSuccess(logoutSuccess) + " <p>\n"
+ " <label for=\"username\" class=\"sr-only\">Username</label>\n"
+ " <input type=\"text\" id=\"username\" name=\"" + this.usernameParameter
+ "\" class=\"form-control\" placeholder=\"Username\" required autofocus>\n" + " </p>\n"
+ " <p>\n" + " <label for=\"password\" class=\"sr-only\">Password</label>\n"
+ " <input type=\"password\" id=\"password\" name=\"" + this.passwordParameter
+ "\" class=\"form-control\" placeholder=\"Password\" required>\n" + " </p>\n"
+ createRememberMe(this.rememberMeParameter) + renderHiddenInputs(request)
+ " <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Sign in</button>\n"
+ " </form>\n");
+ this.authenticationUrl + "\">\n");
sb.append(" <h2 class=\"form-signin-heading\">Please sign in</h2>\n");
sb.append(createError(loginError, errorMsg) + createLogoutSuccess(logoutSuccess) + " <p>\n");
sb.append(" <label for=\"username\" class=\"sr-only\">Username</label>\n");
sb.append(" <input type=\"text\" id=\"username\" name=\"" + this.usernameParameter
+ "\" class=\"form-control\" placeholder=\"Username\" required autofocus>\n");
sb.append(" </p>\n");
sb.append(" <p>\n");
sb.append(" <label for=\"password\" class=\"sr-only\">Password</label>\n");
sb.append(" <input type=\"password\" id=\"password\" name=\"" + this.passwordParameter
+ "\" class=\"form-control\" placeholder=\"Password\" required>\n");
sb.append(" </p>\n");
sb.append(createRememberMe(this.rememberMeParameter) + renderHiddenInputs(request));
sb.append(" <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Sign in</button>\n");
sb.append(" </form>\n");
}
if (this.openIdEnabled) {
sb.append(" <form name=\"oidf\" class=\"form-signin\" method=\"post\" action=\"" + contextPath
+ this.openIDauthenticationUrl + "\">\n"
+ " <h2 class=\"form-signin-heading\">Login with OpenID Identity</h2>\n"
+ createError(loginError, errorMsg) + createLogoutSuccess(logoutSuccess) + " <p>\n"
+ " <label for=\"username\" class=\"sr-only\">Identity</label>\n"
+ " <input type=\"text\" id=\"username\" name=\"" + this.openIDusernameParameter
+ "\" class=\"form-control\" placeholder=\"Username\" required autofocus>\n" + " </p>\n"
+ createRememberMe(this.openIDrememberMeParameter) + renderHiddenInputs(request)
+ " <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Sign in</button>\n"
+ " </form>\n");
+ this.openIDauthenticationUrl + "\">\n");
sb.append(" <h2 class=\"form-signin-heading\">Login with OpenID Identity</h2>\n");
sb.append(createError(loginError, errorMsg) + createLogoutSuccess(logoutSuccess) + " <p>\n");
sb.append(" <label for=\"username\" class=\"sr-only\">Identity</label>\n");
sb.append(" <input type=\"text\" id=\"username\" name=\"" + this.openIDusernameParameter
+ "\" class=\"form-control\" placeholder=\"Username\" required autofocus>\n");
sb.append(" </p>\n");
sb.append(createRememberMe(this.openIDrememberMeParameter) + renderHiddenInputs(request));
sb.append(" <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Sign in</button>\n");
sb.append(" </form>\n");
}
if (this.oauth2LoginEnabled) {
sb.append("<h2 class=\"form-signin-heading\">Login with OAuth 2.0</h2>");
sb.append(createError(loginError, errorMsg));
@ -303,7 +312,6 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
}
sb.append("</table>\n");
}
if (this.saml2LoginEnabled) {
sb.append("<h2 class=\"form-signin-heading\">Login with SAML 2.0</h2>");
sb.append(createError(loginError, errorMsg));
@ -323,15 +331,17 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
}
sb.append("</div>\n");
sb.append("</body></html>");
return sb.toString();
}
private String renderHiddenInputs(HttpServletRequest request) {
StringBuilder sb = new StringBuilder();
for (Map.Entry<String, String> input : this.resolveHiddenInputs.apply(request).entrySet()) {
sb.append("<input name=\"").append(input.getKey()).append("\" type=\"hidden\" value=\"")
.append(input.getValue()).append("\" />\n");
sb.append("<input name=\"");
sb.append(input.getKey());
sb.append("\" type=\"hidden\" value=\"");
sb.append(input.getValue());
sb.append("\" />\n");
}
return sb.toString();
}
@ -356,13 +366,17 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
}
private static String createError(boolean isError, String message) {
return isError ? "<div class=\"alert alert-danger\" role=\"alert\">" + HtmlUtils.htmlEscape(message) + "</div>"
: "";
if (!isError) {
return "";
}
return "<div class=\"alert alert-danger\" role=\"alert\">" + HtmlUtils.htmlEscape(message) + "</div>";
}
private static String createLogoutSuccess(boolean isLogoutSuccess) {
return isLogoutSuccess ? "<div class=\"alert alert-success\" role=\"alert\">You have been signed out</div>"
: "";
if (!isLogoutSuccess) {
return "";
}
return "<div class=\"alert alert-success\" role=\"alert\">You have been signed out</div>";
}
private boolean matches(HttpServletRequest request, String url) {
@ -371,20 +385,16 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
}
String uri = request.getRequestURI();
int pathParamIndex = uri.indexOf(';');
if (pathParamIndex > 0) {
// strip everything after the first semi-colon
uri = uri.substring(0, pathParamIndex);
}
if (request.getQueryString() != null) {
uri += "?" + request.getQueryString();
}
if ("".equals(request.getContextPath())) {
return uri.equals(url);
}
return uri.equals(request.getContextPath() + url);
}

View File

@ -55,21 +55,34 @@ public class DefaultLogoutPageGeneratingFilter extends OncePerRequestFilter {
}
private void renderLogout(HttpServletRequest request, HttpServletResponse response) throws IOException {
String page = "<!DOCTYPE html>\n" + "<html lang=\"en\">\n" + " <head>\n" + " <meta charset=\"utf-8\">\n"
+ " <meta name=\"viewport\" content=\"width=device-width, initial-scale=1, shrink-to-fit=no\">\n"
+ " <meta name=\"description\" content=\"\">\n" + " <meta name=\"author\" content=\"\">\n"
+ " <title>Confirm Log Out?</title>\n"
+ " <link href=\"https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0-beta/css/bootstrap.min.css\" rel=\"stylesheet\" integrity=\"sha384-/Y6pD6FV/Vv2HJnA6t+vslU6fwYXjCFtcEpHbNJ0lyAFsXTsjBbfaDjzALeQsN6M\" crossorigin=\"anonymous\">\n"
+ " <link href=\"https://getbootstrap.com/docs/4.0/examples/signin/signin.css\" rel=\"stylesheet\" crossorigin=\"anonymous\"/>\n"
+ " </head>\n" + " <body>\n" + " <div class=\"container\">\n"
+ " <form class=\"form-signin\" method=\"post\" action=\"" + request.getContextPath()
+ "/logout\">\n" + " <h2 class=\"form-signin-heading\">Are you sure you want to log out?</h2>\n"
+ renderHiddenInputs(request)
+ " <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Log Out</button>\n"
+ " </form>\n" + " </div>\n" + " </body>\n" + "</html>";
StringBuilder sb = new StringBuilder();
sb.append("<!DOCTYPE html>\n");
sb.append("<html lang=\"en\">\n");
sb.append(" <head>\n");
sb.append(" <meta charset=\"utf-8\">\n");
sb.append(" <meta name=\"viewport\" content=\"width=device-width, initial-scale=1, shrink-to-fit=no\">\n");
sb.append(" <meta name=\"description\" content=\"\">\n");
sb.append(" <meta name=\"author\" content=\"\">\n");
sb.append(" <title>Confirm Log Out?</title>\n");
sb.append(" <link href=\"https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0-beta/css/bootstrap.min.css\" "
+ "rel=\"stylesheet\" integrity=\"sha384-/Y6pD6FV/Vv2HJnA6t+vslU6fwYXjCFtcEpHbNJ0lyAFsXTsjBbfaDjzALeQsN6M\" "
+ "crossorigin=\"anonymous\">\n");
sb.append(" <link href=\"https://getbootstrap.com/docs/4.0/examples/signin/signin.css\" "
+ "rel=\"stylesheet\" crossorigin=\"anonymous\"/>\n");
sb.append(" </head>\n");
sb.append(" <body>\n");
sb.append(" <div class=\"container\">\n");
sb.append(" <form class=\"form-signin\" method=\"post\" action=\"" + request.getContextPath()
+ "/logout\">\n");
sb.append(" <h2 class=\"form-signin-heading\">Are you sure you want to log out?</h2>\n");
sb.append(renderHiddenInputs(request)
+ " <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Log Out</button>\n");
sb.append(" </form>\n");
sb.append(" </div>\n");
sb.append(" </body>\n");
sb.append("</html>");
response.setContentType("text/html;charset=UTF-8");
response.getWriter().write(page);
response.getWriter().write(sb.toString());
}
/**
@ -86,8 +99,11 @@ public class DefaultLogoutPageGeneratingFilter extends OncePerRequestFilter {
private String renderHiddenInputs(HttpServletRequest request) {
StringBuilder sb = new StringBuilder();
for (Map.Entry<String, String> input : this.resolveHiddenInputs.apply(request).entrySet()) {
sb.append("<input name=\"").append(input.getKey()).append("\" type=\"hidden\" value=\"")
.append(input.getValue()).append("\" />\n");
sb.append("<input name=\"");
sb.append(input.getKey());
sb.append("\" type=\"hidden\" value=\"");
sb.append(input.getValue());
sb.append("\" />\n");
}
return sb.toString();
}

View File

@ -80,29 +80,17 @@ public class BasicAuthenticationConverter implements AuthenticationConverter {
if (header == null) {
return null;
}
header = header.trim();
if (!StringUtils.startsWithIgnoreCase(header, AUTHENTICATION_SCHEME_BASIC)) {
return null;
}
if (header.equalsIgnoreCase(AUTHENTICATION_SCHEME_BASIC)) {
throw new BadCredentialsException("Empty basic authentication token");
}
byte[] base64Token = header.substring(6).getBytes(StandardCharsets.UTF_8);
byte[] decoded;
try {
decoded = Base64.getDecoder().decode(base64Token);
}
catch (IllegalArgumentException ex) {
throw new BadCredentialsException("Failed to decode basic authentication token");
}
byte[] decoded = decode(base64Token);
String token = new String(decoded, getCredentialsCharset(request));
int delim = token.indexOf(":");
if (delim == -1) {
throw new BadCredentialsException("Invalid basic authentication token");
}
@ -112,6 +100,15 @@ public class BasicAuthenticationConverter implements AuthenticationConverter {
return result;
}
private byte[] decode(byte[] base64Token) {
try {
return Base64.getDecoder().decode(base64Token);
}
catch (IllegalArgumentException ex) {
throw new BadCredentialsException("Failed to decode basic authentication token");
}
}
protected Charset getCredentialsCharset(HttpServletRequest request) {
return getCredentialsCharset();
}

View File

@ -24,6 +24,7 @@ import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager;
@ -132,7 +133,6 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
@Override
public void afterPropertiesSet() {
Assert.notNull(this.authenticationManager, "An AuthenticationManager is required");
if (!isIgnoreFailure()) {
Assert.notNull(this.authenticationEntryPoint, "An AuthenticationEntryPoint is required");
}
@ -141,53 +141,34 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
final boolean debug = this.logger.isDebugEnabled();
try {
UsernamePasswordAuthenticationToken authRequest = this.authenticationConverter.convert(request);
if (authRequest == null) {
chain.doFilter(request, response);
return;
}
String username = authRequest.getName();
if (debug) {
this.logger.debug("Basic Authentication Authorization header found for user '" + username + "'");
}
this.logger.debug(
LogMessage.format("Basic Authentication Authorization header found for user '%s'", username));
if (authenticationIsRequired(username)) {
Authentication authResult = this.authenticationManager.authenticate(authRequest);
if (debug) {
this.logger.debug("Authentication success: " + authResult);
}
this.logger.debug(LogMessage.format("Authentication success: %s", authResult));
SecurityContextHolder.getContext().setAuthentication(authResult);
this.rememberMeServices.loginSuccess(request, response, authResult);
onSuccessfulAuthentication(request, response, authResult);
}
}
catch (AuthenticationException failed) {
catch (AuthenticationException ex) {
SecurityContextHolder.clearContext();
if (debug) {
this.logger.debug("Authentication request for failed!", failed);
}
this.logger.debug("Authentication request for failed!", ex);
this.rememberMeServices.loginFail(request, response);
onUnsuccessfulAuthentication(request, response, failed);
onUnsuccessfulAuthentication(request, response, ex);
if (this.ignoreFailure) {
chain.doFilter(request, response);
}
else {
this.authenticationEntryPoint.commence(request, response, failed);
this.authenticationEntryPoint.commence(request, response, ex);
}
return;
}
@ -196,40 +177,26 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
private boolean authenticationIsRequired(String username) {
// Only reauthenticate if username doesn't match SecurityContextHolder and user
// isn't authenticated
// (see SEC-53)
// isn't authenticated (see SEC-53)
Authentication existingAuth = SecurityContextHolder.getContext().getAuthentication();
if (existingAuth == null || !existingAuth.isAuthenticated()) {
return true;
}
// Limit username comparison to providers which use usernames (ie
// UsernamePasswordAuthenticationToken)
// (see SEC-348)
// UsernamePasswordAuthenticationToken) (see SEC-348)
if (existingAuth instanceof UsernamePasswordAuthenticationToken && !existingAuth.getName().equals(username)) {
return true;
}
// Handle unusual condition where an AnonymousAuthenticationToken is already
// present
// This shouldn't happen very often, as BasicProcessingFitler is meant to be
// earlier in the filter
// chain than AnonymousAuthenticationFilter. Nevertheless, presence of both an
// AnonymousAuthenticationToken
// together with a BASIC authentication request header should indicate
// reauthentication using the
// present. This shouldn't happen very often, as BasicProcessingFitler is meant to
// be earlier in the filter chain than AnonymousAuthenticationFilter.
// Nevertheless, presence of both an AnonymousAuthenticationToken together with a
// BASIC authentication request header should indicate reauthentication using the
// BASIC protocol is desirable. This behaviour is also consistent with that
// provided by form and digest,
// both of which force re-authentication if the respective header is detected (and
// in doing so replace
// any existing AnonymousAuthenticationToken). See SEC-610.
if (existingAuth instanceof AnonymousAuthenticationToken) {
return true;
}
return false;
// provided by form and digest, both of which force re-authentication if the
// respective header is detected (and in doing so replace/ any existing
// AnonymousAuthenticationToken). See SEC-610.
return (existingAuth instanceof AnonymousAuthenticationToken);
}
protected void onSuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response,

View File

@ -44,18 +44,14 @@ final class DigestAuthUtils {
if (str == null) {
return null;
}
int len = str.length();
if (len == 0) {
return EMPTY_STRING_ARRAY;
}
List<String> list = new ArrayList<>();
int i = 0;
int start = 0;
boolean match = false;
while (i < len) {
if (str.charAt(i) == '"') {
i++;
@ -83,7 +79,6 @@ final class DigestAuthUtils {
if (match) {
list.add(str.substring(start, i));
}
return list.toArray(new String[0]);
}
@ -108,32 +103,19 @@ final class DigestAuthUtils {
static String generateDigest(boolean passwordAlreadyEncoded, String username, String realm, String password,
String httpMethod, String uri, String qop, String nonce, String nc, String cnonce)
throws IllegalArgumentException {
String a1Md5;
String a2 = httpMethod + ":" + uri;
String a1Md5 = (!passwordAlreadyEncoded) ? DigestAuthUtils.encodePasswordInA1Format(username, realm, password)
: password;
String a2Md5 = md5Hex(a2);
if (passwordAlreadyEncoded) {
a1Md5 = password;
}
else {
a1Md5 = DigestAuthUtils.encodePasswordInA1Format(username, realm, password);
}
String digest;
if (qop == null) {
// as per RFC 2069 compliant clients (also reaffirmed by RFC 2617)
digest = a1Md5 + ":" + nonce + ":" + a2Md5;
return md5Hex(a1Md5 + ":" + nonce + ":" + a2Md5);
}
else if ("auth".equals(qop)) {
if ("auth".equals(qop)) {
// As per RFC 2617 compliant clients
digest = a1Md5 + ":" + nonce + ":" + nc + ":" + cnonce + ":" + qop + ":" + a2Md5;
return md5Hex(a1Md5 + ":" + nonce + ":" + nc + ":" + cnonce + ":" + qop + ":" + a2Md5);
}
else {
throw new IllegalArgumentException("This method does not support a qop: '" + qop + "'");
}
return md5Hex(digest);
throw new IllegalArgumentException("This method does not support a qop: '" + qop + "'");
}
/**
@ -157,28 +139,15 @@ final class DigestAuthUtils {
if ((array == null) || (array.length == 0)) {
return null;
}
Map<String, String> map = new HashMap<>();
for (String s : array) {
String postRemove;
if (removeCharacters == null) {
postRemove = s;
}
else {
postRemove = StringUtils.replace(s, removeCharacters, "");
}
String postRemove = (removeCharacters != null) ? StringUtils.replace(s, removeCharacters, "") : s;
String[] splitThisArrayElement = split(postRemove, delimiter);
if (splitThisArrayElement == null) {
continue;
}
map.put(splitThisArrayElement[0].trim(), splitThisArrayElement[1].trim());
}
return map;
}
@ -196,33 +165,24 @@ final class DigestAuthUtils {
static String[] split(String toSplit, String delimiter) {
Assert.hasLength(toSplit, "Cannot split a null or empty string");
Assert.hasLength(delimiter, "Cannot use a null or empty delimiter to split a string");
if (delimiter.length() != 1) {
throw new IllegalArgumentException("Delimiter can only be one character in length");
}
Assert.isTrue(delimiter.length() == 1, "Delimiter can only be one character in length");
int offset = toSplit.indexOf(delimiter);
if (offset < 0) {
return null;
}
String beforeDelimiter = toSplit.substring(0, offset);
String afterDelimiter = toSplit.substring(offset + 1);
return new String[] { beforeDelimiter, afterDelimiter };
}
static String md5Hex(String data) {
MessageDigest digest;
try {
digest = MessageDigest.getInstance("MD5");
MessageDigest digest = MessageDigest.getInstance("MD5");
return new String(Hex.encode(digest.digest(data.getBytes())));
}
catch (NoSuchAlgorithmException ex) {
throw new IllegalStateException("No MD5 algorithm available!");
}
return new String(Hex.encode(digest.digest(data.getBytes())));
}
}

View File

@ -27,9 +27,11 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.Ordered;
import org.springframework.core.log.LogMessage;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.util.Assert;
/**
* Used by the <code>SecurityEnforcementFilter</code> to commence authentication via the
@ -68,44 +70,30 @@ public class DigestAuthenticationEntryPoint implements AuthenticationEntryPoint,
@Override
public void afterPropertiesSet() {
if ((this.realmName == null) || "".equals(this.realmName)) {
throw new IllegalArgumentException("realmName must be specified");
}
if ((this.key == null) || "".equals(this.key)) {
throw new IllegalArgumentException("key must be specified");
}
Assert.hasLength(this.realmName, "realmName must be specified");
Assert.hasLength(this.key, "key must be specified");
}
@Override
public void commence(HttpServletRequest request, HttpServletResponse response,
AuthenticationException authException) throws IOException {
HttpServletResponse httpResponse = response;
// compute a nonce (do not use remote IP address due to proxy farms)
// format of nonce is:
// base64(expirationTime + ":" + md5Hex(expirationTime + ":" + key))
// compute a nonce (do not use remote IP address due to proxy farms) format of
// nonce is: base64(expirationTime + ":" + md5Hex(expirationTime + ":" + key))
long expiryTime = System.currentTimeMillis() + (this.nonceValiditySeconds * 1000);
String signatureValue = DigestAuthUtils.md5Hex(expiryTime + ":" + this.key);
String nonceValue = expiryTime + ":" + signatureValue;
String nonceValueBase64 = new String(Base64.getEncoder().encode(nonceValue.getBytes()));
// qop is quality of protection, as defined by RFC 2617.
// we do not use opaque due to IE violation of RFC 2617 in not
// representing opaque on subsequent requests in same session.
// qop is quality of protection, as defined by RFC 2617. We do not use opaque due
// to IE violation of RFC 2617 in not representing opaque on subsequent requests
// in same session.
String authenticateHeader = "Digest realm=\"" + this.realmName + "\", " + "qop=\"auth\", nonce=\""
+ nonceValueBase64 + "\"";
if (authException instanceof NonceExpiredException) {
authenticateHeader = authenticateHeader + ", stale=\"true\"";
}
if (logger.isDebugEnabled()) {
logger.debug("WWW-Authenticate header sent to user agent: " + authenticateHeader);
}
httpResponse.addHeader("WWW-Authenticate", authenticateHeader);
httpResponse.sendError(HttpStatus.UNAUTHORIZED.value(), HttpStatus.UNAUTHORIZED.getReasonPhrase());
logger.debug(LogMessage.format("WWW-Authenticate header sent to user agent: %s", authenticateHeader));
response.addHeader("WWW-Authenticate", authenticateHeader);
response.sendError(HttpStatus.UNAUTHORIZED.value(), HttpStatus.UNAUTHORIZED.getReasonPhrase());
}
public String getKey() {

View File

@ -33,6 +33,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.context.MessageSource;
import org.springframework.context.MessageSourceAware;
import org.springframework.context.support.MessageSourceAccessor;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.authentication.BadCredentialsException;
@ -112,136 +113,105 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
}
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req;
HttpServletResponse response = (HttpServletResponse) res;
doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
}
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
String header = request.getHeader("Authorization");
if (header == null || !header.startsWith("Digest ")) {
chain.doFilter(request, response);
return;
}
if (logger.isDebugEnabled()) {
logger.debug("Digest Authorization header received from user agent: " + header);
}
logger.debug(LogMessage.format("Digest Authorization header received from user agent: %s", header));
DigestData digestAuth = new DigestData(header);
try {
digestAuth.validateAndDecode(this.authenticationEntryPoint.getKey(),
this.authenticationEntryPoint.getRealmName());
}
catch (BadCredentialsException ex) {
fail(request, response, ex);
return;
}
// Lookup password for presented username
// NB: DAO-provided password MUST be clear text - not encoded/salted
// (unless this instance's passwordAlreadyEncoded property is 'false')
// Lookup password for presented username. N.B. DAO-provided password MUST be
// clear text - not encoded/salted (unless this instance's passwordAlreadyEncoded
// property is 'false')
boolean cacheWasUsed = true;
UserDetails user = this.userCache.getUserFromCache(digestAuth.getUsername());
String serverDigestMd5;
try {
if (user == null) {
cacheWasUsed = false;
user = this.userDetailsService.loadUserByUsername(digestAuth.getUsername());
if (user == null) {
throw new AuthenticationServiceException(
"AuthenticationDao returned null, which is an interface contract violation");
}
this.userCache.putUserInCache(user);
}
serverDigestMd5 = digestAuth.calculateServerDigest(user.getPassword(), request.getMethod());
// If digest is incorrect, try refreshing from backend and recomputing
if (!serverDigestMd5.equals(digestAuth.getResponse()) && cacheWasUsed) {
if (logger.isDebugEnabled()) {
logger.debug(
"Digest comparison failure; trying to refresh user from DAO in case password had changed");
}
logger.debug("Digest comparison failure; trying to refresh user from DAO in case password had changed");
user = this.userDetailsService.loadUserByUsername(digestAuth.getUsername());
this.userCache.putUserInCache(user);
serverDigestMd5 = digestAuth.calculateServerDigest(user.getPassword(), request.getMethod());
}
}
catch (UsernameNotFoundException notFound) {
fail(request, response,
new BadCredentialsException(this.messages.getMessage("DigestAuthenticationFilter.usernameNotFound",
new Object[] { digestAuth.getUsername() }, "Username {0} not found")));
catch (UsernameNotFoundException ex) {
String message = this.messages.getMessage("DigestAuthenticationFilter.usernameNotFound",
new Object[] { digestAuth.getUsername() }, "Username {0} not found");
fail(request, response, new BadCredentialsException(message));
return;
}
// If digest is still incorrect, definitely reject authentication attempt
if (!serverDigestMd5.equals(digestAuth.getResponse())) {
if (logger.isDebugEnabled()) {
logger.debug("Expected response: '" + serverDigestMd5 + "' but received: '" + digestAuth.getResponse()
+ "'; is AuthenticationDao returning clear text passwords?");
}
fail(request, response, new BadCredentialsException(
this.messages.getMessage("DigestAuthenticationFilter.incorrectResponse", "Incorrect response")));
logger.debug(LogMessage.format(
"Expected response: '%s' but received: '%s'; is AuthenticationDao returning clear text passwords?",
serverDigestMd5, digestAuth.getResponse()));
String message = this.messages.getMessage("DigestAuthenticationFilter.incorrectResponse",
"Incorrect response");
fail(request, response, new BadCredentialsException(message));
return;
}
// To get this far, the digest must have been valid
// Check the nonce has not expired
// We do this last so we can direct the user agent its nonce is stale
// but the request was otherwise appearing to be valid
if (digestAuth.isNonceExpired()) {
fail(request, response, new NonceExpiredException(this.messages
.getMessage("DigestAuthenticationFilter.nonceExpired", "Nonce has expired/timed out")));
String message = this.messages.getMessage("DigestAuthenticationFilter.nonceExpired",
"Nonce has expired/timed out");
fail(request, response, new NonceExpiredException(message));
return;
}
if (logger.isDebugEnabled()) {
logger.debug("Authentication success for user: '" + digestAuth.getUsername() + "' with response: '"
+ digestAuth.getResponse() + "'");
}
logger.debug(LogMessage.format("Authentication success for user: '%s' with response: '%s'",
digestAuth.getUsername(), digestAuth.getResponse()));
Authentication authentication = createSuccessfulAuthentication(request, user);
SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authentication);
SecurityContextHolder.setContext(context);
chain.doFilter(request, response);
}
private Authentication createSuccessfulAuthentication(HttpServletRequest request, UserDetails user) {
UsernamePasswordAuthenticationToken authRequest;
if (this.createAuthenticatedToken) {
authRequest = new UsernamePasswordAuthenticationToken(user, user.getPassword(), user.getAuthorities());
}
else {
authRequest = new UsernamePasswordAuthenticationToken(user, user.getPassword());
}
UsernamePasswordAuthenticationToken authRequest = getAuthRequest(user);
authRequest.setDetails(this.authenticationDetailsSource.buildDetails(request));
return authRequest;
}
private UsernamePasswordAuthenticationToken getAuthRequest(UserDetails user) {
if (this.createAuthenticatedToken) {
return new UsernamePasswordAuthenticationToken(user, user.getPassword(), user.getAuthorities());
}
return new UsernamePasswordAuthenticationToken(user, user.getPassword());
}
private void fail(HttpServletRequest request, HttpServletResponse response, AuthenticationException failed)
throws IOException, ServletException {
SecurityContextHolder.getContext().setAuthentication(null);
if (logger.isDebugEnabled()) {
logger.debug(failed);
}
logger.debug(failed);
this.authenticationEntryPoint.commence(request, response, failed);
}
@ -326,7 +296,6 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
this.section212response = header.substring(7);
String[] headerEntries = DigestAuthUtils.splitIgnoringQuotes(this.section212response, ',');
Map<String, String> headerMap = DigestAuthUtils.splitEachArrayElementAndCreateMap(headerEntries, "=", "\"");
this.username = headerMap.get("username");
this.realm = headerMap.get("realm");
this.nonce = headerMap.get("nonce");
@ -335,11 +304,9 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
this.qop = headerMap.get("qop"); // RFC 2617 extension
this.nc = headerMap.get("nc"); // RFC 2617 extension
this.cnonce = headerMap.get("cnonce"); // RFC 2617 extension
if (logger.isDebugEnabled()) {
logger.debug("Extracted username: '" + this.username + "'; realm: '" + this.realm + "'; nonce: '"
+ this.nonce + "'; uri: '" + this.uri + "'; response: '" + this.response + "'");
}
logger.debug(
LogMessage.format("Extracted username: '%s'; realm: '%s'; nonce: '%s'; uri: '%s'; response: '%s'",
this.username, this.realm, this.nonce, this.uri, this.response));
}
void validateAndDecode(String entryPointKey, String expectedRealm) throws BadCredentialsException {
@ -353,23 +320,18 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
// Check all required parameters for an "auth" qop were supplied (ie RFC 2617)
if ("auth".equals(this.qop)) {
if ((this.nc == null) || (this.cnonce == null)) {
if (logger.isDebugEnabled()) {
logger.debug("extracted nc: '" + this.nc + "'; cnonce: '" + this.cnonce + "'");
}
logger.debug(LogMessage.format("extracted nc: '%s'; cnonce: '%s'", this.nc, this.cnonce));
throw new BadCredentialsException(DigestAuthenticationFilter.this.messages.getMessage(
"DigestAuthenticationFilter.missingAuth", new Object[] { this.section212response },
"Missing mandatory digest value; received header {0}"));
}
}
// Check realm name equals what we expected
if (!expectedRealm.equals(this.realm)) {
throw new BadCredentialsException(DigestAuthenticationFilter.this.messages.getMessage(
"DigestAuthenticationFilter.incorrectRealm", new Object[] { this.realm, expectedRealm },
"Response realm name '{0}' does not match system realm name of '{1}'"));
}
// Check nonce was Base64 encoded (as sent by DigestAuthenticationEntryPoint)
try {
Base64.getDecoder().decode(this.nonce.getBytes());
@ -379,21 +341,16 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
DigestAuthenticationFilter.this.messages.getMessage("DigestAuthenticationFilter.nonceEncoding",
new Object[] { this.nonce }, "Nonce is not encoded in Base64; received nonce {0}"));
}
// Decode nonce from Base64
// format of nonce is:
// base64(expirationTime + ":" + md5Hex(expirationTime + ":" + key))
// Decode nonce from Base64 format of nonce is: base64(expirationTime + ":" +
// md5Hex(expirationTime + ":" + key))
String nonceAsPlainText = new String(Base64.getDecoder().decode(this.nonce.getBytes()));
String[] nonceTokens = StringUtils.delimitedListToStringArray(nonceAsPlainText, ":");
if (nonceTokens.length != 2) {
throw new BadCredentialsException(DigestAuthenticationFilter.this.messages.getMessage(
"DigestAuthenticationFilter.nonceNotTwoTokens", new Object[] { nonceAsPlainText },
"Nonce should have yielded two tokens but was {0}"));
}
// Extract expiry time from nonce
try {
this.nonceExpiryTime = new Long(nonceTokens[0]);
}
@ -402,10 +359,8 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
"DigestAuthenticationFilter.nonceNotNumeric", new Object[] { nonceAsPlainText },
"Nonce token should have yielded a numeric first token, but was {0}"));
}
// Check signature of nonce matches this expiry time
String expectedNonceSignature = DigestAuthUtils.md5Hex(this.nonceExpiryTime + ":" + entryPointKey);
if (!expectedNonceSignature.equals(nonceTokens[1])) {
throw new BadCredentialsException(DigestAuthenticationFilter.this.messages.getMessage(
"DigestAuthenticationFilter.nonceCompromised", new Object[] { nonceAsPlainText },
@ -414,9 +369,8 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
}
String calculateServerDigest(String password, String httpMethod) {
// Compute the expected response-digest (will be in hex form)
// Don't catch IllegalArgumentException (already checked validity)
// Compute the expected response-digest (will be in hex form). Don't catch
// IllegalArgumentException (already checked validity)
return DigestAuthUtils.generateDigest(DigestAuthenticationFilter.this.passwordAlreadyEncoded, this.username,
this.realm, password, httpMethod, this.uri, this.qop, this.nonce, this.nc, this.cnonce);
}

View File

@ -105,9 +105,7 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet
if (authPrincipal.errorOnInvalidType()) {
throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType());
}
else {
return null;
}
return null;
}
return principal;
}

View File

@ -173,11 +173,8 @@ public abstract class AbstractSecurityWebApplicationInitializer implements WebAp
*/
private void registerFilters(ServletContext servletContext, boolean insertBeforeOtherFilters, Filter... filters) {
Assert.notEmpty(filters, "filters cannot be null or empty");
for (Filter filter : filters) {
if (filter == null) {
throw new IllegalArgumentException("filters cannot contain null values. Got " + Arrays.asList(filters));
}
Assert.notNull(filter, () -> "filters cannot contain null values. Got " + Arrays.asList(filters));
String filterName = Conventions.getVariableName(filter);
registerFilter(servletContext, insertBeforeOtherFilters, filterName, filter);
}
@ -195,10 +192,8 @@ public abstract class AbstractSecurityWebApplicationInitializer implements WebAp
private void registerFilter(ServletContext servletContext, boolean insertBeforeOtherFilters, String filterName,
Filter filter) {
Dynamic registration = servletContext.addFilter(filterName, filter);
if (registration == null) {
throw new IllegalStateException("Duplicate Filter registration for '" + filterName
+ "'. Check to ensure the Filter is only configured once.");
}
Assert.state(registration != null, () -> "Duplicate Filter registration for '" + filterName
+ "'. Check to ensure the Filter is only configured once.");
registration.setAsyncSupported(isAsyncSecuritySupported());
EnumSet<DispatcherType> dispatcherTypes = getSecurityDispatcherTypes();
registration.addMappingForUrlPatterns(dispatcherTypes, !insertBeforeOtherFilters, "/*");

View File

@ -28,6 +28,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationTrustResolver;
import org.springframework.security.authentication.AuthenticationTrustResolverImpl;
import org.springframework.security.core.Authentication;
@ -115,24 +116,18 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
HttpServletRequest request = requestResponseHolder.getRequest();
HttpServletResponse response = requestResponseHolder.getResponse();
HttpSession httpSession = request.getSession(false);
SecurityContext context = readSecurityContextFromSession(httpSession);
if (context == null) {
if (this.logger.isDebugEnabled()) {
this.logger.debug("No SecurityContext was available from the HttpSession: " + httpSession + ". "
+ "A new one will be created.");
}
this.logger.debug(LogMessage.format(
"No SecurityContext was available from the HttpSession: %s. A new one will be created.",
httpSession));
context = generateNewContext();
}
SaveToSessionResponseWrapper wrappedResponse = new SaveToSessionResponseWrapper(response, request,
httpSession != null, context);
requestResponseHolder.setResponse(wrappedResponse);
requestResponseHolder.setRequest(new SaveToSessionRequestWrapper(request, wrappedResponse));
return context;
}
@ -140,13 +135,10 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
public void saveContext(SecurityContext context, HttpServletRequest request, HttpServletResponse response) {
SaveContextOnUpdateOrErrorResponseWrapper responseWrapper = WebUtils.getNativeResponse(response,
SaveContextOnUpdateOrErrorResponseWrapper.class);
if (responseWrapper == null) {
throw new IllegalStateException("Cannot invoke saveContext on response " + response
+ ". You must use the HttpRequestResponseHolder.response after invoking loadContext");
}
// saveContext() might already be called by the response wrapper
// if something in the chain called sendError() or sendRedirect(). This ensures we
// only call it
Assert.state(responseWrapper != null, () -> "Cannot invoke saveContext on response " + response
+ ". You must use the HttpRequestResponseHolder.response after invoking loadContext");
// saveContext() might already be called by the response wrapper if something in
// the chain called sendError() or sendRedirect(). This ensures we only call it
// once per request.
if (!responseWrapper.isContextSaved()) {
responseWrapper.saveContext(context);
@ -156,11 +148,9 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
@Override
public boolean containsContext(HttpServletRequest request) {
HttpSession session = request.getSession(false);
if (session == null) {
return false;
}
return session.getAttribute(this.springSecurityContextKey) != null;
}
@ -168,47 +158,30 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
* @param httpSession the session obtained from the request.
*/
private SecurityContext readSecurityContextFromSession(HttpSession httpSession) {
final boolean debug = this.logger.isDebugEnabled();
if (httpSession == null) {
if (debug) {
this.logger.debug("No HttpSession currently exists");
}
this.logger.debug("No HttpSession currently exists");
return null;
}
// Session exists, so try to obtain a context from it.
Object contextFromSession = httpSession.getAttribute(this.springSecurityContextKey);
if (contextFromSession == null) {
if (debug) {
this.logger.debug("HttpSession returned null object for SPRING_SECURITY_CONTEXT");
}
this.logger.debug("HttpSession returned null object for SPRING_SECURITY_CONTEXT");
return null;
}
// We now have the security context object from the session.
if (!(contextFromSession instanceof SecurityContext)) {
if (this.logger.isWarnEnabled()) {
this.logger.warn(this.springSecurityContextKey + " did not contain a SecurityContext but contained: '"
+ contextFromSession + "'; are you improperly modifying the HttpSession directly "
+ "(you should always use SecurityContextHolder) or using the HttpSession attribute "
+ "reserved for this class?");
}
this.logger.warn(LogMessage.format(
"%s did not contain a SecurityContext but contained: '%s'; are you improperly "
+ "modifying the HttpSession directly (you should always use SecurityContextHolder) "
+ "or using the HttpSession attribute reserved for this class?",
this.springSecurityContextKey, contextFromSession));
return null;
}
if (debug) {
this.logger.debug("Obtained a valid SecurityContext from " + this.springSecurityContextKey + ": '"
+ contextFromSession + "'");
}
this.logger.debug(LogMessage.format("Obtained a valid SecurityContext from %s: '%s'",
this.springSecurityContextKey, contextFromSession));
// Everything OK. The only non-null return from this method.
return (SecurityContext) contextFromSession;
}
@ -306,6 +279,8 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
*/
final class SaveToSessionResponseWrapper extends SaveContextOnUpdateOrErrorResponseWrapper {
private final Log logger = HttpSessionSecurityContextRepository.this.logger;
private final HttpServletRequest request;
private final boolean httpSessionExistedAtStartOfRequest;
@ -349,41 +324,29 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
protected void saveContext(SecurityContext context) {
final Authentication authentication = context.getAuthentication();
HttpSession httpSession = this.request.getSession(false);
String springSecurityContextKey = HttpSessionSecurityContextRepository.this.springSecurityContextKey;
// See SEC-776
if (authentication == null
|| HttpSessionSecurityContextRepository.this.trustResolver.isAnonymous(authentication)) {
if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) {
HttpSessionSecurityContextRepository.this.logger.debug(
"SecurityContext is empty or contents are anonymous - context will not be stored in HttpSession.");
}
this.logger.debug("SecurityContext is empty or contents are anonymous - "
+ "context will not be stored in HttpSession.");
if (httpSession != null && this.authBeforeExecution != null) {
// SEC-1587 A non-anonymous context may still be in the session
// SEC-1735 remove if the contextBeforeExecution was not anonymous
httpSession.removeAttribute(HttpSessionSecurityContextRepository.this.springSecurityContextKey);
httpSession.removeAttribute(springSecurityContextKey);
}
return;
}
if (httpSession == null) {
httpSession = createNewSessionIfAllowed(context);
}
httpSession = (httpSession != null) ? httpSession : createNewSessionIfAllowed(context);
// If HttpSession exists, store current SecurityContext but only if it has
// actually changed in this thread (see SEC-37, SEC-1307, SEC-1528)
if (httpSession != null) {
// We may have a new session, so check also whether the context attribute
// is set SEC-1561
if (contextChanged(context) || httpSession
.getAttribute(HttpSessionSecurityContextRepository.this.springSecurityContextKey) == null) {
httpSession.setAttribute(HttpSessionSecurityContextRepository.this.springSecurityContextKey,
context);
if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) {
HttpSessionSecurityContextRepository.this.logger
.debug("SecurityContext '" + context + "' stored to HttpSession: '" + httpSession);
}
if (contextChanged(context) || httpSession.getAttribute(springSecurityContextKey) == null) {
httpSession.setAttribute(springSecurityContextKey, context);
this.logger.debug(LogMessage.format("SecurityContext '%s' stored to HttpSession: '%s'", context,
httpSession));
}
}
}
@ -396,56 +359,37 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
if (isTransientAuthentication(context.getAuthentication())) {
return null;
}
if (this.httpSessionExistedAtStartOfRequest) {
if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) {
HttpSessionSecurityContextRepository.this.logger
.debug("HttpSession is now null, but was not null at start of request; "
+ "session was invalidated, so do not create a new session");
}
this.logger.debug("HttpSession is now null, but was not null at start of request; "
+ "session was invalidated, so do not create a new session");
return null;
}
if (!HttpSessionSecurityContextRepository.this.allowSessionCreation) {
if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) {
HttpSessionSecurityContextRepository.this.logger.debug("The HttpSession is currently null, and the "
+ HttpSessionSecurityContextRepository.class.getSimpleName()
+ " is prohibited from creating an HttpSession "
+ "(because the allowSessionCreation property is false) - SecurityContext thus not "
+ "stored for next request");
}
this.logger.debug("The HttpSession is currently null, and the "
+ HttpSessionSecurityContextRepository.class.getSimpleName()
+ " is prohibited from creating an HttpSession "
+ "(because the allowSessionCreation property is false) - SecurityContext thus not "
+ "stored for next request");
return null;
}
// Generate a HttpSession only if we need to
if (HttpSessionSecurityContextRepository.this.contextObject.equals(context)) {
if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) {
HttpSessionSecurityContextRepository.this.logger.debug(
"HttpSession is null, but SecurityContext has not changed from default empty context: ' "
+ context + "'; not creating HttpSession or storing SecurityContext");
}
this.logger.debug(LogMessage.format(
"HttpSession is null, but SecurityContext has not changed from "
+ "default empty context: '%s'; not creating HttpSession or storing SecurityContext",
context));
return null;
}
if (HttpSessionSecurityContextRepository.this.logger.isDebugEnabled()) {
HttpSessionSecurityContextRepository.this.logger
.debug("HttpSession being created as SecurityContext is non-default");
}
this.logger.debug("HttpSession being created as SecurityContext is non-default");
try {
return this.request.getSession(true);
}
catch (IllegalStateException ex) {
// Response must already be committed, therefore can't create a new
// session
HttpSessionSecurityContextRepository.this.logger
.warn("Failed to create a session, as response has been committed. Unable to store"
+ " SecurityContext.");
this.logger.warn("Failed to create a session, as response has been committed. "
+ "Unable to store SecurityContext.");
}
return null;
}

View File

@ -44,7 +44,7 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends OnCommit
private boolean contextSaved = false;
/* See SEC-1052 */
// See SEC-1052
private final boolean disableUrlRewriting;
/**

View File

@ -26,6 +26,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.web.filter.GenericFilterBean;
@ -74,49 +75,36 @@ public class SecurityContextPersistenceFilter extends GenericFilterBean {
}
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req;
HttpServletResponse response = (HttpServletResponse) res;
doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
}
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
// ensure that filter is only applied once per request
if (request.getAttribute(FILTER_APPLIED) != null) {
// ensure that filter is only applied once per request
chain.doFilter(request, response);
return;
}
final boolean debug = this.logger.isDebugEnabled();
request.setAttribute(FILTER_APPLIED, Boolean.TRUE);
if (this.forceEagerSessionCreation) {
HttpSession session = request.getSession();
if (debug && session.isNew()) {
this.logger.debug("Eagerly created session: " + session.getId());
}
this.logger.debug(LogMessage.format("Eagerly created session: %s", session.getId()));
}
HttpRequestResponseHolder holder = new HttpRequestResponseHolder(request, response);
SecurityContext contextBeforeChainExecution = this.repo.loadContext(holder);
try {
SecurityContextHolder.setContext(contextBeforeChainExecution);
chain.doFilter(holder.getRequest(), holder.getResponse());
}
finally {
SecurityContext contextAfterChainExecution = SecurityContextHolder.getContext();
// Crucial removal of SecurityContextHolder contents - do this before anything
// else.
// Crucial removal of SecurityContextHolder contents before anything else.
SecurityContextHolder.clearContext();
this.repo.saveContext(contextAfterChainExecution, holder.getRequest(), holder.getResponse());
request.removeAttribute(FILTER_APPLIED);
if (debug) {
this.logger.debug("SecurityContextHolder now cleared, as request processing completed");
}
this.logger.debug("SecurityContextHolder now cleared, as request processing completed");
}
}

View File

@ -46,14 +46,12 @@ public final class WebAsyncManagerIntegrationFilter extends OncePerRequestFilter
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request);
SecurityContextCallableProcessingInterceptor securityProcessingInterceptor = (SecurityContextCallableProcessingInterceptor) asyncManager
.getCallableInterceptor(CALLABLE_INTERCEPTOR_KEY);
if (securityProcessingInterceptor == null) {
asyncManager.registerCallableInterceptor(CALLABLE_INTERCEPTOR_KEY,
new SecurityContextCallableProcessingInterceptor());
}
filterChain.doFilter(request, response);
}

View File

@ -20,6 +20,7 @@ import java.util.Enumeration;
import javax.servlet.ServletContext;
import org.springframework.util.Assert;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.WebApplicationContextUtils;
@ -47,11 +48,10 @@ public abstract class SecurityWebApplicationContextUtils extends WebApplicationC
* @see ServletContext#getAttributeNames()
*/
public static WebApplicationContext findRequiredWebApplicationContext(ServletContext servletContext) {
WebApplicationContext wac = _findWebApplicationContext(servletContext);
if (wac == null) {
throw new IllegalStateException("No WebApplicationContext found: no ContextLoaderListener registered?");
}
return wac;
WebApplicationContext webApplicationContext = compatiblyFindWebApplicationContext(servletContext);
Assert.state(webApplicationContext != null,
"No WebApplicationContext found: no ContextLoaderListener registered?");
return webApplicationContext;
}
/**
@ -59,23 +59,21 @@ public abstract class SecurityWebApplicationContextUtils extends WebApplicationC
* spring framework 4.1.x.
* @see #findWebApplicationContext(ServletContext)
*/
private static WebApplicationContext _findWebApplicationContext(ServletContext sc) {
WebApplicationContext wac = getWebApplicationContext(sc);
if (wac == null) {
private static WebApplicationContext compatiblyFindWebApplicationContext(ServletContext sc) {
WebApplicationContext webApplicationContext = getWebApplicationContext(sc);
if (webApplicationContext == null) {
Enumeration<String> attrNames = sc.getAttributeNames();
while (attrNames.hasMoreElements()) {
String attrName = attrNames.nextElement();
Object attrValue = sc.getAttribute(attrName);
if (attrValue instanceof WebApplicationContext) {
if (wac != null) {
throw new IllegalStateException("No unique WebApplicationContext found: more than one "
+ "DispatcherServlet registered with publishContext=true?");
}
wac = (WebApplicationContext) attrValue;
Assert.state(webApplicationContext == null, "No unique WebApplicationContext found: more than one "
+ "DispatcherServlet registered with publishContext=true?");
webApplicationContext = (WebApplicationContext) attrValue;
}
}
}
return wac;
return webApplicationContext;
}
}

View File

@ -69,30 +69,13 @@ public final class CookieCsrfTokenRepository implements CsrfTokenRepository {
public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) {
String tokenValue = (token != null) ? token.getToken() : "";
Cookie cookie = new Cookie(this.cookieName, tokenValue);
if (this.secure == null) {
cookie.setSecure(request.isSecure());
}
else {
cookie.setSecure(this.secure);
}
if (this.cookiePath != null && !this.cookiePath.isEmpty()) {
cookie.setPath(this.cookiePath);
}
else {
cookie.setPath(this.getRequestContext(request));
}
if (token == null) {
cookie.setMaxAge(0);
}
else {
cookie.setMaxAge(-1);
}
cookie.setSecure((this.secure != null) ? this.secure : request.isSecure());
cookie.setPath(StringUtils.hasLength(this.cookiePath) ? this.cookiePath : this.getRequestContext(request));
cookie.setMaxAge((token != null) ? -1 : 0);
cookie.setHttpOnly(this.cookieHttpOnly);
if (this.cookieDomain != null && !this.cookieDomain.isEmpty()) {
if (StringUtils.hasLength(this.cookieDomain)) {
cookie.setDomain(this.cookieDomain);
}
response.addCookie(cookie);
}

View File

@ -51,10 +51,8 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt
boolean containsToken = this.csrfTokenRepository.loadToken(request) != null;
if (containsToken) {
this.csrfTokenRepository.saveToken(null, request, response);
CsrfToken newToken = this.csrfTokenRepository.generateToken(request);
this.csrfTokenRepository.saveToken(newToken, request, response);
request.setAttribute(CsrfToken.class.getName(), newToken);
request.setAttribute(newToken.getParameterName(), newToken);
}

View File

@ -29,6 +29,8 @@ import javax.servlet.http.HttpSession;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.access.AccessDeniedHandlerImpl;
import org.springframework.security.web.util.UrlUtils;
@ -97,39 +99,30 @@ public final class CsrfFilter extends OncePerRequestFilter {
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
request.setAttribute(HttpServletResponse.class.getName(), response);
CsrfToken csrfToken = this.tokenRepository.loadToken(request);
final boolean missingToken = csrfToken == null;
boolean missingToken = (csrfToken == null);
if (missingToken) {
csrfToken = this.tokenRepository.generateToken(request);
this.tokenRepository.saveToken(csrfToken, request, response);
}
request.setAttribute(CsrfToken.class.getName(), csrfToken);
request.setAttribute(csrfToken.getParameterName(), csrfToken);
if (!this.requireCsrfProtectionMatcher.matches(request)) {
filterChain.doFilter(request, response);
return;
}
String actualToken = request.getHeader(csrfToken.getHeaderName());
if (actualToken == null) {
actualToken = request.getParameter(csrfToken.getParameterName());
}
if (!csrfToken.getToken().equals(actualToken)) {
if (this.logger.isDebugEnabled()) {
this.logger.debug("Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request));
}
if (missingToken) {
this.accessDeniedHandler.handle(request, response, new MissingCsrfTokenException(actualToken));
}
else {
this.accessDeniedHandler.handle(request, response,
new InvalidCsrfTokenException(csrfToken, actualToken));
}
this.logger.debug(
LogMessage.of(() -> "Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request)));
AccessDeniedException exception = (!missingToken) ? new InvalidCsrfTokenException(csrfToken, actualToken)
: new MissingCsrfTokenException(actualToken);
this.accessDeniedHandler.handle(request, response, exception);
return;
}
filterChain.doFilter(request, response);
}

View File

@ -24,7 +24,6 @@ import java.io.Serializable;
* @author Rob Winch
* @since 3.2
* @see DefaultCsrfToken
*
*/
public interface CsrfToken extends Serializable {

View File

@ -87,11 +87,8 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository {
private HttpServletResponse getResponse(HttpServletRequest request) {
HttpServletResponse response = (HttpServletResponse) request.getAttribute(HTTP_RESPONSE_ATTR);
if (response == null) {
throw new IllegalArgumentException(
"The HttpServletRequest attribute must contain an HttpServletResponse for the attribute "
+ HTTP_RESPONSE_ATTR);
}
Assert.notNull(response, () -> "The HttpServletRequest attribute must contain an HttpServletResponse "
+ "for the attribute " + HTTP_RESPONSE_ATTR);
return response;
}
@ -166,7 +163,6 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository {
if (this.tokenRepository == null) {
return;
}
synchronized (this) {
if (this.tokenRepository != null) {
this.tokenRepository.saveToken(this.delegate, this.request, this.response);

View File

@ -50,35 +50,35 @@ public final class DebugFilter implements Filter {
static final String ALREADY_FILTERED_ATTR_NAME = DebugFilter.class.getName().concat(".FILTERED");
private final FilterChainProxy fcp;
private final FilterChainProxy filterChainProxy;
private final Logger logger = new Logger();
public DebugFilter(FilterChainProxy fcp) {
this.fcp = fcp;
public DebugFilter(FilterChainProxy filterChainProxy) {
this.filterChainProxy = filterChainProxy;
}
@Override
public void doFilter(ServletRequest srvltRequest, ServletResponse srvltResponse, FilterChain filterChain)
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
if (!(srvltRequest instanceof HttpServletRequest) || !(srvltResponse instanceof HttpServletResponse)) {
if (!(request instanceof HttpServletRequest) || !(response instanceof HttpServletResponse)) {
throw new ServletException("DebugFilter just supports HTTP requests");
}
HttpServletRequest request = (HttpServletRequest) srvltRequest;
HttpServletResponse response = (HttpServletResponse) srvltResponse;
doFilter((HttpServletRequest) request, (HttpServletResponse) response, filterChain);
}
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws IOException, ServletException {
List<Filter> filters = getFilters(request);
this.logger.info("Request received for " + request.getMethod() + " '" + UrlUtils.buildRequestUrl(request)
+ "':\n\n" + request + "\n\n" + "servletPath:" + request.getServletPath() + "\n" + "pathInfo:"
+ request.getPathInfo() + "\n" + "headers: \n" + formatHeaders(request) + "\n\n"
+ formatFilters(filters));
if (request.getAttribute(ALREADY_FILTERED_ATTR_NAME) == null) {
invokeWithWrappedRequest(request, response, filterChain);
}
else {
this.fcp.doFilter(request, response, filterChain);
this.filterChainProxy.doFilter(request, response, filterChain);
}
}
@ -87,7 +87,7 @@ public final class DebugFilter implements Filter {
request.setAttribute(ALREADY_FILTERED_ATTR_NAME, Boolean.TRUE);
request = new DebugRequestWrapper(request);
try {
this.fcp.doFilter(request, response, filterChain);
this.filterChainProxy.doFilter(request, response, filterChain);
}
finally {
request.removeAttribute(ALREADY_FILTERED_ATTR_NAME);
@ -134,7 +134,7 @@ public final class DebugFilter implements Filter {
}
private List<Filter> getFilters(HttpServletRequest request) {
for (SecurityFilterChain chain : this.fcp.getFilterChains()) {
for (SecurityFilterChain chain : this.filterChainProxy.getFilterChains()) {
if (chain.matches(request)) {
return chain.getFilters();
}
@ -163,11 +163,9 @@ public final class DebugFilter implements Filter {
public HttpSession getSession() {
boolean sessionExists = super.getSession(false) != null;
HttpSession session = super.getSession();
if (!sessionExists) {
DebugRequestWrapper.logger.info("New HTTP session created: " + session.getId(), true);
}
return session;
}

View File

@ -50,19 +50,17 @@ public class DefaultHttpFirewall implements HttpFirewall {
@Override
public FirewalledRequest getFirewalledRequest(HttpServletRequest request) throws RequestRejectedException {
FirewalledRequest fwr = new RequestWrapper(request);
if (!isNormalized(fwr.getServletPath()) || !isNormalized(fwr.getPathInfo())) {
throw new RequestRejectedException("Un-normalized paths are not supported: " + fwr.getServletPath()
+ ((fwr.getPathInfo() != null) ? fwr.getPathInfo() : ""));
FirewalledRequest firewalledRequest = new RequestWrapper(request);
if (!isNormalized(firewalledRequest.getServletPath()) || !isNormalized(firewalledRequest.getPathInfo())) {
throw new RequestRejectedException(
"Un-normalized paths are not supported: " + firewalledRequest.getServletPath()
+ ((firewalledRequest.getPathInfo() != null) ? firewalledRequest.getPathInfo() : ""));
}
String requestURI = fwr.getRequestURI();
String requestURI = firewalledRequest.getRequestURI();
if (containsInvalidUrlEncodedSlash(requestURI)) {
throw new RequestRejectedException("The requestURI cannot contain encoded slash. Got " + requestURI);
}
return fwr;
return firewalledRequest;
}
@Override
@ -89,11 +87,9 @@ public class DefaultHttpFirewall implements HttpFirewall {
if (this.allowUrlEncodedSlash || uri == null) {
return false;
}
if (uri.contains("%2f") || uri.contains("%2F")) {
return true;
}
return false;
}
@ -107,22 +103,18 @@ public class DefaultHttpFirewall implements HttpFirewall {
if (path == null) {
return true;
}
for (int j = path.length(); j > 0;) {
int i = path.lastIndexOf('/', j - 1);
int gap = j - i;
if (gap == 2 && path.charAt(i + 1) == '.') {
for (int i = path.length(); i > 0;) {
int slashIndex = path.lastIndexOf('/', i - 1);
int gap = i - slashIndex;
if (gap == 2 && path.charAt(slashIndex + 1) == '.') {
// ".", "/./" or "/."
return false;
}
else if (gap == 3 && path.charAt(i + 1) == '.' && path.charAt(i + 2) == '.') {
if (gap == 3 && path.charAt(slashIndex + 1) == '.' && path.charAt(slashIndex + 2) == '.') {
return false;
}
j = i;
i = slashIndex;
}
return true;
}

View File

@ -22,6 +22,8 @@ import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import org.springframework.util.Assert;
/**
* @author Luke Taylor
* @author Eddú Meléndez
@ -71,9 +73,7 @@ class FirewalledResponse extends HttpServletResponseWrapper {
}
void validateCrlf(String name, String value) {
if (hasCrlf(name) || hasCrlf(value)) {
throw new IllegalArgumentException("Invalid characters (CR/LF) in header " + name);
}
Assert.isTrue(!hasCrlf(name) && !hasCrlf(value), () -> "Invalid characters (CR/LF) in header " + name);
}
private boolean hasCrlf(String value) {

View File

@ -24,6 +24,8 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
/**
* A simple implementation of {@link RequestRejectedHandler} that sends an error with
* configurable status code.
@ -55,10 +57,8 @@ public class HttpStatusRequestRejectedHandler implements RequestRejectedHandler
@Override
public void handle(HttpServletRequest request, HttpServletResponse response,
RequestRejectedException requestRejectedException) throws IOException {
if (logger.isDebugEnabled()) {
logger.debug("Rejecting request due to: " + requestRejectedException.getMessage(),
requestRejectedException);
}
logger.debug(LogMessage.format("Rejecting request due to: %s", requestRejectedException.getMessage()),
requestRejectedException);
response.sendError(this.httpError);
}

View File

@ -74,10 +74,8 @@ final class RequestWrapper extends FirewalledRequest {
if (path == null) {
return null;
}
int scIndex = path.indexOf(';');
if (scIndex < 0) {
int semicolonIndex = path.indexOf(';');
if (semicolonIndex < 0) {
int doubleSlashIndex = path.indexOf("//");
if (doubleSlashIndex < 0) {
// Most likely case, no parameters in any segment and no '//', so no
@ -85,29 +83,23 @@ final class RequestWrapper extends FirewalledRequest {
return path;
}
}
StringTokenizer st = new StringTokenizer(path, "/");
StringTokenizer tokenizer = new StringTokenizer(path, "/");
StringBuilder stripped = new StringBuilder(path.length());
if (path.charAt(0) == '/') {
stripped.append('/');
}
while (st.hasMoreTokens()) {
String segment = st.nextToken();
scIndex = segment.indexOf(';');
if (scIndex >= 0) {
segment = segment.substring(0, scIndex);
while (tokenizer.hasMoreTokens()) {
String segment = tokenizer.nextToken();
semicolonIndex = segment.indexOf(';');
if (semicolonIndex >= 0) {
segment = segment.substring(0, semicolonIndex);
}
stripped.append(segment).append('/');
}
// Remove the trailing slash if the original path didn't have one
if (path.charAt(path.length() - 1) != '/') {
stripped.deleteCharAt(stripped.length() - 1);
}
return stripped.toString();
}

View File

@ -31,6 +31,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpMethod;
import org.springframework.util.Assert;
/**
* <p>
@ -83,7 +84,7 @@ public class StrictHttpFirewall implements HttpFirewall {
* Used to specify to {@link #setAllowedHttpMethods(Collection)} that any HTTP method
* should be allowed.
*/
private static final Set<String> ALLOW_ANY_HTTP_METHOD = Collections.unmodifiableSet(Collections.emptySet());
private static final Set<String> ALLOW_ANY_HTTP_METHOD = Collections.emptySet();
private static final String ENCODED_PERCENT = "%25";
@ -165,15 +166,9 @@ public class StrictHttpFirewall implements HttpFirewall {
* @see #setUnsafeAllowAnyHttpMethod(boolean)
*/
public void setAllowedHttpMethods(Collection<String> allowedHttpMethods) {
if (allowedHttpMethods == null) {
throw new IllegalArgumentException("allowedHttpMethods cannot be null");
}
if (allowedHttpMethods == ALLOW_ANY_HTTP_METHOD) {
this.allowedHttpMethods = ALLOW_ANY_HTTP_METHOD;
}
else {
this.allowedHttpMethods = new HashSet<>(allowedHttpMethods);
}
Assert.notNull(allowedHttpMethods, "allowedHttpMethods cannot be null");
this.allowedHttpMethods = (allowedHttpMethods != ALLOW_ANY_HTTP_METHOD) ? new HashSet<>(allowedHttpMethods)
: ALLOW_ANY_HTTP_METHOD;
}
/**
@ -361,9 +356,7 @@ public class StrictHttpFirewall implements HttpFirewall {
* @see Character#isDefined(int)
*/
public void setAllowedHeaderNames(Predicate<String> allowedHeaderNames) {
if (allowedHeaderNames == null) {
throw new IllegalArgumentException("allowedHeaderNames cannot be null");
}
Assert.notNull(allowedHeaderNames, "allowedHeaderNames cannot be null");
this.allowedHeaderNames = allowedHeaderNames;
}
@ -378,28 +371,20 @@ public class StrictHttpFirewall implements HttpFirewall {
* @see Character#isDefined(int)
*/
public void setAllowedHeaderValues(Predicate<String> allowedHeaderValues) {
if (allowedHeaderValues == null) {
throw new IllegalArgumentException("allowedHeaderValues cannot be null");
}
Assert.notNull(allowedHeaderValues, "allowedHeaderValues cannot be null");
this.allowedHeaderValues = allowedHeaderValues;
}
/*
/**
* Determines which parameter names should be allowed. The default is to reject header
* names that contain ISO control characters and characters that are not defined. </p>
*
* names that contain ISO control characters and characters that are not defined.
* @param allowedParameterNames the predicate for testing parameter names
*
* @see Character#isISOControl(int)
*
* @see Character#isDefined(int)
*
* @since 5.4
* @see Character#isISOControl(int)
* @see Character#isDefined(int)
*/
public void setAllowedParameterNames(Predicate<String> allowedParameterNames) {
if (allowedParameterNames == null) {
throw new IllegalArgumentException("allowedParameterNames cannot be null");
}
Assert.notNull(allowedParameterNames, "allowedParameterNames cannot be null");
this.allowedParameterNames = allowedParameterNames;
}
@ -412,9 +397,7 @@ public class StrictHttpFirewall implements HttpFirewall {
* @since 5.4
*/
public void setAllowedParameterValues(Predicate<String> allowedParameterValues) {
if (allowedParameterValues == null) {
throw new IllegalArgumentException("allowedParameterValues cannot be null");
}
Assert.notNull(allowedParameterValues, "allowedParameterValues cannot be null");
this.allowedParameterValues = allowedParameterValues;
}
@ -426,9 +409,7 @@ public class StrictHttpFirewall implements HttpFirewall {
* @since 5.2
*/
public void setAllowedHostnames(Predicate<String> allowedHostnames) {
if (allowedHostnames == null) {
throw new IllegalArgumentException("allowedHostnames cannot be null");
}
Assert.notNull(allowedHostnames, "allowedHostnames cannot be null");
this.allowedHostnames = allowedHostnames;
}
@ -447,173 +428,15 @@ public class StrictHttpFirewall implements HttpFirewall {
rejectForbiddenHttpMethod(request);
rejectedBlocklistedUrls(request);
rejectedUntrustedHosts(request);
if (!isNormalized(request)) {
throw new RequestRejectedException("The request was rejected because the URL was not normalized.");
}
String requestUri = request.getRequestURI();
if (!containsOnlyPrintableAsciiCharacters(requestUri)) {
throw new RequestRejectedException(
"The requestURI was rejected because it can only contain printable ASCII characters.");
}
return new FirewalledRequest(request) {
@Override
public long getDateHeader(String name) {
if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) {
throw new RequestRejectedException(
"The request was rejected because the header name \"" + name + "\" is not allowed.");
}
return super.getDateHeader(name);
}
@Override
public int getIntHeader(String name) {
if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) {
throw new RequestRejectedException(
"The request was rejected because the header name \"" + name + "\" is not allowed.");
}
return super.getIntHeader(name);
}
@Override
public String getHeader(String name) {
if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) {
throw new RequestRejectedException(
"The request was rejected because the header name \"" + name + "\" is not allowed.");
}
String value = super.getHeader(name);
if (value != null && !StrictHttpFirewall.this.allowedHeaderValues.test(value)) {
throw new RequestRejectedException(
"The request was rejected because the header value \"" + value + "\" is not allowed.");
}
return value;
}
@Override
public Enumeration<String> getHeaders(String name) {
if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) {
throw new RequestRejectedException(
"The request was rejected because the header name \"" + name + "\" is not allowed.");
}
Enumeration<String> valuesEnumeration = super.getHeaders(name);
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return valuesEnumeration.hasMoreElements();
}
@Override
public String nextElement() {
String value = valuesEnumeration.nextElement();
if (!StrictHttpFirewall.this.allowedHeaderValues.test(value)) {
throw new RequestRejectedException("The request was rejected because the header value \""
+ value + "\" is not allowed.");
}
return value;
}
};
}
@Override
public Enumeration<String> getHeaderNames() {
Enumeration<String> namesEnumeration = super.getHeaderNames();
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return namesEnumeration.hasMoreElements();
}
@Override
public String nextElement() {
String name = namesEnumeration.nextElement();
if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the header name \""
+ name + "\" is not allowed.");
}
return name;
}
};
}
@Override
public String getParameter(String name) {
if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) {
throw new RequestRejectedException(
"The request was rejected because the parameter name \"" + name + "\" is not allowed.");
}
String value = super.getParameter(name);
if (value != null && !StrictHttpFirewall.this.allowedParameterValues.test(value)) {
throw new RequestRejectedException(
"The request was rejected because the parameter value \"" + value + "\" is not allowed.");
}
return value;
}
@Override
public Map<String, String[]> getParameterMap() {
Map<String, String[]> parameterMap = super.getParameterMap();
for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
String name = entry.getKey();
String[] values = entry.getValue();
if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) {
throw new RequestRejectedException(
"The request was rejected because the parameter name \"" + name + "\" is not allowed.");
}
for (String value : values) {
if (!StrictHttpFirewall.this.allowedParameterValues.test(value)) {
throw new RequestRejectedException("The request was rejected because the parameter value \""
+ value + "\" is not allowed.");
}
}
}
return parameterMap;
}
@Override
public Enumeration<String> getParameterNames() {
Enumeration<String> namesEnumeration = super.getParameterNames();
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return namesEnumeration.hasMoreElements();
}
@Override
public String nextElement() {
String name = namesEnumeration.nextElement();
if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the parameter name \""
+ name + "\" is not allowed.");
}
return name;
}
};
}
@Override
public String[] getParameterValues(String name) {
if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) {
throw new RequestRejectedException(
"The request was rejected because the parameter name \"" + name + "\" is not allowed.");
}
String[] values = super.getParameterValues(name);
if (values != null) {
for (String value : values) {
if (!StrictHttpFirewall.this.allowedParameterValues.test(value)) {
throw new RequestRejectedException("The request was rejected because the parameter value \""
+ value + "\" is not allowed.");
}
}
}
return values;
}
@Override
public void reset() {
}
};
return new StrictFirewalledRequest(request);
}
private void rejectForbiddenHttpMethod(HttpServletRequest request) {
@ -705,12 +528,11 @@ public class StrictHttpFirewall implements HttpFirewall {
private static boolean containsOnlyPrintableAsciiCharacters(String uri) {
int length = uri.length();
for (int i = 0; i < length; i++) {
char c = uri.charAt(i);
if (c < '\u0020' || c > '\u007e') {
char ch = uri.charAt(i);
if (ch < '\u0020' || ch > '\u007e') {
return false;
}
}
return true;
}
@ -728,22 +550,17 @@ public class StrictHttpFirewall implements HttpFirewall {
if (path == null) {
return true;
}
for (int j = path.length(); j > 0;) {
int i = path.lastIndexOf('/', j - 1);
int gap = j - i;
if (gap == 2 && path.charAt(i + 1) == '.') {
// ".", "/./" or "/."
for (int i = path.length(); i > 0;) {
int slashIndex = path.lastIndexOf('/', i - 1);
int gap = i - slashIndex;
if (gap == 2 && path.charAt(slashIndex + 1) == '.') {
return false; // ".", "/./" or "/."
}
if (gap == 3 && path.charAt(slashIndex + 1) == '.' && path.charAt(slashIndex + 2) == '.') {
return false;
}
else if (gap == 3 && path.charAt(i + 1) == '.' && path.charAt(i + 2) == '.') {
return false;
}
j = i;
i = slashIndex;
}
return true;
}
@ -782,4 +599,166 @@ public class StrictHttpFirewall implements HttpFirewall {
return getDecodedUrlBlocklist();
}
/**
* Strict {@link FirewalledRequest}.
*/
private class StrictFirewalledRequest extends FirewalledRequest {
StrictFirewalledRequest(HttpServletRequest request) {
super(request);
}
@Override
public long getDateHeader(String name) {
validateAllowedHeaderName(name);
return super.getDateHeader(name);
}
@Override
public int getIntHeader(String name) {
validateAllowedHeaderName(name);
return super.getIntHeader(name);
}
@Override
public String getHeader(String name) {
validateAllowedHeaderName(name);
String value = super.getHeader(name);
if (value != null) {
validateAllowedHeaderValue(value);
}
return value;
}
@Override
public Enumeration<String> getHeaders(String name) {
validateAllowedHeaderName(name);
Enumeration<String> headers = super.getHeaders(name);
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return headers.hasMoreElements();
}
@Override
public String nextElement() {
String value = headers.nextElement();
validateAllowedHeaderValue(value);
return value;
}
};
}
@Override
public Enumeration<String> getHeaderNames() {
Enumeration<String> names = super.getHeaderNames();
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return names.hasMoreElements();
}
@Override
public String nextElement() {
String headerNames = names.nextElement();
validateAllowedHeaderName(headerNames);
return headerNames;
}
};
}
@Override
public String getParameter(String name) {
validateAllowedParameterName(name);
String value = super.getParameter(name);
if (value != null) {
validateAllowedParameterValue(value);
}
return value;
}
@Override
public Map<String, String[]> getParameterMap() {
Map<String, String[]> parameterMap = super.getParameterMap();
for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
String name = entry.getKey();
String[] values = entry.getValue();
validateAllowedParameterName(name);
for (String value : values) {
validateAllowedParameterValue(value);
}
}
return parameterMap;
}
@Override
public Enumeration<String> getParameterNames() {
Enumeration<String> paramaterNames = super.getParameterNames();
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return paramaterNames.hasMoreElements();
}
@Override
public String nextElement() {
String name = paramaterNames.nextElement();
validateAllowedParameterName(name);
return name;
}
};
}
@Override
public String[] getParameterValues(String name) {
validateAllowedParameterName(name);
String[] values = super.getParameterValues(name);
if (values != null) {
for (String value : values) {
validateAllowedParameterValue(value);
}
}
return values;
}
private void validateAllowedHeaderName(String headerNames) {
if (!StrictHttpFirewall.this.allowedHeaderNames.test(headerNames)) {
throw new RequestRejectedException(
"The request was rejected because the header name \"" + headerNames + "\" is not allowed.");
}
}
private void validateAllowedHeaderValue(String value) {
if (!StrictHttpFirewall.this.allowedHeaderValues.test(value)) {
throw new RequestRejectedException(
"The request was rejected because the header value \"" + value + "\" is not allowed.");
}
}
private void validateAllowedParameterName(String name) {
if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) {
throw new RequestRejectedException(
"The request was rejected because the parameter name \"" + name + "\" is not allowed.");
}
}
private void validateAllowedParameterValue(String value) {
if (!StrictHttpFirewall.this.allowedParameterValues.test(value)) {
throw new RequestRejectedException(
"The request was rejected because the parameter value \"" + value + "\" is not allowed.");
}
}
@Override
public void reset() {
}
};
}

View File

@ -62,20 +62,18 @@ public final class Header {
}
@Override
public boolean equals(Object o) {
if (this == o) {
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (o == null || getClass() != o.getClass()) {
if (obj == null || getClass() != obj.getClass()) {
return false;
}
Header header = (Header) o;
if (!this.headerName.equals(header.headerName)) {
Header other = (Header) obj;
if (!this.headerName.equals(other.headerName)) {
return false;
}
return this.headerValues.equals(header.headerValues);
return this.headerValues.equals(other.headerValues);
}
@Override

View File

@ -68,7 +68,6 @@ public class HeaderWriterFilter extends OncePerRequestFilter {
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
if (this.shouldWriteHeadersEagerly) {
doHeadersBefore(request, response, filterChain);
}

View File

@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.web.header.HeaderWriter;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
@ -76,10 +77,9 @@ public final class ClearSiteDataHeaderWriter implements HeaderWriter {
response.setHeader(CLEAR_SITE_DATA_HEADER, this.headerValue);
}
}
else if (this.logger.isDebugEnabled()) {
this.logger.debug("Not injecting Clear-Site-Data header since it did not match the " + "requestMatcher "
+ this.requestMatcher);
}
this.logger.debug(
LogMessage.format("Not injecting Clear-Site-Data header since it did not match the requestMatcher %s",
this.requestMatcher));
}
private String transformToHeaderValue(Directive... directives) {
@ -97,14 +97,19 @@ public final class ClearSiteDataHeaderWriter implements HeaderWriter {
}
/**
* <p>
* Represents the directive values expected by the {@link ClearSiteDataHeaderWriter}
* </p>
* .
* Represents the directive values expected by the {@link ClearSiteDataHeaderWriter}.
*/
public enum Directive {
CACHE("cache"), COOKIES("cookies"), STORAGE("storage"), EXECUTION_CONTEXTS("executionContexts"), ALL("*");
CACHE("cache"),
COOKIES("cookies"),
STORAGE("storage"),
EXECUTION_CONTEXTS("executionContexts"),
ALL("*");
private final String headerValue;

View File

@ -117,7 +117,7 @@ public final class ContentSecurityPolicyHeaderWriter implements HeaderWriter {
*/
@Override
public void writeHeaders(HttpServletRequest request, HttpServletResponse response) {
String headerName = !this.reportOnly ? CONTENT_SECURITY_POLICY_HEADER
String headerName = (!this.reportOnly) ? CONTENT_SECURITY_POLICY_HEADER
: CONTENT_SECURITY_POLICY_REPORT_ONLY_HEADER;
if (!response.containsHeader(headerName)) {
response.setHeader(headerName, this.policyDirectives);

View File

@ -174,19 +174,17 @@ public final class HpkpHeaderWriter implements HeaderWriter {
@Override
public void writeHeaders(HttpServletRequest request, HttpServletResponse response) {
if (this.requestMatcher.matches(request)) {
if (!this.pins.isEmpty()) {
String headerName = this.reportOnly ? HPKP_RO_HEADER_NAME : HPKP_HEADER_NAME;
if (!response.containsHeader(headerName)) {
response.setHeader(headerName, this.hpkpHeaderValue);
}
}
if (this.logger.isDebugEnabled()) {
this.logger.debug("Not injecting HPKP header since there aren't any pins");
}
}
else if (this.logger.isDebugEnabled()) {
if (!this.requestMatcher.matches(request)) {
this.logger.debug("Not injecting HPKP header since it wasn't a secure connection");
return;
}
if (this.pins.isEmpty()) {
this.logger.debug("Not injecting HPKP header since there aren't any pins");
return;
}
String headerName = (this.reportOnly) ? HPKP_RO_HEADER_NAME : HPKP_HEADER_NAME;
if (!response.containsHeader(headerName)) {
response.setHeader(headerName, this.hpkpHeaderValue);
}
}
@ -294,9 +292,7 @@ public final class HpkpHeaderWriter implements HeaderWriter {
* @throws IllegalArgumentException if maxAgeInSeconds is negative
*/
public void setMaxAgeInSeconds(long maxAgeInSeconds) {
if (maxAgeInSeconds < 0) {
throw new IllegalArgumentException("maxAgeInSeconds must be non-negative. Got " + maxAgeInSeconds);
}
Assert.isTrue(maxAgeInSeconds > 0, () -> "maxAgeInSeconds must be non-negative. Got " + maxAgeInSeconds);
this.maxAgeInSeconds = maxAgeInSeconds;
updateHpkpHeaderValue();
}
@ -414,11 +410,11 @@ public final class HpkpHeaderWriter implements HeaderWriter {
public void setReportUri(String reportUri) {
try {
this.reportUri = new URI(reportUri);
updateHpkpHeaderValue();
}
catch (URISyntaxException ex) {
throw new IllegalArgumentException(ex);
}
updateHpkpHeaderValue();
}
private void updateHpkpHeaderValue() {

View File

@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.web.header.HeaderWriter;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
@ -148,14 +149,13 @@ public final class HstsHeaderWriter implements HeaderWriter {
@Override
public void writeHeaders(HttpServletRequest request, HttpServletResponse response) {
if (this.requestMatcher.matches(request)) {
if (!response.containsHeader(HSTS_HEADER_NAME)) {
response.setHeader(HSTS_HEADER_NAME, this.hstsHeaderValue);
}
if (!this.requestMatcher.matches(request)) {
this.logger.debug(LogMessage.format(
"Not injecting HSTS header since it did not match the requestMatcher %s", this.requestMatcher));
return;
}
else if (this.logger.isDebugEnabled()) {
this.logger.debug(
"Not injecting HSTS header since it did not match the requestMatcher " + this.requestMatcher);
if (!response.containsHeader(HSTS_HEADER_NAME)) {
response.setHeader(HSTS_HEADER_NAME, this.hstsHeaderValue);
}
}
@ -188,9 +188,7 @@ public final class HstsHeaderWriter implements HeaderWriter {
* @throws IllegalArgumentException if maxAgeInSeconds is negative
*/
public void setMaxAgeInSeconds(long maxAgeInSeconds) {
if (maxAgeInSeconds < 0) {
throw new IllegalArgumentException("maxAgeInSeconds must be non-negative. Got " + maxAgeInSeconds);
}
Assert.isTrue(maxAgeInSeconds >= 0, () -> "maxAgeInSeconds must be non-negative. Got " + maxAgeInSeconds);
this.maxAgeInSeconds = maxAgeInSeconds;
updateHstsHeaderValue();
}

View File

@ -100,10 +100,21 @@ public class ReferrerPolicyHeaderWriter implements HeaderWriter {
public enum ReferrerPolicy {
NO_REFERRER("no-referrer"), NO_REFERRER_WHEN_DOWNGRADE("no-referrer-when-downgrade"), SAME_ORIGIN(
"same-origin"), ORIGIN("origin"), STRICT_ORIGIN("strict-origin"), ORIGIN_WHEN_CROSS_ORIGIN(
"origin-when-cross-origin"), STRICT_ORIGIN_WHEN_CROSS_ORIGIN(
"strict-origin-when-cross-origin"), UNSAFE_URL("unsafe-url");
NO_REFERRER("no-referrer"),
NO_REFERRER_WHEN_DOWNGRADE("no-referrer-when-downgrade"),
SAME_ORIGIN("same-origin"),
ORIGIN("origin"),
STRICT_ORIGIN("strict-origin"),
ORIGIN_WHEN_CROSS_ORIGIN("origin-when-cross-origin"),
STRICT_ORIGIN_WHEN_CROSS_ORIGIN("strict-origin-when-cross-origin"),
UNSAFE_URL("unsafe-url");
private static final Map<String, ReferrerPolicy> REFERRER_POLICIES;
@ -115,7 +126,7 @@ public class ReferrerPolicyHeaderWriter implements HeaderWriter {
REFERRER_POLICIES = Collections.unmodifiableMap(referrerPolicies);
}
private String policy;
private final String policy;
ReferrerPolicy(String policy) {
this.policy = policy;

View File

@ -21,6 +21,7 @@ import javax.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
@ -52,15 +53,11 @@ public abstract class AbstractRequestParameterAllowFromStrategy implements Allow
@Override
public String getAllowFromValue(HttpServletRequest request) {
String allowFromOrigin = request.getParameter(this.allowFromParameterName);
if (this.log.isDebugEnabled()) {
this.log.debug("Supplied origin '" + allowFromOrigin + "'");
}
this.log.debug(LogMessage.format("Supplied origin '%s'", allowFromOrigin));
if (StringUtils.hasText(allowFromOrigin) && allowed(allowFromOrigin)) {
return allowFromOrigin;
}
else {
return "DENY";
}
return "DENY";
}
/**

View File

@ -55,10 +55,9 @@ public final class XFrameOptionsHeaderWriter implements HeaderWriter {
*/
public XFrameOptionsHeaderWriter(XFrameOptionsMode frameOptionsMode) {
Assert.notNull(frameOptionsMode, "frameOptionsMode cannot be null");
if (XFrameOptionsMode.ALLOW_FROM.equals(frameOptionsMode)) {
throw new IllegalArgumentException(
"ALLOW_FROM requires an AllowFromStrategy. Please use FrameOptionsHeaderWriter(AllowFromStrategy allowFromStrategy) instead");
}
Assert.isTrue(!XFrameOptionsMode.ALLOW_FROM.equals(frameOptionsMode),
"ALLOW_FROM requires an AllowFromStrategy. Please use "
+ "FrameOptionsHeaderWriter(AllowFromStrategy allowFromStrategy) instead");
this.frameOptionsMode = frameOptionsMode;
this.allowFromStrategy = null;
}
@ -113,7 +112,10 @@ public final class XFrameOptionsHeaderWriter implements HeaderWriter {
*/
public enum XFrameOptionsMode {
DENY("DENY"), SAMEORIGIN("SAMEORIGIN"),
DENY("DENY"),
SAMEORIGIN("SAMEORIGIN"),
/**
* @deprecated ALLOW-FROM is an obsolete directive that no longer works in modern
* browsers. Instead use Content-Security-Policy with the <a href=
@ -123,7 +125,7 @@ public final class XFrameOptionsHeaderWriter implements HeaderWriter {
@Deprecated
ALLOW_FROM("ALLOW-FROM");
private String mode;
private final String mode;
XFrameOptionsMode(String mode) {
this.mode = mode;

View File

@ -29,6 +29,9 @@ import org.springframework.util.Assert;
*/
public final class SecurityHeaders {
private SecurityHeaders() {
}
/**
* Sets the provided value as a Bearer token in a header with the name of
* {@link HttpHeaders#AUTHORIZATION}
@ -40,7 +43,4 @@ public final class SecurityHeaders {
return (headers) -> headers.set(HttpHeaders.AUTHORIZATION, "Bearer " + bearerTokenValue);
}
private SecurityHeaders() {
}
}

Some files were not shown because too many files have changed in this diff Show More