Default to Xor CSRF tokens in CsrfWebFilter

Closes gh-11960
This commit is contained in:
Steve Riesenberg 2022-10-12 20:48:29 -05:00
parent 2a2051cd7b
commit 2407d07890
No known key found for this signature in database
GPG Key ID: 5F311AB48A55D521
3 changed files with 33 additions and 18 deletions

View File

@ -27,6 +27,8 @@ import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType
import org.springframework.mock.http.server.reactive.MockServerHttpRequest
import org.springframework.mock.web.server.MockServerWebExchange
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity
import org.springframework.security.config.test.SpringTestContext
import org.springframework.security.config.test.SpringTestContextExtension
@ -39,6 +41,7 @@ import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository
import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestAttributeHandler
import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestHandler
import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository
import org.springframework.security.web.server.csrf.XorServerCsrfTokenRequestAttributeHandler
import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher
import org.springframework.test.web.reactive.server.WebTestClient
import org.springframework.web.bind.annotation.PostMapping
@ -278,14 +281,23 @@ class ServerCsrfDslTests {
MultipartFormDataEnabledConfig.TOKEN_REPOSITORY.generateToken(any())
} returns Mono.just(this.token)
val csrfToken = createXorCsrfToken()
this.client.post()
.uri("/")
.contentType(MediaType.MULTIPART_FORM_DATA)
.body(fromMultipartData(this.token.parameterName, this.token.token))
.body(fromMultipartData(csrfToken.parameterName, csrfToken.token))
.exchange()
.expectStatus().isOk
}
private fun createXorCsrfToken(): CsrfToken {
val handler = XorServerCsrfTokenRequestAttributeHandler()
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"))
handler.handle(exchange, Mono.just(this.token))
val deferredCsrfToken: Mono<CsrfToken>? = exchange.getAttribute(CsrfToken::class.java.name)
return deferredCsrfToken?.block()!!
}
@Configuration
@EnableWebFluxSecurity
@EnableWebFlux

View File

@ -83,7 +83,7 @@ public class CsrfWebFilter implements WebFilter {
private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(
HttpStatus.FORBIDDEN);
private ServerCsrfTokenRequestHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler();
private ServerCsrfTokenRequestHandler requestHandler = new XorServerCsrfTokenRequestAttributeHandler();
public void setAccessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) {
Assert.notNull(accessDeniedHandler, "accessDeniedHandler");

View File

@ -125,9 +125,10 @@ public class CsrfWebFilterTests {
this.csrfFilter.setCsrfTokenRepository(this.repository);
given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
CsrfToken csrfToken = createXorCsrfToken();
this.post = MockServerWebExchange
.from(MockServerHttpRequest.post("/").contentType(MediaType.APPLICATION_FORM_URLENCODED)
.body(this.token.getParameterName() + "=" + this.token.getToken()));
.body(csrfToken.getParameterName() + "=" + csrfToken.getToken()));
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
StepVerifier.create(result).verifyComplete();
chainResult.assertWasSubscribed();
@ -151,8 +152,9 @@ public class CsrfWebFilterTests {
this.csrfFilter.setCsrfTokenRepository(this.repository);
given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
CsrfToken csrfToken = createXorCsrfToken();
this.post = MockServerWebExchange
.from(MockServerHttpRequest.post("/").header(this.token.getHeaderName(), this.token.getToken()));
.from(MockServerHttpRequest.post("/").header(csrfToken.getHeaderName(), csrfToken.getToken()));
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
StepVerifier.create(result).verifyComplete();
chainResult.assertWasSubscribed();
@ -181,30 +183,22 @@ public class CsrfWebFilterTests {
}
@Test
public void filterWhenXorServerCsrfTokenRequestProcessorAndValidTokenThenSuccess() {
public void filterWhenXorServerCsrfTokenRequestAttributeHandlerAndValidTokenThenSuccess() {
PublisherProbe<Void> chainResult = PublisherProbe.empty();
given(this.chain.filter(any())).willReturn(chainResult.mono());
this.csrfFilter.setCsrfTokenRepository(this.repository);
given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
XorServerCsrfTokenRequestAttributeHandler requestHandler = new XorServerCsrfTokenRequestAttributeHandler();
this.csrfFilter.setRequestHandler(requestHandler);
StepVerifier.create(this.csrfFilter.filter(this.get, this.chain)).verifyComplete();
chainResult.assertWasSubscribed();
Mono<CsrfToken> csrfTokenAttribute = this.get.getAttribute(CsrfToken.class.getName());
assertThat(csrfTokenAttribute).isNotNull();
StepVerifier.create(csrfTokenAttribute)
.consumeNextWith((csrfToken) -> this.post = MockServerWebExchange
.from(MockServerHttpRequest.post("/").header(csrfToken.getHeaderName(), csrfToken.getToken())))
.verifyComplete();
CsrfToken csrfToken = createXorCsrfToken();
this.post = MockServerWebExchange
.from(MockServerHttpRequest.post("/").header(csrfToken.getHeaderName(), csrfToken.getToken()));
StepVerifier.create(this.csrfFilter.filter(this.post, this.chain)).verifyComplete();
chainResult.assertWasSubscribed();
}
@Test
public void filterWhenXorServerCsrfTokenRequestProcessorAndRawTokenThenAccessDeniedException() {
public void filterWhenXorServerCsrfTokenRequestAttributeHandlerAndRawTokenThenAccessDeniedException() {
PublisherProbe<Void> chainResult = PublisherProbe.empty();
this.csrfFilter.setCsrfTokenRepository(this.repository);
given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
@ -305,6 +299,7 @@ public class CsrfWebFilterTests {
}
// gh-9561
@Test
public void doFilterWhenTokenIsNullThenNoNullPointer() {
this.csrfFilter.setCsrfTokenRepository(this.repository);
@ -318,8 +313,8 @@ public class CsrfWebFilterTests {
.bodyValue(this.token.getParameterName() + "=" + this.token.getToken()).exchange().expectStatus()
.isForbidden();
}
// gh-9113
@Test
public void filterWhenSubscribingCsrfTokenMultipleTimesThenGenerateOnlyOnce() {
PublisherProbe<CsrfToken> chainResult = PublisherProbe.empty();
@ -334,6 +329,14 @@ public class CsrfWebFilterTests {
assertThat(chainResult.subscribeCount()).isEqualTo(1);
}
private CsrfToken createXorCsrfToken() {
ServerCsrfTokenRequestHandler handler = new XorServerCsrfTokenRequestAttributeHandler();
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"));
handler.handle(exchange, Mono.just(this.token));
Mono<CsrfToken> csrfToken = exchange.getAttribute(CsrfToken.class.getName());
return csrfToken.block();
}
@RestController
static class OkController {