From 8da2c7f657a4884718489572122b02af22ac61b7 Mon Sep 17 00:00:00 2001
From: Rob Winch
Date: Sat, 28 Oct 2017 20:27:57 -0500
Subject: [PATCH] Add WebFlux CSRF Protection
Fixes gh-4734
---
.../web/server/SecurityWebFiltersOrder.java | 4 +
.../config/web/server/ServerHttpSecurity.java | 62 ++++++
.../reactive/EnableWebFluxSecurityTests.java | 2 +
.../server/AuthorizeExchangeBuilderTests.java | 2 +
.../server/SecurityMockServerConfigurers.java | 33 ++++
.../SecurityMockServerConfigurersTests.java | 38 +++-
.../AuthenticationWebFilter.java | 1 +
.../web/server/csrf/CsrfException.java | 33 ++++
.../security/web/server/csrf/CsrfToken.java | 48 +++++
.../web/server/csrf/CsrfWebFilter.java | 140 +++++++++++++
.../web/server/csrf/DefaultCsrfToken.java | 77 ++++++++
.../csrf/ServerCsrfTokenRepository.java | 58 ++++++
.../WebSessionServerCsrfTokenRepository.java | 122 ++++++++++++
.../ui/LoginPageGeneratingWebFilter.java | 38 ++--
.../web/server/csrf/CsrfWebFilterTests.java | 185 ++++++++++++++++++
...SessionServerCsrfTokenRepositoryTests.java | 112 +++++++++++
16 files changed, 943 insertions(+), 12 deletions(-)
create mode 100644 web/src/main/java/org/springframework/security/web/server/csrf/CsrfException.java
create mode 100644 web/src/main/java/org/springframework/security/web/server/csrf/CsrfToken.java
create mode 100644 web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java
create mode 100644 web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java
create mode 100644 web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRepository.java
create mode 100644 web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java
create mode 100644 web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java
create mode 100644 web/src/test/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepositoryTests.java
diff --git a/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java b/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java
index c13a4f68bb..54ada6f56a 100644
--- a/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java
+++ b/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java
@@ -23,6 +23,10 @@ package org.springframework.security.config.web.server;
public enum SecurityWebFiltersOrder {
FIRST(Integer.MIN_VALUE),
HTTP_HEADERS_WRITER,
+ /**
+ * {@link org.springframework.security.web.server.csrf.CsrfWebFilter}
+ */
+ CSRF,
/**
* Instance of AuthenticationWebFilter
*/
diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java
index 389b75d79f..e9d5a6a9a0 100644
--- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java
+++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java
@@ -44,11 +44,14 @@ import org.springframework.security.web.server.authorization.AuthorizationContex
import org.springframework.security.web.server.authorization.AuthorizationWebFilter;
import org.springframework.security.web.server.authorization.DelegatingReactiveAuthorizationManager;
import org.springframework.security.web.server.authorization.ExceptionTranslationWebFilter;
+import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
import org.springframework.security.web.server.context.ReactorContextWebFilter;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.context.NoOpServerSecurityContextRepository;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
+import org.springframework.security.web.server.csrf.CsrfWebFilter;
+import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
import org.springframework.security.web.server.header.CacheControlServerHttpHeadersWriter;
import org.springframework.security.web.server.header.CompositeServerHttpHeadersWriter;
import org.springframework.security.web.server.header.ContentTypeOptionsServerHttpHeadersWriter;
@@ -90,6 +93,8 @@ public class ServerHttpSecurity {
private HeaderBuilder headers;
+ private CsrfBuilder csrf = new CsrfBuilder();
+
private HttpBasicBuilder httpBasic;
private FormLoginBuilder formLogin;
@@ -139,6 +144,13 @@ public class ServerHttpSecurity {
return this;
}
+ public CsrfBuilder csrf() {
+ if(this.csrf == null) {
+ this.csrf = new CsrfBuilder();
+ }
+ return this.csrf;
+ }
+
public HttpBasicBuilder httpBasic() {
if(this.httpBasic == null) {
this.httpBasic = new HttpBasicBuilder();
@@ -191,6 +203,9 @@ public class ServerHttpSecurity {
if(securityContextRepositoryWebFilter != null) {
this.webFilters.add(securityContextRepositoryWebFilter);
}
+ if(this.csrf != null) {
+ this.csrf.configure(this);
+ }
if(this.httpBasic != null) {
this.httpBasic.authenticationManager(this.authenticationManager);
if(this.serverSecurityContextRepository != null) {
@@ -340,6 +355,53 @@ public class ServerHttpSecurity {
}
}
+ /**
+ * @author Rob Winch
+ * @since 5.0
+ */
+ public class CsrfBuilder {
+ private CsrfWebFilter filter = new CsrfWebFilter();
+
+ public CsrfBuilder serverAccessDeniedHandler(
+ ServerAccessDeniedHandler serverAccessDeniedHandler) {
+ this.filter.setServerAccessDeniedHandler(serverAccessDeniedHandler);
+ return this;
+ }
+
+ public CsrfBuilder csrfTokenAttributeName(String csrfTokenAttributeName) {
+ Assert.notNull(csrfTokenAttributeName, "csrfTokenAttributeName cannot be null");
+ this.filter.setCsrfTokenAttributeName(csrfTokenAttributeName);
+ return this;
+ }
+
+ public CsrfBuilder serverCsrfTokenRepository(
+ ServerCsrfTokenRepository serverCsrfTokenRepository) {
+ this.filter.setServerCsrfTokenRepository(serverCsrfTokenRepository);
+ return this;
+ }
+
+ public CsrfBuilder requireCsrfProtectionMatcher(
+ ServerWebExchangeMatcher requireCsrfProtectionMatcher) {
+ this.filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher);
+ return this;
+ }
+
+ public ServerHttpSecurity and() {
+ return ServerHttpSecurity.this;
+ }
+
+ public ServerHttpSecurity disable() {
+ ServerHttpSecurity.this.csrf = null;
+ return ServerHttpSecurity.this;
+ }
+
+ protected void configure(ServerHttpSecurity http) {
+ http.addFilterAt(this.filter, SecurityWebFiltersOrder.CSRF);
+ }
+
+ private CsrfBuilder() {}
+ }
+
/**
* @author Rob Winch
* @since 5.0
diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurityTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurityTests.java
index 2bc73b356f..85eb16a3fa 100644
--- a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurityTests.java
+++ b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurityTests.java
@@ -55,6 +55,7 @@ import java.nio.charset.StandardCharsets;
import java.security.Principal;
import static org.assertj.core.api.Assertions.assertThat;
+import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.csrf;
import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.Credentials.basicAuthenticationCredentials;
import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.basicAuthentication;
@@ -213,6 +214,7 @@ public class EnableWebFluxSecurityTests {
data.add("username", "user");
data.add("password", "password");
client
+ .mutateWith(csrf())
.post()
.uri("/login")
.body(BodyInserters.fromFormData(data))
diff --git a/config/src/test/java/org/springframework/security/config/web/server/AuthorizeExchangeBuilderTests.java b/config/src/test/java/org/springframework/security/config/web/server/AuthorizeExchangeBuilderTests.java
index 6f31ea623b..75eb37474a 100644
--- a/config/src/test/java/org/springframework/security/config/web/server/AuthorizeExchangeBuilderTests.java
+++ b/config/src/test/java/org/springframework/security/config/web/server/AuthorizeExchangeBuilderTests.java
@@ -32,6 +32,7 @@ public class AuthorizeExchangeBuilderTests {
@Test
public void antMatchersWhenMethodAndPatternsThenDiscriminatesByMethod() {
this.http
+ .csrf().disable()
.authorizeExchange()
.pathMatchers(HttpMethod.POST, "/a", "/b").denyAll()
.anyExchange().permitAll();
@@ -63,6 +64,7 @@ public class AuthorizeExchangeBuilderTests {
@Test
public void antMatchersWhenPatternsThenAnyMethod() {
this.http
+ .csrf().disable()
.authorizeExchange()
.pathMatchers("/a", "/b").denyAll()
.anyExchange().permitAll();
diff --git a/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java b/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java
index 07c86e2de4..4fb903bc6f 100644
--- a/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java
+++ b/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java
@@ -26,6 +26,10 @@ import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
+import org.springframework.security.web.server.csrf.CsrfToken;
+import org.springframework.security.web.server.csrf.CsrfWebFilter;
+import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository;
+import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.test.web.reactive.server.MockServerConfigurer;
import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.test.web.reactive.server.WebTestClientConfigurer;
@@ -107,6 +111,35 @@ public class SecurityMockServerConfigurers {
return new UserExchangeMutator(username);
}
+ public static CsrfMutator csrf() {
+ return new CsrfMutator();
+ }
+
+ public static class CsrfMutator implements WebTestClientConfigurer, MockServerConfigurer {
+
+ @Override
+ public void afterConfigurerAdded(WebTestClient.Builder builder,
+ @Nullable WebHttpHandlerBuilder httpHandlerBuilder,
+ @Nullable ClientHttpConnector connector) {
+ CsrfWebFilter filter = new CsrfWebFilter();
+ filter.setRequireCsrfProtectionMatcher( e -> ServerWebExchangeMatcher.MatchResult.notMatch());
+ httpHandlerBuilder.filters( filters -> filters.add(0, filter));
+ }
+
+ @Override
+ public void afterConfigureAdded(
+ WebTestClient.MockServerSpec> serverSpec) {
+
+ }
+
+ @Override
+ public void beforeServerCreated(WebHttpHandlerBuilder builder) {
+
+ }
+
+ private CsrfMutator() {}
+ }
+
/**
* Updates the WebServerExchange using {@code {@link SecurityMockServerConfigurers#mockUser(UserDetails)}. Defaults to use a
* password of "password" and granted authorities of "ROLE_USER".
diff --git a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java
index d97f79683a..051b7705ed 100644
--- a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java
+++ b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java
@@ -18,15 +18,18 @@ package org.springframework.security.test.web.reactive.server;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
+import org.springframework.security.web.server.csrf.CsrfWebFilter;
import org.springframework.test.web.reactive.server.WebTestClient;
import java.security.Principal;
+import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.*;
/**
@@ -36,7 +39,7 @@ import static org.springframework.security.test.web.reactive.server.SecurityMock
public class SecurityMockServerConfigurersTests extends AbstractMockServerConfigurersTests {
WebTestClient client = WebTestClient
.bindToController(controller)
- .webFilter(new SecurityContextServerWebExchangeWebFilter())
+ .webFilter( new CsrfWebFilter(), new SecurityContextServerWebExchangeWebFilter())
.apply(springSecurity())
.configureClient()
.defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
@@ -144,4 +147,37 @@ public class SecurityMockServerConfigurersTests extends AbstractMockServerConfig
assertPrincipalCreatedFromUserDetails(actual, userBuilder.build());
}
+
+ @Test
+ public void csrfWhenMutateWithThenDisablesCsrf() {
+ this.client
+ .post()
+ .exchange()
+ .expectStatus().isEqualTo(HttpStatus.FORBIDDEN)
+ .expectBody().consumeWith( b -> assertThat(new String(b.getResponseBody())).contains("CSRF"));
+
+ this.client
+ .mutateWith(csrf())
+ .post()
+ .exchange()
+ .expectStatus().isOk();
+
+ }
+
+ @Test
+ public void csrfWhenGlobalThenDisablesCsrf() {
+ this.client = WebTestClient
+ .bindToController(this.controller)
+ .webFilter(new CsrfWebFilter())
+ .apply(springSecurity())
+ .apply(csrf())
+ .configureClient()
+ .build();
+
+ this.client
+ .get()
+ .exchange()
+ .expectStatus().isOk();
+
+ }
}
diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java b/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java
index b019018cf4..9c1a4c882d 100644
--- a/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java
+++ b/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java
@@ -52,6 +52,7 @@ public class AuthenticationWebFilter implements WebFilter {
private ServerSecurityContextRepository serverSecurityContextRepository = NoOpServerSecurityContextRepository.getInstance();
private ServerWebExchangeMatcher requiresAuthenticationMatcher = ServerWebExchangeMatchers.anyExchange();
+
public AuthenticationWebFilter(ReactiveAuthenticationManager authenticationManager) {
Assert.notNull(authenticationManager, "authenticationManager cannot be null");
this.authenticationManager = authenticationManager;
diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfException.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfException.java
new file mode 100644
index 0000000000..51f5c8c498
--- /dev/null
+++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfException.java
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.web.server.csrf;
+
+import org.springframework.security.access.AccessDeniedException;
+import org.springframework.security.web.csrf.CsrfToken;
+
+/**
+ * Thrown when an invalid or missing {@link CsrfToken} is found in the HttpServletRequest
+ *
+ * @author Rob Winch
+ * @since 3.2
+ */
+@SuppressWarnings("serial")
+public class CsrfException extends AccessDeniedException {
+
+ public CsrfException(String message) {
+ super(message);
+ }
+}
diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfToken.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfToken.java
new file mode 100644
index 0000000000..e4ac06dec4
--- /dev/null
+++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfToken.java
@@ -0,0 +1,48 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server.csrf;
+
+import java.io.Serializable;
+
+/**
+ * @author Rob Winch
+ * @since 5.0
+ */
+public interface CsrfToken extends Serializable {
+
+ /**
+ * Gets the HTTP header that the CSRF is populated on the response and can be placed
+ * on requests instead of the parameter. Cannot be null.
+ *
+ * @return the HTTP header that the CSRF is populated on the response and can be
+ * placed on requests instead of the parameter
+ */
+ String getHeaderName();
+
+ /**
+ * Gets the HTTP parameter name that should contain the token. Cannot be null.
+ * @return the HTTP parameter name that should contain the token.
+ */
+ String getParameterName();
+
+ /**
+ * Gets the token value. Cannot be null.
+ * @return the token value
+ */
+ String getToken();
+
+}
diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java
new file mode 100644
index 0000000000..e0a7613425
--- /dev/null
+++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java
@@ -0,0 +1,140 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server.csrf;
+
+import org.springframework.http.HttpMethod;
+import org.springframework.http.HttpStatus;
+import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
+import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
+import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
+import org.springframework.util.Assert;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebFilter;
+import org.springframework.web.server.WebFilterChain;
+import reactor.core.publisher.Mono;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
+
+/**
+ *
+ * Applies
+ * CSRF
+ * protection using a synchronizer token pattern. Developers are required to ensure that
+ * {@link CsrfWebFilter} is invoked for any request that allows state to change. Typically
+ * this just means that they should ensure their web application follows proper REST
+ * semantics (i.e. do not change state with the HTTP methods GET, HEAD, TRACE, OPTIONS).
+ *
+ *
+ *
+ * Typically the {@link ServerCsrfTokenRepository} implementation chooses to store the
+ * {@link CsrfToken} in {@link org.springframework.web.server.WebSession} with
+ * {@link WebSessionServerCsrfTokenRepository}. This is preferred to storing the token in
+ * a cookie which can be modified by a client application.
+ *
+ *
+ * @author Rob Winch
+ * @since 5.0
+ */
+public class CsrfWebFilter implements WebFilter {
+
+ private ServerWebExchangeMatcher requireCsrfProtectionMatcher = new DefaultRequireCsrfProtectionMatcher();
+
+ private ServerCsrfTokenRepository serverCsrfTokenRepository = new WebSessionServerCsrfTokenRepository();
+
+ private ServerAccessDeniedHandler serverAccessDeniedHandler = new HttpStatusServerAccessDeniedHandler(HttpStatus.FORBIDDEN);
+
+ private String csrfTokenAttributeName = "csrf";
+
+ public void setServerAccessDeniedHandler(
+ ServerAccessDeniedHandler serverAccessDeniedHandler) {
+ Assert.notNull(serverAccessDeniedHandler, "serverAccessDeniedHandler");
+ this.serverAccessDeniedHandler = serverAccessDeniedHandler;
+ }
+
+ public void setCsrfTokenAttributeName(String csrfTokenAttributeName) {
+ Assert.notNull(csrfTokenAttributeName, "csrfTokenAttributeName cannot be null");
+ this.csrfTokenAttributeName = csrfTokenAttributeName;
+ }
+
+ public void setServerCsrfTokenRepository(
+ ServerCsrfTokenRepository serverCsrfTokenRepository) {
+ Assert.notNull(serverCsrfTokenRepository, "serverCsrfTokenRepository cannot be null");
+ this.serverCsrfTokenRepository = serverCsrfTokenRepository;
+ }
+
+ public void setRequireCsrfProtectionMatcher(
+ ServerWebExchangeMatcher requireCsrfProtectionMatcher) {
+ Assert.notNull(requireCsrfProtectionMatcher, "requireCsrfProtectionMatcher cannot be null");
+ this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher;
+ }
+
+ @Override
+ public Mono filter(ServerWebExchange exchange, WebFilterChain chain) {
+ return this.requireCsrfProtectionMatcher.matches(exchange)
+ .filter( matchResult -> matchResult.isMatch())
+ .filter( matchResult -> !exchange.getAttributes().containsKey(CsrfToken.class.getName()))
+ .flatMap(m -> validateToken(exchange))
+ .flatMap(m -> continueFilterChain(exchange, chain))
+ .switchIfEmpty(continueFilterChain(exchange, chain).then(Mono.empty()))
+ .onErrorResume(CsrfException.class, e -> this.serverAccessDeniedHandler.handle(exchange, e));
+ }
+
+ private Mono validateToken(ServerWebExchange exchange) {
+ return this.serverCsrfTokenRepository.loadToken(exchange)
+ .switchIfEmpty(Mono.error(new CsrfException("CSRF Token has been associated to this client")))
+ .filterWhen(expected -> containsValidCsrfToken(exchange, expected))
+ .switchIfEmpty(Mono.error(new CsrfException("Invalid CSRF Token")))
+ .then();
+ }
+
+ private Mono containsValidCsrfToken(ServerWebExchange exchange, CsrfToken expected) {
+ return exchange.getFormData()
+ .flatMap(data -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
+ .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
+ .map(actual -> actual.equals(expected.getToken()));
+ }
+
+ private Mono continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
+ return csrfToken(exchange)
+ .doOnSuccess(csrfToken -> exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken))
+ .doOnSuccess(csrfToken -> exchange.getAttributes().put(this.csrfTokenAttributeName, csrfToken))
+ .flatMap( t -> chain.filter(exchange))
+ .then();
+ }
+
+ private Mono> csrfToken(ServerWebExchange exchange) {
+ return this.serverCsrfTokenRepository.loadToken(exchange)
+ .switchIfEmpty(this.serverCsrfTokenRepository.generateToken(exchange))
+ .as(Mono::just); // FIXME eager saving of CsrfToken with .as
+ }
+
+ private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher {
+ private static final Set ALLOWED_METHODS = new HashSet<>(
+ Arrays.asList(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.TRACE, HttpMethod.OPTIONS));
+
+ @Override
+ public Mono matches(ServerWebExchange exchange) {
+ return Mono.just(exchange.getRequest())
+ .map(r -> r.getMethod())
+ .filter(m -> ALLOWED_METHODS.contains(m))
+ .flatMap(m -> MatchResult.notMatch())
+ .switchIfEmpty(MatchResult.match());
+ }
+ }
+}
diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java b/web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java
new file mode 100644
index 0000000000..0e75316ebb
--- /dev/null
+++ b/web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java
@@ -0,0 +1,77 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.web.server.csrf;
+
+import org.springframework.util.Assert;
+
+/**
+ * A CSRF token that is used to protect against CSRF attacks.
+ *
+ * @author Rob Winch
+ * @since 5.0
+ */
+@SuppressWarnings("serial")
+public final class DefaultCsrfToken implements CsrfToken {
+
+ private final String token;
+
+ private final String parameterName;
+
+ private final String headerName;
+
+ /**
+ * Creates a new instance
+ * @param headerName the HTTP header name to use
+ * @param parameterName the HTTP parameter name to use
+ * @param token the value of the token (i.e. expected value of the HTTP parameter of
+ * parametername).
+ */
+ public DefaultCsrfToken(String headerName, String parameterName, String token) {
+ Assert.hasLength(headerName, "headerName cannot be null or empty");
+ Assert.hasLength(parameterName, "parameterName cannot be null or empty");
+ Assert.hasLength(token, "token cannot be null or empty");
+ this.headerName = headerName;
+ this.parameterName = parameterName;
+ this.token = token;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.springframework.security.web.csrf.CsrfToken#getHeaderName()
+ */
+ public String getHeaderName() {
+ return this.headerName;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.springframework.security.web.csrf.CsrfToken#getParameterName()
+ */
+ public String getParameterName() {
+ return this.parameterName;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.springframework.security.web.csrf.CsrfToken#getToken()
+ */
+ public String getToken() {
+ return this.token;
+ }
+}
diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRepository.java
new file mode 100644
index 0000000000..b09ee4ec31
--- /dev/null
+++ b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRepository.java
@@ -0,0 +1,58 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.web.server.csrf;
+
+import org.springframework.web.server.ServerWebExchange;
+import reactor.core.publisher.Mono;
+
+/**
+ * An API to allow changing the method in which the expected {@link CsrfToken} is
+ * associated to the {@link ServerWebExchange}. For example, it may be stored in
+ * {@link org.springframework.web.server.WebSession}.
+ *
+ * @see WebSessionServerCsrfTokenRepository
+ *
+ * @author Rob Winch
+ * @since 5.0
+ *
+ */
+public interface ServerCsrfTokenRepository {
+
+ /**
+ * Generates a {@link CsrfToken}
+ *
+ * @param exchange the {@link ServerWebExchange} to use
+ * @return the {@link CsrfToken} that was generated. Cannot be null.
+ */
+ Mono generateToken(ServerWebExchange exchange);
+
+ /**
+ * Saves the {@link CsrfToken} using the {@link ServerWebExchange}. If the
+ * {@link CsrfToken} is null, it is the same as deleting it.
+ *
+ * @param exchange the {@link ServerWebExchange} to use
+ * @param token the {@link CsrfToken} to save or null to delete
+ */
+ Mono saveToken(ServerWebExchange exchange, CsrfToken token);
+
+ /**
+ * Loads the expected {@link CsrfToken} from the {@link ServerWebExchange}
+ *
+ * @param exchange the {@link ServerWebExchange} to use
+ * @return the {@link CsrfToken} or null if none exists
+ */
+ Mono loadToken(ServerWebExchange exchange);
+}
diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java
new file mode 100644
index 0000000000..da4281f858
--- /dev/null
+++ b/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java
@@ -0,0 +1,122 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.web.server.csrf;
+
+import org.springframework.util.Assert;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebSession;
+import reactor.core.publisher.Mono;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpSession;
+import java.util.Map;
+import java.util.UUID;
+
+/**
+ * A {@link ServerCsrfTokenRepository} that stores the {@link CsrfToken} in the
+ * {@link HttpSession}.
+ *
+ * @author Rob Winch
+ * @since 5.0
+ */
+public class WebSessionServerCsrfTokenRepository
+ implements ServerCsrfTokenRepository {
+ private static final String DEFAULT_CSRF_PARAMETER_NAME = "_csrf";
+
+ private static final String DEFAULT_CSRF_HEADER_NAME = "X-CSRF-TOKEN";
+
+ private static final String DEFAULT_CSRF_TOKEN_ATTR_NAME = WebSessionServerCsrfTokenRepository.class
+ .getName().concat(".CSRF_TOKEN");
+
+ private String parameterName = DEFAULT_CSRF_PARAMETER_NAME;
+
+ private String headerName = DEFAULT_CSRF_HEADER_NAME;
+
+ private String sessionAttributeName = DEFAULT_CSRF_TOKEN_ATTR_NAME;
+
+ @Override
+ public Mono generateToken(ServerWebExchange exchange) {
+ return Mono.defer(() -> Mono.just(createCsrfToken()))
+ .flatMap(token -> save(exchange, token));
+ }
+
+ @Override
+ public Mono saveToken(ServerWebExchange exchange, CsrfToken token) {
+ return save(exchange, token)
+ .then();
+ }
+
+ private Mono save(ServerWebExchange exchange, CsrfToken token) {
+ return exchange.getSession()
+ .map(WebSession::getAttributes)
+ .flatMap( attrs -> save(attrs, token));
+ }
+
+ private Mono save(Map attributes, CsrfToken token) {
+ if(token == null) {
+ attributes.remove(this.sessionAttributeName);
+ } else {
+ attributes.put(this.sessionAttributeName, token);
+ }
+ return Mono.justOrEmpty(token);
+ }
+
+ @Override
+ public Mono loadToken(ServerWebExchange exchange) {
+ return exchange.getSession()
+ .filter( s -> s.getAttributes().containsKey(this.sessionAttributeName))
+ .map(s -> s.getAttribute(this.sessionAttributeName));
+ }
+
+ /**
+ * Sets the {@link HttpServletRequest} parameter name that the {@link CsrfToken} is
+ * expected to appear on
+ * @param parameterName the new parameter name to use
+ */
+ public void setParameterName(String parameterName) {
+ Assert.hasLength(parameterName, "parameterName cannot be null or empty");
+ this.parameterName = parameterName;
+ }
+
+ /**
+ * Sets the header name that the {@link CsrfToken} is expected to appear on and the
+ * header that the response will contain the {@link CsrfToken}.
+ *
+ * @param headerName the new header name to use
+ */
+ public void setHeaderName(String headerName) {
+ Assert.hasLength(headerName, "headerName cannot be null or empty");
+ this.headerName = headerName;
+ }
+
+ /**
+ * Sets the {@link HttpSession} attribute name that the {@link CsrfToken} is stored in
+ * @param sessionAttributeName the new attribute name to use
+ */
+ public void setSessionAttributeName(String sessionAttributeName) {
+ Assert.hasLength(sessionAttributeName,
+ "sessionAttributename cannot be null or empty");
+ this.sessionAttributeName = sessionAttributeName;
+ }
+
+ private CsrfToken createCsrfToken() {
+ return new DefaultCsrfToken(this.headerName, this.parameterName, createNewToken());
+ }
+
+ private String createNewToken() {
+ return UUID.randomUUID().toString();
+ }
+}
diff --git a/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java b/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java
index 2850105238..78d8a463a0 100644
--- a/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java
+++ b/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java
@@ -23,6 +23,7 @@ import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpResponse;
+import org.springframework.security.web.server.csrf.CsrfToken;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
import org.springframework.util.MultiValueMap;
@@ -50,21 +51,31 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
}
private Mono render(ServerWebExchange exchange) {
- MultiValueMap queryParams = exchange.getRequest()
- .getQueryParams();
- boolean isError = queryParams.containsKey("error");
- boolean isLogoutSuccess = queryParams.containsKey("logout");
ServerHttpResponse result = exchange.getResponse();
- result.setStatusCode(HttpStatus.FOUND);
+ result.setStatusCode(HttpStatus.OK);
result.getHeaders().setContentType(MediaType.TEXT_HTML);
- byte[] bytes = createPage(isError, isLogoutSuccess);
- DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory();
- DataBuffer buffer = bufferFactory.wrap(bytes);
- return result.writeWith(Mono.just(buffer))
- .doOnError( error -> DataBufferUtils.release(buffer));
+ return result.writeWith(createBuffer(exchange));
+// .doOnError( error -> DataBufferUtils.release(buffer));
}
- private static byte[] createPage(boolean isError, boolean isLogoutSuccess) {
+ private Mono createBuffer(ServerWebExchange exchange) {
+ MultiValueMap queryParams = exchange.getRequest()
+ .getQueryParams();
+ Mono token = (Mono) exchange.getAttributes()
+ .getOrDefault(CsrfToken.class.getName(), Mono.empty());
+ return token
+ .map(LoginPageGeneratingWebFilter::csrfToken)
+ .defaultIfEmpty("")
+ .map(csrfTokenHtmlInput -> {
+ boolean isError = queryParams.containsKey("error");
+ boolean isLogoutSuccess = queryParams.containsKey("logout");
+ byte[] bytes = createPage(isError, isLogoutSuccess, csrfTokenHtmlInput);
+ DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory();
+ return bufferFactory.wrap(bytes);
+ });
+ }
+
+ private static byte[] createPage(boolean isError, boolean isLogoutSuccess, String csrfTokenHtmlInput) {
String page = "\n"
+ "\n"
+ " \n"
@@ -90,6 +101,7 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
+ " \n"
+ " \n"
+ "