parent
387f765595
commit
635f7e1edd
|
@ -2731,6 +2731,19 @@ public class ServerHttpSecurity {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Specifies if {@link CsrfWebFilter} should try to resolve the actual CSRF token from the body of multipart
|
||||
* data requests.
|
||||
*
|
||||
* @param enabled true if should read from multipart form body, else false. Default is false
|
||||
* @return the {@link CsrfSpec} for additional configuration
|
||||
*/
|
||||
public CsrfSpec tokenFromMultipartDataEnabled(boolean enabled) {
|
||||
this.filter.setTokenFromMultipartDataEnabled(enabled);
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Allows method chaining to continue configuring the {@link ServerHttpSecurity}
|
||||
* @return the {@link ServerHttpSecurity} to continue configuring
|
||||
|
|
|
@ -210,6 +210,7 @@ dependencyManagement {
|
|||
dependency 'org.slf4j:slf4j-nop:1.7.28'
|
||||
dependency 'org.sonatype.sisu.inject:cglib:2.2.1-v20090111'
|
||||
dependency 'org.springframework.ldap:spring-ldap-core:2.3.2.RELEASE'
|
||||
dependency 'org.synchronoss.cloud:nio-multipart-parser:1.1.0'
|
||||
dependency 'org.thymeleaf:thymeleaf-spring5:3.0.11.RELEASE'
|
||||
dependency 'org.unbescape:unbescape:1.1.5.RELEASE'
|
||||
dependency 'org.w3c.css:sac:1.3'
|
||||
|
|
|
@ -25,6 +25,7 @@ dependencies {
|
|||
testCompile 'org.codehaus.groovy:groovy-all'
|
||||
testCompile 'org.skyscreamer:jsonassert'
|
||||
testCompile 'org.springframework:spring-webflux'
|
||||
testCompile 'org.synchronoss.cloud:nio-multipart-parser'
|
||||
testCompile powerMock2Dependencies
|
||||
testCompile spockDependencies
|
||||
|
||||
|
|
|
@ -16,14 +16,12 @@
|
|||
|
||||
package org.springframework.security.web.server.csrf;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.codec.multipart.FormFieldPart;
|
||||
import org.springframework.http.server.reactive.ServerHttpRequest;
|
||||
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
|
||||
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
|
||||
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
|
||||
|
@ -31,6 +29,11 @@ import org.springframework.util.Assert;
|
|||
import org.springframework.web.server.ServerWebExchange;
|
||||
import org.springframework.web.server.WebFilter;
|
||||
import org.springframework.web.server.WebFilterChain;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
import static java.lang.Boolean.TRUE;
|
||||
|
||||
|
@ -78,6 +81,8 @@ public class CsrfWebFilter implements WebFilter {
|
|||
|
||||
private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(HttpStatus.FORBIDDEN);
|
||||
|
||||
private boolean isTokenFromMultipartDataEnabled;
|
||||
|
||||
public void setAccessDeniedHandler(
|
||||
ServerAccessDeniedHandler accessDeniedHandler) {
|
||||
Assert.notNull(accessDeniedHandler, "accessDeniedHandler");
|
||||
|
@ -96,6 +101,15 @@ public class CsrfWebFilter implements WebFilter {
|
|||
this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher;
|
||||
}
|
||||
|
||||
/**
|
||||
* Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token from the body of multipart
|
||||
* data requests.
|
||||
* @param tokenFromMultipartDataEnabled true if should read from multipart form body, else false. Default is false
|
||||
*/
|
||||
public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) {
|
||||
this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
|
||||
if (TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) {
|
||||
|
@ -128,9 +142,26 @@ public class CsrfWebFilter implements WebFilter {
|
|||
return exchange.getFormData()
|
||||
.flatMap(data -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
|
||||
.switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
|
||||
.switchIfEmpty(tokenFromMultipartData(exchange, expected))
|
||||
.map(actual -> actual.equals(expected.getToken()));
|
||||
}
|
||||
|
||||
private Mono<String> tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
|
||||
if (!this.isTokenFromMultipartDataEnabled) {
|
||||
return Mono.empty();
|
||||
}
|
||||
ServerHttpRequest request = exchange.getRequest();
|
||||
HttpHeaders headers = request.getHeaders();
|
||||
MediaType contentType = headers.getContentType();
|
||||
if (!contentType.includes(MediaType.MULTIPART_FORM_DATA)) {
|
||||
return Mono.empty();
|
||||
}
|
||||
return exchange.getMultipartData()
|
||||
.map(d -> d.getFirst(expected.getParameterName()))
|
||||
.cast(FormFieldPart.class)
|
||||
.map(FormFieldPart::value);
|
||||
}
|
||||
|
||||
private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
|
||||
return Mono.defer(() ->{
|
||||
Mono<CsrfToken> csrfToken = csrfToken(exchange);
|
||||
|
|
|
@ -20,17 +20,20 @@ import org.junit.Test;
|
|||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.test.StepVerifier;
|
||||
import reactor.test.publisher.PublisherProbe;
|
||||
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
|
||||
import org.springframework.mock.web.server.MockServerWebExchange;
|
||||
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
|
||||
import org.springframework.test.web.reactive.server.WebTestClient;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
import org.springframework.web.server.ServerWebExchange;
|
||||
import org.springframework.web.server.WebFilterChain;
|
||||
import org.springframework.web.server.WebSession;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.test.StepVerifier;
|
||||
import reactor.test.publisher.PublisherProbe;
|
||||
|
||||
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
|
@ -38,6 +41,7 @@ import static org.mockito.Mockito.mock;
|
|||
import static org.mockito.Mockito.verifyZeroInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.mock.web.server.MockServerWebExchange.from;
|
||||
import static org.springframework.web.reactive.function.BodyInserters.fromMultipartData;
|
||||
|
||||
/**
|
||||
* @author Rob Winch
|
||||
|
@ -57,7 +61,7 @@ public class CsrfWebFilterTests {
|
|||
private MockServerWebExchange get = from(
|
||||
MockServerHttpRequest.get("/"));
|
||||
|
||||
private MockServerWebExchange post = from(
|
||||
private ServerWebExchange post = from(
|
||||
MockServerHttpRequest.post("/"));
|
||||
|
||||
@Test
|
||||
|
@ -193,4 +197,91 @@ public class CsrfWebFilterTests {
|
|||
|
||||
verifyZeroInteractions(matcher);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void filterWhenMultipartFormDataAndNotEnabledThenDenied() {
|
||||
this.csrfFilter.setCsrfTokenRepository(this.repository);
|
||||
when(this.repository.loadToken(any()))
|
||||
.thenReturn(Mono.just(this.token));
|
||||
|
||||
WebTestClient client = WebTestClient.bindToController(new OkController())
|
||||
.webFilter(this.csrfFilter)
|
||||
.build();
|
||||
|
||||
client.post()
|
||||
.uri("/")
|
||||
.contentType(MediaType.MULTIPART_FORM_DATA)
|
||||
.body(fromMultipartData(this.token.getParameterName(), this.token.getToken()))
|
||||
.exchange()
|
||||
.expectStatus().isForbidden();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void filterWhenMultipartFormDataAndEnabledThenGranted() {
|
||||
this.csrfFilter.setCsrfTokenRepository(this.repository);
|
||||
this.csrfFilter.setTokenFromMultipartDataEnabled(true);
|
||||
when(this.repository.loadToken(any()))
|
||||
.thenReturn(Mono.just(this.token));
|
||||
when(this.repository.generateToken(any()))
|
||||
.thenReturn(Mono.just(this.token));
|
||||
|
||||
WebTestClient client = WebTestClient.bindToController(new OkController())
|
||||
.webFilter(this.csrfFilter)
|
||||
.build();
|
||||
|
||||
client.post()
|
||||
.uri("/")
|
||||
.contentType(MediaType.MULTIPART_FORM_DATA)
|
||||
.body(fromMultipartData(this.token.getParameterName(), this.token.getToken()))
|
||||
.exchange()
|
||||
.expectStatus().is2xxSuccessful();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void filterWhenFormDataAndEnabledThenGranted() {
|
||||
this.csrfFilter.setCsrfTokenRepository(this.repository);
|
||||
this.csrfFilter.setTokenFromMultipartDataEnabled(true);
|
||||
when(this.repository.loadToken(any()))
|
||||
.thenReturn(Mono.just(this.token));
|
||||
when(this.repository.generateToken(any()))
|
||||
.thenReturn(Mono.just(this.token));
|
||||
|
||||
WebTestClient client = WebTestClient.bindToController(new OkController())
|
||||
.webFilter(this.csrfFilter)
|
||||
.build();
|
||||
|
||||
client.post()
|
||||
.uri("/")
|
||||
.contentType(MediaType.APPLICATION_FORM_URLENCODED)
|
||||
.bodyValue(this.token.getParameterName() + "="+this.token.getToken())
|
||||
.exchange()
|
||||
.expectStatus().is2xxSuccessful();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void filterWhenMultipartMixedAndEnabledThenNotRead() {
|
||||
this.csrfFilter.setCsrfTokenRepository(this.repository);
|
||||
this.csrfFilter.setTokenFromMultipartDataEnabled(true);
|
||||
when(this.repository.loadToken(any()))
|
||||
.thenReturn(Mono.just(this.token));
|
||||
|
||||
WebTestClient client = WebTestClient.bindToController(new OkController())
|
||||
.webFilter(this.csrfFilter)
|
||||
.build();
|
||||
|
||||
client.post()
|
||||
.uri("/")
|
||||
.contentType(MediaType.MULTIPART_MIXED)
|
||||
.bodyValue(this.token.getParameterName() + "="+this.token.getToken())
|
||||
.exchange()
|
||||
.expectStatus().isForbidden();
|
||||
}
|
||||
|
||||
@RestController
|
||||
static class OkController {
|
||||
@RequestMapping("/**")
|
||||
String ok() {
|
||||
return "ok";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue