CsrfWebFilter supports multipart/form-data

Fixes gh-7576
This commit is contained in:
Rob Winch 2019-10-28 14:05:42 -05:00
parent 387f765595
commit 635f7e1edd
5 changed files with 148 additions and 11 deletions

View File

@ -2731,6 +2731,19 @@ public class ServerHttpSecurity {
return this; return this;
} }
/**
* Specifies if {@link CsrfWebFilter} should try to resolve the actual CSRF token from the body of multipart
* data requests.
*
* @param enabled true if should read from multipart form body, else false. Default is false
* @return the {@link CsrfSpec} for additional configuration
*/
public CsrfSpec tokenFromMultipartDataEnabled(boolean enabled) {
this.filter.setTokenFromMultipartDataEnabled(enabled);
return this;
}
/** /**
* Allows method chaining to continue configuring the {@link ServerHttpSecurity} * Allows method chaining to continue configuring the {@link ServerHttpSecurity}
* @return the {@link ServerHttpSecurity} to continue configuring * @return the {@link ServerHttpSecurity} to continue configuring

View File

@ -210,6 +210,7 @@ dependencyManagement {
dependency 'org.slf4j:slf4j-nop:1.7.28' dependency 'org.slf4j:slf4j-nop:1.7.28'
dependency 'org.sonatype.sisu.inject:cglib:2.2.1-v20090111' dependency 'org.sonatype.sisu.inject:cglib:2.2.1-v20090111'
dependency 'org.springframework.ldap:spring-ldap-core:2.3.2.RELEASE' dependency 'org.springframework.ldap:spring-ldap-core:2.3.2.RELEASE'
dependency 'org.synchronoss.cloud:nio-multipart-parser:1.1.0'
dependency 'org.thymeleaf:thymeleaf-spring5:3.0.11.RELEASE' dependency 'org.thymeleaf:thymeleaf-spring5:3.0.11.RELEASE'
dependency 'org.unbescape:unbescape:1.1.5.RELEASE' dependency 'org.unbescape:unbescape:1.1.5.RELEASE'
dependency 'org.w3c.css:sac:1.3' dependency 'org.w3c.css:sac:1.3'

View File

@ -25,6 +25,7 @@ dependencies {
testCompile 'org.codehaus.groovy:groovy-all' testCompile 'org.codehaus.groovy:groovy-all'
testCompile 'org.skyscreamer:jsonassert' testCompile 'org.skyscreamer:jsonassert'
testCompile 'org.springframework:spring-webflux' testCompile 'org.springframework:spring-webflux'
testCompile 'org.synchronoss.cloud:nio-multipart-parser'
testCompile powerMock2Dependencies testCompile powerMock2Dependencies
testCompile spockDependencies testCompile spockDependencies

View File

@ -16,14 +16,12 @@
package org.springframework.security.web.server.csrf; package org.springframework.security.web.server.csrf;
import java.util.Arrays; import org.springframework.http.HttpHeaders;
import java.util.HashSet;
import java.util.Set;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.codec.multipart.FormFieldPart;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler; import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler; import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
@ -31,6 +29,11 @@ import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import static java.lang.Boolean.TRUE; import static java.lang.Boolean.TRUE;
@ -78,6 +81,8 @@ public class CsrfWebFilter implements WebFilter {
private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(HttpStatus.FORBIDDEN); private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(HttpStatus.FORBIDDEN);
private boolean isTokenFromMultipartDataEnabled;
public void setAccessDeniedHandler( public void setAccessDeniedHandler(
ServerAccessDeniedHandler accessDeniedHandler) { ServerAccessDeniedHandler accessDeniedHandler) {
Assert.notNull(accessDeniedHandler, "accessDeniedHandler"); Assert.notNull(accessDeniedHandler, "accessDeniedHandler");
@ -96,6 +101,15 @@ public class CsrfWebFilter implements WebFilter {
this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher; this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher;
} }
/**
* Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token from the body of multipart
* data requests.
* @param tokenFromMultipartDataEnabled true if should read from multipart form body, else false. Default is false
*/
public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) {
this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled;
}
@Override @Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) { public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
if (TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) { if (TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) {
@ -128,9 +142,26 @@ public class CsrfWebFilter implements WebFilter {
return exchange.getFormData() return exchange.getFormData()
.flatMap(data -> Mono.justOrEmpty(data.getFirst(expected.getParameterName()))) .flatMap(data -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
.switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName()))) .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
.switchIfEmpty(tokenFromMultipartData(exchange, expected))
.map(actual -> actual.equals(expected.getToken())); .map(actual -> actual.equals(expected.getToken()));
} }
private Mono<String> tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
if (!this.isTokenFromMultipartDataEnabled) {
return Mono.empty();
}
ServerHttpRequest request = exchange.getRequest();
HttpHeaders headers = request.getHeaders();
MediaType contentType = headers.getContentType();
if (!contentType.includes(MediaType.MULTIPART_FORM_DATA)) {
return Mono.empty();
}
return exchange.getMultipartData()
.map(d -> d.getFirst(expected.getParameterName()))
.cast(FormFieldPart.class)
.map(FormFieldPart::value);
}
private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) { private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
return Mono.defer(() ->{ return Mono.defer(() ->{
Mono<CsrfToken> csrfToken = csrfToken(exchange); Mono<CsrfToken> csrfToken = csrfToken(exchange);

View File

@ -20,17 +20,20 @@ import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import reactor.test.publisher.PublisherProbe;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebSession; import org.springframework.web.server.WebSession;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import reactor.test.publisher.PublisherProbe;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
@ -38,6 +41,7 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.springframework.mock.web.server.MockServerWebExchange.from; import static org.springframework.mock.web.server.MockServerWebExchange.from;
import static org.springframework.web.reactive.function.BodyInserters.fromMultipartData;
/** /**
* @author Rob Winch * @author Rob Winch
@ -57,7 +61,7 @@ public class CsrfWebFilterTests {
private MockServerWebExchange get = from( private MockServerWebExchange get = from(
MockServerHttpRequest.get("/")); MockServerHttpRequest.get("/"));
private MockServerWebExchange post = from( private ServerWebExchange post = from(
MockServerHttpRequest.post("/")); MockServerHttpRequest.post("/"));
@Test @Test
@ -193,4 +197,91 @@ public class CsrfWebFilterTests {
verifyZeroInteractions(matcher); verifyZeroInteractions(matcher);
} }
@Test
public void filterWhenMultipartFormDataAndNotEnabledThenDenied() {
this.csrfFilter.setCsrfTokenRepository(this.repository);
when(this.repository.loadToken(any()))
.thenReturn(Mono.just(this.token));
WebTestClient client = WebTestClient.bindToController(new OkController())
.webFilter(this.csrfFilter)
.build();
client.post()
.uri("/")
.contentType(MediaType.MULTIPART_FORM_DATA)
.body(fromMultipartData(this.token.getParameterName(), this.token.getToken()))
.exchange()
.expectStatus().isForbidden();
}
@Test
public void filterWhenMultipartFormDataAndEnabledThenGranted() {
this.csrfFilter.setCsrfTokenRepository(this.repository);
this.csrfFilter.setTokenFromMultipartDataEnabled(true);
when(this.repository.loadToken(any()))
.thenReturn(Mono.just(this.token));
when(this.repository.generateToken(any()))
.thenReturn(Mono.just(this.token));
WebTestClient client = WebTestClient.bindToController(new OkController())
.webFilter(this.csrfFilter)
.build();
client.post()
.uri("/")
.contentType(MediaType.MULTIPART_FORM_DATA)
.body(fromMultipartData(this.token.getParameterName(), this.token.getToken()))
.exchange()
.expectStatus().is2xxSuccessful();
}
@Test
public void filterWhenFormDataAndEnabledThenGranted() {
this.csrfFilter.setCsrfTokenRepository(this.repository);
this.csrfFilter.setTokenFromMultipartDataEnabled(true);
when(this.repository.loadToken(any()))
.thenReturn(Mono.just(this.token));
when(this.repository.generateToken(any()))
.thenReturn(Mono.just(this.token));
WebTestClient client = WebTestClient.bindToController(new OkController())
.webFilter(this.csrfFilter)
.build();
client.post()
.uri("/")
.contentType(MediaType.APPLICATION_FORM_URLENCODED)
.bodyValue(this.token.getParameterName() + "="+this.token.getToken())
.exchange()
.expectStatus().is2xxSuccessful();
}
@Test
public void filterWhenMultipartMixedAndEnabledThenNotRead() {
this.csrfFilter.setCsrfTokenRepository(this.repository);
this.csrfFilter.setTokenFromMultipartDataEnabled(true);
when(this.repository.loadToken(any()))
.thenReturn(Mono.just(this.token));
WebTestClient client = WebTestClient.bindToController(new OkController())
.webFilter(this.csrfFilter)
.build();
client.post()
.uri("/")
.contentType(MediaType.MULTIPART_MIXED)
.bodyValue(this.token.getParameterName() + "="+this.token.getToken())
.exchange()
.expectStatus().isForbidden();
}
@RestController
static class OkController {
@RequestMapping("/**")
String ok() {
return "ok";
}
}
} }