diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index 9a489c6f96..362168f109 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.web.csrf; import java.io.IOException; +import java.security.MessageDigest; import java.util.Arrays; import java.util.HashSet; @@ -28,6 +30,9 @@ import javax.servlet.http.HttpSession; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; +import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.crypto.codec.Utf8; import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.access.AccessDeniedHandlerImpl; import org.springframework.security.web.util.UrlUtils; @@ -35,8 +40,6 @@ import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; -import static java.lang.Boolean.TRUE; - /** *

* Applies @@ -58,6 +61,7 @@ import static java.lang.Boolean.TRUE; * @since 3.2 */ public final class CsrfFilter extends OncePerRequestFilter { + /** * The default {@link RequestMatcher} that indicates if CSRF protection is required or * not. The default is to ignore GET, HEAD, TRACE, OPTIONS and process all other @@ -66,18 +70,21 @@ public final class CsrfFilter extends OncePerRequestFilter { public static final RequestMatcher DEFAULT_CSRF_MATCHER = new DefaultRequiresCsrfMatcher(); /** - * The attribute name to use when marking a given request as one that should not be filtered. + * The attribute name to use when marking a given request as one that should not be + * filtered. * - * To use, set the attribute on your {@link HttpServletRequest}: - *

+	 * To use, set the attribute on your {@link HttpServletRequest}: 
 	 * 	CsrfFilter.skipRequest(request);
 	 * 
*/ private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfFilter.class.getName(); private final Log logger = LogFactory.getLog(getClass()); + private final CsrfTokenRepository tokenRepository; + private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER; + private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl(); public CsrfFilter(CsrfTokenRepository csrfTokenRepository) { @@ -87,62 +94,46 @@ public final class CsrfFilter extends OncePerRequestFilter { @Override protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { - return TRUE.equals(request.getAttribute(SHOULD_NOT_FILTER)); + return Boolean.TRUE.equals(request.getAttribute(SHOULD_NOT_FILTER)); } - /* - * (non-Javadoc) - * - * @see - * org.springframework.web.filter.OncePerRequestFilter#doFilterInternal(javax.servlet - * .http.HttpServletRequest, javax.servlet.http.HttpServletResponse, - * javax.servlet.FilterChain) - */ @Override - protected void doFilterInternal(HttpServletRequest request, - HttpServletResponse response, FilterChain filterChain) - throws ServletException, IOException { + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { request.setAttribute(HttpServletResponse.class.getName(), response); - CsrfToken csrfToken = this.tokenRepository.loadToken(request); - final boolean missingToken = csrfToken == null; + boolean missingToken = (csrfToken == null); if (missingToken) { csrfToken = this.tokenRepository.generateToken(request); this.tokenRepository.saveToken(csrfToken, request, response); } request.setAttribute(CsrfToken.class.getName(), csrfToken); request.setAttribute(csrfToken.getParameterName(), csrfToken); - if (!this.requireCsrfProtectionMatcher.matches(request)) { + if (this.logger.isTraceEnabled()) { + this.logger.trace("Did not protect against CSRF since request did not match " + + this.requireCsrfProtectionMatcher); + } filterChain.doFilter(request, response); return; } - String actualToken = request.getHeader(csrfToken.getHeaderName()); if (actualToken == null) { actualToken = request.getParameter(csrfToken.getParameterName()); } - if (!csrfToken.getToken().equals(actualToken)) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("Invalid CSRF token found for " - + UrlUtils.buildFullRequestUrl(request)); - } - if (missingToken) { - this.accessDeniedHandler.handle(request, response, - new MissingCsrfTokenException(actualToken)); - } - else { - this.accessDeniedHandler.handle(request, response, - new InvalidCsrfTokenException(csrfToken, actualToken)); - } + if (!equalsConstantTime(csrfToken.getToken(), actualToken)) { + this.logger.debug( + LogMessage.of(() -> "Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request))); + AccessDeniedException exception = (!missingToken) ? new InvalidCsrfTokenException(csrfToken, actualToken) + : new MissingCsrfTokenException(actualToken); + this.accessDeniedHandler.handle(request, response, exception); return; } - filterChain.doFilter(request, response); } public static void skipRequest(HttpServletRequest request) { - request.setAttribute(SHOULD_NOT_FILTER, TRUE); + request.setAttribute(SHOULD_NOT_FILTER, Boolean.TRUE); } /** @@ -154,14 +145,11 @@ public final class CsrfFilter extends OncePerRequestFilter { * The default is to apply CSRF protection for any HTTP method other than GET, HEAD, * TRACE, OPTIONS. *

- * * @param requireCsrfProtectionMatcher the {@link RequestMatcher} used to determine if * CSRF protection should be applied. */ - public void setRequireCsrfProtectionMatcher( - RequestMatcher requireCsrfProtectionMatcher) { - Assert.notNull(requireCsrfProtectionMatcher, - "requireCsrfProtectionMatcher cannot be null"); + public void setRequireCsrfProtectionMatcher(RequestMatcher requireCsrfProtectionMatcher) { + Assert.notNull(requireCsrfProtectionMatcher, "requireCsrfProtectionMatcher cannot be null"); this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher; } @@ -172,7 +160,6 @@ public final class CsrfFilter extends OncePerRequestFilter { *

* The default is to use AccessDeniedHandlerImpl with no arguments. *

- * * @param accessDeniedHandler the {@link AccessDeniedHandler} to use */ public void setAccessDeniedHandler(AccessDeniedHandler accessDeniedHandler) { @@ -180,20 +167,38 @@ public final class CsrfFilter extends OncePerRequestFilter { this.accessDeniedHandler = accessDeniedHandler; } - private static final class DefaultRequiresCsrfMatcher implements RequestMatcher { - private final HashSet allowedMethods = new HashSet<>( - Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS")); + /** + * Constant time comparison to prevent against timing attacks. + * @param expected + * @param actual + * @return + */ + private static boolean equalsConstantTime(String expected, String actual) { + byte[] expectedBytes = bytesUtf8(expected); + byte[] actualBytes = bytesUtf8(actual); + return MessageDigest.isEqual(expectedBytes, actualBytes); + } + + private static byte[] bytesUtf8(String s) { + // need to check if Utf8.encode() runs in constant time (probably not). + // This may leak length of string. + return (s != null) ? Utf8.encode(s) : null; + } + + private static final class DefaultRequiresCsrfMatcher implements RequestMatcher { + + private final HashSet allowedMethods = new HashSet<>(Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS")); - /* - * (non-Javadoc) - * - * @see - * org.springframework.security.web.util.matcher.RequestMatcher#matches(javax. - * servlet.http.HttpServletRequest) - */ @Override public boolean matches(HttpServletRequest request) { return !this.allowedMethods.contains(request.getMethod()); } + + @Override + public String toString() { + return "CsrfNotRequired " + this.allowedMethods; + } + } + } diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java index 35cfe2a65a..a2699018b3 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java @@ -16,26 +16,28 @@ package org.springframework.security.web.server.csrf; +import java.security.MessageDigest; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import reactor.core.publisher.Mono; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.codec.multipart.FormFieldPart; import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.security.crypto.codec.Utf8; import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler; import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; +import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; -import reactor.core.publisher.Mono; - -import java.util.Arrays; -import java.util.HashSet; -import java.util.Set; - -import static java.lang.Boolean.TRUE; /** *

@@ -64,13 +66,14 @@ import static java.lang.Boolean.TRUE; * @since 5.0 */ public class CsrfWebFilter implements WebFilter { + public static final ServerWebExchangeMatcher DEFAULT_CSRF_MATCHER = new DefaultRequireCsrfProtectionMatcher(); /** - * The attribute name to use when marking a given request as one that should not be filtered. + * The attribute name to use when marking a given request as one that should not be + * filtered. * - * To use, set the attribute on your {@link ServerWebExchange}: - *

+	 * To use, set the attribute on your {@link ServerWebExchange}: 
 	 * 	CsrfWebFilter.skipExchange(exchange);
 	 * 
*/ @@ -80,32 +83,31 @@ public class CsrfWebFilter implements WebFilter { private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository(); - private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(HttpStatus.FORBIDDEN); + private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler( + HttpStatus.FORBIDDEN); private boolean isTokenFromMultipartDataEnabled; - public void setAccessDeniedHandler( - ServerAccessDeniedHandler accessDeniedHandler) { + public void setAccessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) { Assert.notNull(accessDeniedHandler, "accessDeniedHandler"); this.accessDeniedHandler = accessDeniedHandler; } - public void setCsrfTokenRepository( - ServerCsrfTokenRepository csrfTokenRepository) { + public void setCsrfTokenRepository(ServerCsrfTokenRepository csrfTokenRepository) { Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); this.csrfTokenRepository = csrfTokenRepository; } - public void setRequireCsrfProtectionMatcher( - ServerWebExchangeMatcher requireCsrfProtectionMatcher) { + public void setRequireCsrfProtectionMatcher(ServerWebExchangeMatcher requireCsrfProtectionMatcher) { Assert.notNull(requireCsrfProtectionMatcher, "requireCsrfProtectionMatcher cannot be null"); this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher; } /** - * Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token from the body of multipart - * data requests. - * @param tokenFromMultipartDataEnabled true if should read from multipart form body, else false. Default is false + * Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token + * from the body of multipart data requests. + * @param tokenFromMultipartDataEnabled true if should read from multipart form body, + * else false. Default is false */ public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) { this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled; @@ -113,38 +115,33 @@ public class CsrfWebFilter implements WebFilter { @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { - if (TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) { + if (Boolean.TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) { return chain.filter(exchange).then(Mono.empty()); } - - return this.requireCsrfProtectionMatcher.matches(exchange) - .filter( matchResult -> matchResult.isMatch()) - .filter( matchResult -> !exchange.getAttributes().containsKey(CsrfToken.class.getName())) - .flatMap(m -> validateToken(exchange)) - .flatMap(m -> continueFilterChain(exchange, chain)) - .switchIfEmpty(continueFilterChain(exchange, chain).then(Mono.empty())) - .onErrorResume(CsrfException.class, e -> this.accessDeniedHandler - .handle(exchange, e)); + return this.requireCsrfProtectionMatcher.matches(exchange).filter(MatchResult::isMatch) + .filter((matchResult) -> !exchange.getAttributes().containsKey(CsrfToken.class.getName())) + .flatMap((m) -> validateToken(exchange)).flatMap((m) -> continueFilterChain(exchange, chain)) + .switchIfEmpty(continueFilterChain(exchange, chain).then(Mono.empty())) + .onErrorResume(CsrfException.class, (ex) -> this.accessDeniedHandler.handle(exchange, ex)); } public static void skipExchange(ServerWebExchange exchange) { - exchange.getAttributes().put(SHOULD_NOT_FILTER, TRUE); + exchange.getAttributes().put(SHOULD_NOT_FILTER, Boolean.TRUE); } private Mono validateToken(ServerWebExchange exchange) { return this.csrfTokenRepository.loadToken(exchange) - .switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("An expected CSRF token cannot be found")))) - .filterWhen(expected -> containsValidCsrfToken(exchange, expected)) - .switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("Invalid CSRF Token")))) - .then(); + .switchIfEmpty( + Mono.defer(() -> Mono.error(new CsrfException("An expected CSRF token cannot be found")))) + .filterWhen((expected) -> containsValidCsrfToken(exchange, expected)) + .switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("Invalid CSRF Token")))).then(); } private Mono containsValidCsrfToken(ServerWebExchange exchange, CsrfToken expected) { - return exchange.getFormData() - .flatMap(data -> Mono.justOrEmpty(data.getFirst(expected.getParameterName()))) - .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName()))) - .switchIfEmpty(tokenFromMultipartData(exchange, expected)) - .map(actual -> actual.equals(expected.getToken())); + return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(expected.getParameterName()))) + .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName()))) + .switchIfEmpty(tokenFromMultipartData(exchange, expected)) + .map((actual) -> equalsConstantTime(actual, expected.getToken())); } private Mono tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) { @@ -157,14 +154,12 @@ public class CsrfWebFilter implements WebFilter { if (!contentType.includes(MediaType.MULTIPART_FORM_DATA)) { return Mono.empty(); } - return exchange.getMultipartData() - .map(d -> d.getFirst(expected.getParameterName())) - .cast(FormFieldPart.class) - .map(FormFieldPart::value); + return exchange.getMultipartData().map((d) -> d.getFirst(expected.getParameterName())).cast(FormFieldPart.class) + .map(FormFieldPart::value); } private Mono continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) { - return Mono.defer(() ->{ + return Mono.defer(() -> { Mono csrfToken = csrfToken(exchange); exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken); return chain.filter(exchange); @@ -172,26 +167,44 @@ public class CsrfWebFilter implements WebFilter { } private Mono csrfToken(ServerWebExchange exchange) { - return this.csrfTokenRepository.loadToken(exchange) - .switchIfEmpty(generateToken(exchange)); + return this.csrfTokenRepository.loadToken(exchange).switchIfEmpty(generateToken(exchange)); + } + + /** + * Constant time comparison to prevent against timing attacks. + * @param expected + * @param actual + * @return + */ + private static boolean equalsConstantTime(String expected, String actual) { + byte[] expectedBytes = bytesUtf8(expected); + byte[] actualBytes = bytesUtf8(actual); + return MessageDigest.isEqual(expectedBytes, actualBytes); + } + + private static byte[] bytesUtf8(String s) { + // need to check if Utf8.encode() runs in constant time (probably not). + // This may leak length of string. + return (s != null) ? Utf8.encode(s) : null; } private Mono generateToken(ServerWebExchange exchange) { return this.csrfTokenRepository.generateToken(exchange) - .delayUntil(token -> this.csrfTokenRepository.saveToken(exchange, token)); + .delayUntil((token) -> this.csrfTokenRepository.saveToken(exchange, token)); } private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher { + private static final Set ALLOWED_METHODS = new HashSet<>( - Arrays.asList(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.TRACE, HttpMethod.OPTIONS)); + Arrays.asList(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.TRACE, HttpMethod.OPTIONS)); @Override public Mono matches(ServerWebExchange exchange) { - return Mono.just(exchange.getRequest()) - .flatMap(r -> Mono.justOrEmpty(r.getMethod())) - .filter(m -> ALLOWED_METHODS.contains(m)) - .flatMap(m -> MatchResult.notMatch()) - .switchIfEmpty(MatchResult.match()); + return Mono.just(exchange.getRequest()).flatMap((r) -> Mono.justOrEmpty(r.getMethod())) + .filter(ALLOWED_METHODS::contains).flatMap((m) -> MatchResult.notMatch()) + .switchIfEmpty(MatchResult.match()); } + } + }