From e6d6b397677598b62e5270309d3f0a0711d8e692 Mon Sep 17 00:00:00 2001
From: Rob Winch
* Applies
@@ -58,6 +61,7 @@ import static java.lang.Boolean.TRUE;
* @since 3.2
*/
public final class CsrfFilter extends OncePerRequestFilter {
+
/**
* The default {@link RequestMatcher} that indicates if CSRF protection is required or
* not. The default is to ignore GET, HEAD, TRACE, OPTIONS and process all other
@@ -66,18 +70,21 @@ public final class CsrfFilter extends OncePerRequestFilter {
public static final RequestMatcher DEFAULT_CSRF_MATCHER = new DefaultRequiresCsrfMatcher();
/**
- * The attribute name to use when marking a given request as one that should not be filtered.
+ * The attribute name to use when marking a given request as one that should not be
+ * filtered.
*
- * To use, set the attribute on your {@link HttpServletRequest}:
- *
+ * To use, set the attribute on your {@link HttpServletRequest}:
* CsrfFilter.skipRequest(request);
*
*/
private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfFilter.class.getName();
private final Log logger = LogFactory.getLog(getClass());
+
private final CsrfTokenRepository tokenRepository;
+
private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
+
private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl();
public CsrfFilter(CsrfTokenRepository csrfTokenRepository) {
@@ -87,62 +94,46 @@ public final class CsrfFilter extends OncePerRequestFilter {
@Override
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
- return TRUE.equals(request.getAttribute(SHOULD_NOT_FILTER));
+ return Boolean.TRUE.equals(request.getAttribute(SHOULD_NOT_FILTER));
}
- /*
- * (non-Javadoc)
- *
- * @see
- * org.springframework.web.filter.OncePerRequestFilter#doFilterInternal(javax.servlet
- * .http.HttpServletRequest, javax.servlet.http.HttpServletResponse,
- * javax.servlet.FilterChain)
- */
@Override
- protected void doFilterInternal(HttpServletRequest request,
- HttpServletResponse response, FilterChain filterChain)
- throws ServletException, IOException {
+ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
+ throws ServletException, IOException {
request.setAttribute(HttpServletResponse.class.getName(), response);
-
CsrfToken csrfToken = this.tokenRepository.loadToken(request);
- final boolean missingToken = csrfToken == null;
+ boolean missingToken = (csrfToken == null);
if (missingToken) {
csrfToken = this.tokenRepository.generateToken(request);
this.tokenRepository.saveToken(csrfToken, request, response);
}
request.setAttribute(CsrfToken.class.getName(), csrfToken);
request.setAttribute(csrfToken.getParameterName(), csrfToken);
-
if (!this.requireCsrfProtectionMatcher.matches(request)) {
+ if (this.logger.isTraceEnabled()) {
+ this.logger.trace("Did not protect against CSRF since request did not match "
+ + this.requireCsrfProtectionMatcher);
+ }
filterChain.doFilter(request, response);
return;
}
-
String actualToken = request.getHeader(csrfToken.getHeaderName());
if (actualToken == null) {
actualToken = request.getParameter(csrfToken.getParameterName());
}
- if (!csrfToken.getToken().equals(actualToken)) {
- if (this.logger.isDebugEnabled()) {
- this.logger.debug("Invalid CSRF token found for "
- + UrlUtils.buildFullRequestUrl(request));
- }
- if (missingToken) {
- this.accessDeniedHandler.handle(request, response,
- new MissingCsrfTokenException(actualToken));
- }
- else {
- this.accessDeniedHandler.handle(request, response,
- new InvalidCsrfTokenException(csrfToken, actualToken));
- }
+ if (!equalsConstantTime(csrfToken.getToken(), actualToken)) {
+ this.logger.debug(
+ LogMessage.of(() -> "Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request)));
+ AccessDeniedException exception = (!missingToken) ? new InvalidCsrfTokenException(csrfToken, actualToken)
+ : new MissingCsrfTokenException(actualToken);
+ this.accessDeniedHandler.handle(request, response, exception);
return;
}
-
filterChain.doFilter(request, response);
}
public static void skipRequest(HttpServletRequest request) {
- request.setAttribute(SHOULD_NOT_FILTER, TRUE);
+ request.setAttribute(SHOULD_NOT_FILTER, Boolean.TRUE);
}
/**
@@ -154,14 +145,11 @@ public final class CsrfFilter extends OncePerRequestFilter {
* The default is to apply CSRF protection for any HTTP method other than GET, HEAD,
* TRACE, OPTIONS.
*
* The default is to use AccessDeniedHandlerImpl with no arguments. *
- * * @param accessDeniedHandler the {@link AccessDeniedHandler} to use */ public void setAccessDeniedHandler(AccessDeniedHandler accessDeniedHandler) { @@ -180,20 +167,38 @@ public final class CsrfFilter extends OncePerRequestFilter { this.accessDeniedHandler = accessDeniedHandler; } - private static final class DefaultRequiresCsrfMatcher implements RequestMatcher { - private final HashSet@@ -64,13 +66,14 @@ import static java.lang.Boolean.TRUE; * @since 5.0 */ public class CsrfWebFilter implements WebFilter { + public static final ServerWebExchangeMatcher DEFAULT_CSRF_MATCHER = new DefaultRequireCsrfProtectionMatcher(); /** - * The attribute name to use when marking a given request as one that should not be filtered. + * The attribute name to use when marking a given request as one that should not be + * filtered. * - * To use, set the attribute on your {@link ServerWebExchange}: - *
+ * To use, set the attribute on your {@link ServerWebExchange}:* CsrfWebFilter.skipExchange(exchange); **/ @@ -80,32 +83,31 @@ public class CsrfWebFilter implements WebFilter { private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository(); - private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(HttpStatus.FORBIDDEN); + private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler( + HttpStatus.FORBIDDEN); private boolean isTokenFromMultipartDataEnabled; - public void setAccessDeniedHandler( - ServerAccessDeniedHandler accessDeniedHandler) { + public void setAccessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) { Assert.notNull(accessDeniedHandler, "accessDeniedHandler"); this.accessDeniedHandler = accessDeniedHandler; } - public void setCsrfTokenRepository( - ServerCsrfTokenRepository csrfTokenRepository) { + public void setCsrfTokenRepository(ServerCsrfTokenRepository csrfTokenRepository) { Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); this.csrfTokenRepository = csrfTokenRepository; } - public void setRequireCsrfProtectionMatcher( - ServerWebExchangeMatcher requireCsrfProtectionMatcher) { + public void setRequireCsrfProtectionMatcher(ServerWebExchangeMatcher requireCsrfProtectionMatcher) { Assert.notNull(requireCsrfProtectionMatcher, "requireCsrfProtectionMatcher cannot be null"); this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher; } /** - * Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token from the body of multipart - * data requests. - * @param tokenFromMultipartDataEnabled true if should read from multipart form body, else false. Default is false + * Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token + * from the body of multipart data requests. + * @param tokenFromMultipartDataEnabled true if should read from multipart form body, + * else false. Default is false */ public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) { this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled; @@ -113,38 +115,33 @@ public class CsrfWebFilter implements WebFilter { @Override public Monofilter(ServerWebExchange exchange, WebFilterChain chain) { - if (TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) { + if (Boolean.TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) { return chain.filter(exchange).then(Mono.empty()); } - - return this.requireCsrfProtectionMatcher.matches(exchange) - .filter( matchResult -> matchResult.isMatch()) - .filter( matchResult -> !exchange.getAttributes().containsKey(CsrfToken.class.getName())) - .flatMap(m -> validateToken(exchange)) - .flatMap(m -> continueFilterChain(exchange, chain)) - .switchIfEmpty(continueFilterChain(exchange, chain).then(Mono.empty())) - .onErrorResume(CsrfException.class, e -> this.accessDeniedHandler - .handle(exchange, e)); + return this.requireCsrfProtectionMatcher.matches(exchange).filter(MatchResult::isMatch) + .filter((matchResult) -> !exchange.getAttributes().containsKey(CsrfToken.class.getName())) + .flatMap((m) -> validateToken(exchange)).flatMap((m) -> continueFilterChain(exchange, chain)) + .switchIfEmpty(continueFilterChain(exchange, chain).then(Mono.empty())) + .onErrorResume(CsrfException.class, (ex) -> this.accessDeniedHandler.handle(exchange, ex)); } public static void skipExchange(ServerWebExchange exchange) { - exchange.getAttributes().put(SHOULD_NOT_FILTER, TRUE); + exchange.getAttributes().put(SHOULD_NOT_FILTER, Boolean.TRUE); } private Mono validateToken(ServerWebExchange exchange) { return this.csrfTokenRepository.loadToken(exchange) - .switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("An expected CSRF token cannot be found")))) - .filterWhen(expected -> containsValidCsrfToken(exchange, expected)) - .switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("Invalid CSRF Token")))) - .then(); + .switchIfEmpty( + Mono.defer(() -> Mono.error(new CsrfException("An expected CSRF token cannot be found")))) + .filterWhen((expected) -> containsValidCsrfToken(exchange, expected)) + .switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("Invalid CSRF Token")))).then(); } private Mono containsValidCsrfToken(ServerWebExchange exchange, CsrfToken expected) { - return exchange.getFormData() - .flatMap(data -> Mono.justOrEmpty(data.getFirst(expected.getParameterName()))) - .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName()))) - .switchIfEmpty(tokenFromMultipartData(exchange, expected)) - .map(actual -> actual.equals(expected.getToken())); + return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(expected.getParameterName()))) + .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName()))) + .switchIfEmpty(tokenFromMultipartData(exchange, expected)) + .map((actual) -> equalsConstantTime(actual, expected.getToken())); } private Mono tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) { @@ -157,14 +154,12 @@ public class CsrfWebFilter implements WebFilter { if (!contentType.includes(MediaType.MULTIPART_FORM_DATA)) { return Mono.empty(); } - return exchange.getMultipartData() - .map(d -> d.getFirst(expected.getParameterName())) - .cast(FormFieldPart.class) - .map(FormFieldPart::value); + return exchange.getMultipartData().map((d) -> d.getFirst(expected.getParameterName())).cast(FormFieldPart.class) + .map(FormFieldPart::value); } private Mono continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) { - return Mono.defer(() ->{ + return Mono.defer(() -> { Mono csrfToken = csrfToken(exchange); exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken); return chain.filter(exchange); @@ -172,26 +167,44 @@ public class CsrfWebFilter implements WebFilter { } private Mono csrfToken(ServerWebExchange exchange) { - return this.csrfTokenRepository.loadToken(exchange) - .switchIfEmpty(generateToken(exchange)); + return this.csrfTokenRepository.loadToken(exchange).switchIfEmpty(generateToken(exchange)); + } + + /** + * Constant time comparison to prevent against timing attacks. + * @param expected + * @param actual + * @return + */ + private static boolean equalsConstantTime(String expected, String actual) { + byte[] expectedBytes = bytesUtf8(expected); + byte[] actualBytes = bytesUtf8(actual); + return MessageDigest.isEqual(expectedBytes, actualBytes); + } + + private static byte[] bytesUtf8(String s) { + // need to check if Utf8.encode() runs in constant time (probably not). + // This may leak length of string. + return (s != null) ? Utf8.encode(s) : null; } private Mono generateToken(ServerWebExchange exchange) { return this.csrfTokenRepository.generateToken(exchange) - .delayUntil(token -> this.csrfTokenRepository.saveToken(exchange, token)); + .delayUntil((token) -> this.csrfTokenRepository.saveToken(exchange, token)); } private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher { + private static final Set ALLOWED_METHODS = new HashSet<>( - Arrays.asList(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.TRACE, HttpMethod.OPTIONS)); + Arrays.asList(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.TRACE, HttpMethod.OPTIONS)); @Override public Mono matches(ServerWebExchange exchange) { - return Mono.just(exchange.getRequest()) - .flatMap(r -> Mono.justOrEmpty(r.getMethod())) - .filter(m -> ALLOWED_METHODS.contains(m)) - .flatMap(m -> MatchResult.notMatch()) - .switchIfEmpty(MatchResult.match()); + return Mono.just(exchange.getRequest()).flatMap((r) -> Mono.justOrEmpty(r.getMethod())) + .filter(ALLOWED_METHODS::contains).flatMap((m) -> MatchResult.notMatch()) + .switchIfEmpty(MatchResult.match()); } + } + }