diff --git a/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java b/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java index c13a4f68bb..54ada6f56a 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java +++ b/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java @@ -23,6 +23,10 @@ package org.springframework.security.config.web.server; public enum SecurityWebFiltersOrder { FIRST(Integer.MIN_VALUE), HTTP_HEADERS_WRITER, + /** + * {@link org.springframework.security.web.server.csrf.CsrfWebFilter} + */ + CSRF, /** * Instance of AuthenticationWebFilter */ diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 389b75d79f..e9d5a6a9a0 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -44,11 +44,14 @@ import org.springframework.security.web.server.authorization.AuthorizationContex import org.springframework.security.web.server.authorization.AuthorizationWebFilter; import org.springframework.security.web.server.authorization.DelegatingReactiveAuthorizationManager; import org.springframework.security.web.server.authorization.ExceptionTranslationWebFilter; +import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler; import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; import org.springframework.security.web.server.context.ReactorContextWebFilter; import org.springframework.security.web.server.context.ServerSecurityContextRepository; import org.springframework.security.web.server.context.NoOpServerSecurityContextRepository; import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; +import org.springframework.security.web.server.csrf.CsrfWebFilter; +import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository; import org.springframework.security.web.server.header.CacheControlServerHttpHeadersWriter; import org.springframework.security.web.server.header.CompositeServerHttpHeadersWriter; import org.springframework.security.web.server.header.ContentTypeOptionsServerHttpHeadersWriter; @@ -90,6 +93,8 @@ public class ServerHttpSecurity { private HeaderBuilder headers; + private CsrfBuilder csrf = new CsrfBuilder(); + private HttpBasicBuilder httpBasic; private FormLoginBuilder formLogin; @@ -139,6 +144,13 @@ public class ServerHttpSecurity { return this; } + public CsrfBuilder csrf() { + if(this.csrf == null) { + this.csrf = new CsrfBuilder(); + } + return this.csrf; + } + public HttpBasicBuilder httpBasic() { if(this.httpBasic == null) { this.httpBasic = new HttpBasicBuilder(); @@ -191,6 +203,9 @@ public class ServerHttpSecurity { if(securityContextRepositoryWebFilter != null) { this.webFilters.add(securityContextRepositoryWebFilter); } + if(this.csrf != null) { + this.csrf.configure(this); + } if(this.httpBasic != null) { this.httpBasic.authenticationManager(this.authenticationManager); if(this.serverSecurityContextRepository != null) { @@ -340,6 +355,53 @@ public class ServerHttpSecurity { } } + /** + * @author Rob Winch + * @since 5.0 + */ + public class CsrfBuilder { + private CsrfWebFilter filter = new CsrfWebFilter(); + + public CsrfBuilder serverAccessDeniedHandler( + ServerAccessDeniedHandler serverAccessDeniedHandler) { + this.filter.setServerAccessDeniedHandler(serverAccessDeniedHandler); + return this; + } + + public CsrfBuilder csrfTokenAttributeName(String csrfTokenAttributeName) { + Assert.notNull(csrfTokenAttributeName, "csrfTokenAttributeName cannot be null"); + this.filter.setCsrfTokenAttributeName(csrfTokenAttributeName); + return this; + } + + public CsrfBuilder serverCsrfTokenRepository( + ServerCsrfTokenRepository serverCsrfTokenRepository) { + this.filter.setServerCsrfTokenRepository(serverCsrfTokenRepository); + return this; + } + + public CsrfBuilder requireCsrfProtectionMatcher( + ServerWebExchangeMatcher requireCsrfProtectionMatcher) { + this.filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher); + return this; + } + + public ServerHttpSecurity and() { + return ServerHttpSecurity.this; + } + + public ServerHttpSecurity disable() { + ServerHttpSecurity.this.csrf = null; + return ServerHttpSecurity.this; + } + + protected void configure(ServerHttpSecurity http) { + http.addFilterAt(this.filter, SecurityWebFiltersOrder.CSRF); + } + + private CsrfBuilder() {} + } + /** * @author Rob Winch * @since 5.0 diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurityTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurityTests.java index 2bc73b356f..85eb16a3fa 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurityTests.java @@ -55,6 +55,7 @@ import java.nio.charset.StandardCharsets; import java.security.Principal; import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.csrf; import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.Credentials.basicAuthenticationCredentials; import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.basicAuthentication; @@ -213,6 +214,7 @@ public class EnableWebFluxSecurityTests { data.add("username", "user"); data.add("password", "password"); client + .mutateWith(csrf()) .post() .uri("/login") .body(BodyInserters.fromFormData(data)) diff --git a/config/src/test/java/org/springframework/security/config/web/server/AuthorizeExchangeBuilderTests.java b/config/src/test/java/org/springframework/security/config/web/server/AuthorizeExchangeBuilderTests.java index 6f31ea623b..75eb37474a 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/AuthorizeExchangeBuilderTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/AuthorizeExchangeBuilderTests.java @@ -32,6 +32,7 @@ public class AuthorizeExchangeBuilderTests { @Test public void antMatchersWhenMethodAndPatternsThenDiscriminatesByMethod() { this.http + .csrf().disable() .authorizeExchange() .pathMatchers(HttpMethod.POST, "/a", "/b").denyAll() .anyExchange().permitAll(); @@ -63,6 +64,7 @@ public class AuthorizeExchangeBuilderTests { @Test public void antMatchersWhenPatternsThenAnyMethod() { this.http + .csrf().disable() .authorizeExchange() .pathMatchers("/a", "/b").denyAll() .anyExchange().permitAll(); diff --git a/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java b/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java index 07c86e2de4..4fb903bc6f 100644 --- a/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java +++ b/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java @@ -26,6 +26,10 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.web.server.csrf.CsrfToken; +import org.springframework.security.web.server.csrf.CsrfWebFilter; +import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository; +import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.test.web.reactive.server.MockServerConfigurer; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.test.web.reactive.server.WebTestClientConfigurer; @@ -107,6 +111,35 @@ public class SecurityMockServerConfigurers { return new UserExchangeMutator(username); } + public static CsrfMutator csrf() { + return new CsrfMutator(); + } + + public static class CsrfMutator implements WebTestClientConfigurer, MockServerConfigurer { + + @Override + public void afterConfigurerAdded(WebTestClient.Builder builder, + @Nullable WebHttpHandlerBuilder httpHandlerBuilder, + @Nullable ClientHttpConnector connector) { + CsrfWebFilter filter = new CsrfWebFilter(); + filter.setRequireCsrfProtectionMatcher( e -> ServerWebExchangeMatcher.MatchResult.notMatch()); + httpHandlerBuilder.filters( filters -> filters.add(0, filter)); + } + + @Override + public void afterConfigureAdded( + WebTestClient.MockServerSpec serverSpec) { + + } + + @Override + public void beforeServerCreated(WebHttpHandlerBuilder builder) { + + } + + private CsrfMutator() {} + } + /** * Updates the WebServerExchange using {@code {@link SecurityMockServerConfigurers#mockUser(UserDetails)}. Defaults to use a * password of "password" and granted authorities of "ROLE_USER". diff --git a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java index d97f79683a..051b7705ed 100644 --- a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java +++ b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java @@ -18,15 +18,18 @@ package org.springframework.security.test.web.reactive.server; import org.junit.Test; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; +import org.springframework.security.web.server.csrf.CsrfWebFilter; import org.springframework.test.web.reactive.server.WebTestClient; import java.security.Principal; +import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.*; /** @@ -36,7 +39,7 @@ import static org.springframework.security.test.web.reactive.server.SecurityMock public class SecurityMockServerConfigurersTests extends AbstractMockServerConfigurersTests { WebTestClient client = WebTestClient .bindToController(controller) - .webFilter(new SecurityContextServerWebExchangeWebFilter()) + .webFilter( new CsrfWebFilter(), new SecurityContextServerWebExchangeWebFilter()) .apply(springSecurity()) .configureClient() .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) @@ -144,4 +147,37 @@ public class SecurityMockServerConfigurersTests extends AbstractMockServerConfig assertPrincipalCreatedFromUserDetails(actual, userBuilder.build()); } + + @Test + public void csrfWhenMutateWithThenDisablesCsrf() { + this.client + .post() + .exchange() + .expectStatus().isEqualTo(HttpStatus.FORBIDDEN) + .expectBody().consumeWith( b -> assertThat(new String(b.getResponseBody())).contains("CSRF")); + + this.client + .mutateWith(csrf()) + .post() + .exchange() + .expectStatus().isOk(); + + } + + @Test + public void csrfWhenGlobalThenDisablesCsrf() { + this.client = WebTestClient + .bindToController(this.controller) + .webFilter(new CsrfWebFilter()) + .apply(springSecurity()) + .apply(csrf()) + .configureClient() + .build(); + + this.client + .get() + .exchange() + .expectStatus().isOk(); + + } } diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java b/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java index b019018cf4..9c1a4c882d 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java @@ -52,6 +52,7 @@ public class AuthenticationWebFilter implements WebFilter { private ServerSecurityContextRepository serverSecurityContextRepository = NoOpServerSecurityContextRepository.getInstance(); private ServerWebExchangeMatcher requiresAuthenticationMatcher = ServerWebExchangeMatchers.anyExchange(); + public AuthenticationWebFilter(ReactiveAuthenticationManager authenticationManager) { Assert.notNull(authenticationManager, "authenticationManager cannot be null"); this.authenticationManager = authenticationManager; diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfException.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfException.java new file mode 100644 index 0000000000..51f5c8c498 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfException.java @@ -0,0 +1,33 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.server.csrf; + +import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.web.csrf.CsrfToken; + +/** + * Thrown when an invalid or missing {@link CsrfToken} is found in the HttpServletRequest + * + * @author Rob Winch + * @since 3.2 + */ +@SuppressWarnings("serial") +public class CsrfException extends AccessDeniedException { + + public CsrfException(String message) { + super(message); + } +} diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfToken.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfToken.java new file mode 100644 index 0000000000..e4ac06dec4 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfToken.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.server.csrf; + +import java.io.Serializable; + +/** + * @author Rob Winch + * @since 5.0 + */ +public interface CsrfToken extends Serializable { + + /** + * Gets the HTTP header that the CSRF is populated on the response and can be placed + * on requests instead of the parameter. Cannot be null. + * + * @return the HTTP header that the CSRF is populated on the response and can be + * placed on requests instead of the parameter + */ + String getHeaderName(); + + /** + * Gets the HTTP parameter name that should contain the token. Cannot be null. + * @return the HTTP parameter name that should contain the token. + */ + String getParameterName(); + + /** + * Gets the token value. Cannot be null. + * @return the token value + */ + String getToken(); + +} diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java new file mode 100644 index 0000000000..e0a7613425 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java @@ -0,0 +1,140 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.server.csrf; + +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler; +import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler; +import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebFilterChain; +import reactor.core.publisher.Mono; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +/** + *

+ * Applies + * CSRF + * protection using a synchronizer token pattern. Developers are required to ensure that + * {@link CsrfWebFilter} is invoked for any request that allows state to change. Typically + * this just means that they should ensure their web application follows proper REST + * semantics (i.e. do not change state with the HTTP methods GET, HEAD, TRACE, OPTIONS). + *

+ * + *

+ * Typically the {@link ServerCsrfTokenRepository} implementation chooses to store the + * {@link CsrfToken} in {@link org.springframework.web.server.WebSession} with + * {@link WebSessionServerCsrfTokenRepository}. This is preferred to storing the token in + * a cookie which can be modified by a client application. + *

+ * + * @author Rob Winch + * @since 5.0 + */ +public class CsrfWebFilter implements WebFilter { + + private ServerWebExchangeMatcher requireCsrfProtectionMatcher = new DefaultRequireCsrfProtectionMatcher(); + + private ServerCsrfTokenRepository serverCsrfTokenRepository = new WebSessionServerCsrfTokenRepository(); + + private ServerAccessDeniedHandler serverAccessDeniedHandler = new HttpStatusServerAccessDeniedHandler(HttpStatus.FORBIDDEN); + + private String csrfTokenAttributeName = "csrf"; + + public void setServerAccessDeniedHandler( + ServerAccessDeniedHandler serverAccessDeniedHandler) { + Assert.notNull(serverAccessDeniedHandler, "serverAccessDeniedHandler"); + this.serverAccessDeniedHandler = serverAccessDeniedHandler; + } + + public void setCsrfTokenAttributeName(String csrfTokenAttributeName) { + Assert.notNull(csrfTokenAttributeName, "csrfTokenAttributeName cannot be null"); + this.csrfTokenAttributeName = csrfTokenAttributeName; + } + + public void setServerCsrfTokenRepository( + ServerCsrfTokenRepository serverCsrfTokenRepository) { + Assert.notNull(serverCsrfTokenRepository, "serverCsrfTokenRepository cannot be null"); + this.serverCsrfTokenRepository = serverCsrfTokenRepository; + } + + public void setRequireCsrfProtectionMatcher( + ServerWebExchangeMatcher requireCsrfProtectionMatcher) { + Assert.notNull(requireCsrfProtectionMatcher, "requireCsrfProtectionMatcher cannot be null"); + this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher; + } + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + return this.requireCsrfProtectionMatcher.matches(exchange) + .filter( matchResult -> matchResult.isMatch()) + .filter( matchResult -> !exchange.getAttributes().containsKey(CsrfToken.class.getName())) + .flatMap(m -> validateToken(exchange)) + .flatMap(m -> continueFilterChain(exchange, chain)) + .switchIfEmpty(continueFilterChain(exchange, chain).then(Mono.empty())) + .onErrorResume(CsrfException.class, e -> this.serverAccessDeniedHandler.handle(exchange, e)); + } + + private Mono validateToken(ServerWebExchange exchange) { + return this.serverCsrfTokenRepository.loadToken(exchange) + .switchIfEmpty(Mono.error(new CsrfException("CSRF Token has been associated to this client"))) + .filterWhen(expected -> containsValidCsrfToken(exchange, expected)) + .switchIfEmpty(Mono.error(new CsrfException("Invalid CSRF Token"))) + .then(); + } + + private Mono containsValidCsrfToken(ServerWebExchange exchange, CsrfToken expected) { + return exchange.getFormData() + .flatMap(data -> Mono.justOrEmpty(data.getFirst(expected.getParameterName()))) + .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName()))) + .map(actual -> actual.equals(expected.getToken())); + } + + private Mono continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) { + return csrfToken(exchange) + .doOnSuccess(csrfToken -> exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken)) + .doOnSuccess(csrfToken -> exchange.getAttributes().put(this.csrfTokenAttributeName, csrfToken)) + .flatMap( t -> chain.filter(exchange)) + .then(); + } + + private Mono> csrfToken(ServerWebExchange exchange) { + return this.serverCsrfTokenRepository.loadToken(exchange) + .switchIfEmpty(this.serverCsrfTokenRepository.generateToken(exchange)) + .as(Mono::just); // FIXME eager saving of CsrfToken with .as + } + + private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher { + private static final Set ALLOWED_METHODS = new HashSet<>( + Arrays.asList(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.TRACE, HttpMethod.OPTIONS)); + + @Override + public Mono matches(ServerWebExchange exchange) { + return Mono.just(exchange.getRequest()) + .map(r -> r.getMethod()) + .filter(m -> ALLOWED_METHODS.contains(m)) + .flatMap(m -> MatchResult.notMatch()) + .switchIfEmpty(MatchResult.match()); + } + } +} diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java b/web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java new file mode 100644 index 0000000000..0e75316ebb --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.server.csrf; + +import org.springframework.util.Assert; + +/** + * A CSRF token that is used to protect against CSRF attacks. + * + * @author Rob Winch + * @since 5.0 + */ +@SuppressWarnings("serial") +public final class DefaultCsrfToken implements CsrfToken { + + private final String token; + + private final String parameterName; + + private final String headerName; + + /** + * Creates a new instance + * @param headerName the HTTP header name to use + * @param parameterName the HTTP parameter name to use + * @param token the value of the token (i.e. expected value of the HTTP parameter of + * parametername). + */ + public DefaultCsrfToken(String headerName, String parameterName, String token) { + Assert.hasLength(headerName, "headerName cannot be null or empty"); + Assert.hasLength(parameterName, "parameterName cannot be null or empty"); + Assert.hasLength(token, "token cannot be null or empty"); + this.headerName = headerName; + this.parameterName = parameterName; + this.token = token; + } + + /* + * (non-Javadoc) + * + * @see org.springframework.security.web.csrf.CsrfToken#getHeaderName() + */ + public String getHeaderName() { + return this.headerName; + } + + /* + * (non-Javadoc) + * + * @see org.springframework.security.web.csrf.CsrfToken#getParameterName() + */ + public String getParameterName() { + return this.parameterName; + } + + /* + * (non-Javadoc) + * + * @see org.springframework.security.web.csrf.CsrfToken#getToken() + */ + public String getToken() { + return this.token; + } +} diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRepository.java new file mode 100644 index 0000000000..b09ee4ec31 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRepository.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.server.csrf; + +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; + +/** + * An API to allow changing the method in which the expected {@link CsrfToken} is + * associated to the {@link ServerWebExchange}. For example, it may be stored in + * {@link org.springframework.web.server.WebSession}. + * + * @see WebSessionServerCsrfTokenRepository + * + * @author Rob Winch + * @since 5.0 + * + */ +public interface ServerCsrfTokenRepository { + + /** + * Generates a {@link CsrfToken} + * + * @param exchange the {@link ServerWebExchange} to use + * @return the {@link CsrfToken} that was generated. Cannot be null. + */ + Mono generateToken(ServerWebExchange exchange); + + /** + * Saves the {@link CsrfToken} using the {@link ServerWebExchange}. If the + * {@link CsrfToken} is null, it is the same as deleting it. + * + * @param exchange the {@link ServerWebExchange} to use + * @param token the {@link CsrfToken} to save or null to delete + */ + Mono saveToken(ServerWebExchange exchange, CsrfToken token); + + /** + * Loads the expected {@link CsrfToken} from the {@link ServerWebExchange} + * + * @param exchange the {@link ServerWebExchange} to use + * @return the {@link CsrfToken} or null if none exists + */ + Mono loadToken(ServerWebExchange exchange); +} diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java new file mode 100644 index 0000000000..da4281f858 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java @@ -0,0 +1,122 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.server.csrf; + +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; +import reactor.core.publisher.Mono; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpSession; +import java.util.Map; +import java.util.UUID; + +/** + * A {@link ServerCsrfTokenRepository} that stores the {@link CsrfToken} in the + * {@link HttpSession}. + * + * @author Rob Winch + * @since 5.0 + */ +public class WebSessionServerCsrfTokenRepository + implements ServerCsrfTokenRepository { + private static final String DEFAULT_CSRF_PARAMETER_NAME = "_csrf"; + + private static final String DEFAULT_CSRF_HEADER_NAME = "X-CSRF-TOKEN"; + + private static final String DEFAULT_CSRF_TOKEN_ATTR_NAME = WebSessionServerCsrfTokenRepository.class + .getName().concat(".CSRF_TOKEN"); + + private String parameterName = DEFAULT_CSRF_PARAMETER_NAME; + + private String headerName = DEFAULT_CSRF_HEADER_NAME; + + private String sessionAttributeName = DEFAULT_CSRF_TOKEN_ATTR_NAME; + + @Override + public Mono generateToken(ServerWebExchange exchange) { + return Mono.defer(() -> Mono.just(createCsrfToken())) + .flatMap(token -> save(exchange, token)); + } + + @Override + public Mono saveToken(ServerWebExchange exchange, CsrfToken token) { + return save(exchange, token) + .then(); + } + + private Mono save(ServerWebExchange exchange, CsrfToken token) { + return exchange.getSession() + .map(WebSession::getAttributes) + .flatMap( attrs -> save(attrs, token)); + } + + private Mono save(Map attributes, CsrfToken token) { + if(token == null) { + attributes.remove(this.sessionAttributeName); + } else { + attributes.put(this.sessionAttributeName, token); + } + return Mono.justOrEmpty(token); + } + + @Override + public Mono loadToken(ServerWebExchange exchange) { + return exchange.getSession() + .filter( s -> s.getAttributes().containsKey(this.sessionAttributeName)) + .map(s -> s.getAttribute(this.sessionAttributeName)); + } + + /** + * Sets the {@link HttpServletRequest} parameter name that the {@link CsrfToken} is + * expected to appear on + * @param parameterName the new parameter name to use + */ + public void setParameterName(String parameterName) { + Assert.hasLength(parameterName, "parameterName cannot be null or empty"); + this.parameterName = parameterName; + } + + /** + * Sets the header name that the {@link CsrfToken} is expected to appear on and the + * header that the response will contain the {@link CsrfToken}. + * + * @param headerName the new header name to use + */ + public void setHeaderName(String headerName) { + Assert.hasLength(headerName, "headerName cannot be null or empty"); + this.headerName = headerName; + } + + /** + * Sets the {@link HttpSession} attribute name that the {@link CsrfToken} is stored in + * @param sessionAttributeName the new attribute name to use + */ + public void setSessionAttributeName(String sessionAttributeName) { + Assert.hasLength(sessionAttributeName, + "sessionAttributename cannot be null or empty"); + this.sessionAttributeName = sessionAttributeName; + } + + private CsrfToken createCsrfToken() { + return new DefaultCsrfToken(this.headerName, this.parameterName, createNewToken()); + } + + private String createNewToken() { + return UUID.randomUUID().toString(); + } +} diff --git a/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java b/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java index 2850105238..78d8a463a0 100644 --- a/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java @@ -23,6 +23,7 @@ import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.security.web.server.csrf.CsrfToken; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers; import org.springframework.util.MultiValueMap; @@ -50,21 +51,31 @@ public class LoginPageGeneratingWebFilter implements WebFilter { } private Mono render(ServerWebExchange exchange) { - MultiValueMap queryParams = exchange.getRequest() - .getQueryParams(); - boolean isError = queryParams.containsKey("error"); - boolean isLogoutSuccess = queryParams.containsKey("logout"); ServerHttpResponse result = exchange.getResponse(); - result.setStatusCode(HttpStatus.FOUND); + result.setStatusCode(HttpStatus.OK); result.getHeaders().setContentType(MediaType.TEXT_HTML); - byte[] bytes = createPage(isError, isLogoutSuccess); - DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory(); - DataBuffer buffer = bufferFactory.wrap(bytes); - return result.writeWith(Mono.just(buffer)) - .doOnError( error -> DataBufferUtils.release(buffer)); + return result.writeWith(createBuffer(exchange)); +// .doOnError( error -> DataBufferUtils.release(buffer)); } - private static byte[] createPage(boolean isError, boolean isLogoutSuccess) { + private Mono createBuffer(ServerWebExchange exchange) { + MultiValueMap queryParams = exchange.getRequest() + .getQueryParams(); + Mono token = (Mono) exchange.getAttributes() + .getOrDefault(CsrfToken.class.getName(), Mono.empty()); + return token + .map(LoginPageGeneratingWebFilter::csrfToken) + .defaultIfEmpty("") + .map(csrfTokenHtmlInput -> { + boolean isError = queryParams.containsKey("error"); + boolean isLogoutSuccess = queryParams.containsKey("logout"); + byte[] bytes = createPage(isError, isLogoutSuccess, csrfTokenHtmlInput); + DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory(); + return bufferFactory.wrap(bytes); + }); + } + + private static byte[] createPage(boolean isError, boolean isLogoutSuccess, String csrfTokenHtmlInput) { String page = "\n" + "\n" + " \n" @@ -90,6 +101,7 @@ public class LoginPageGeneratingWebFilter implements WebFilter { + " \n" + " \n" + "

\n" + + csrfTokenHtmlInput + " \n" + " \n" + " \n" @@ -99,6 +111,10 @@ public class LoginPageGeneratingWebFilter implements WebFilter { return page.getBytes(Charset.defaultCharset()); } + private static String csrfToken(CsrfToken token) { + return " \n"; + } + private static String createError(boolean isError) { return isError ? "
Invalid credentials
" : ""; } diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java new file mode 100644 index 0000000000..4c118f6fce --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java @@ -0,0 +1,185 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.server.csrf; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.server.WebFilterChain; +import org.springframework.web.server.WebSession; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.test.publisher.PublisherProbe; + +import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +/** + * @author Rob Winch + * @since 5.0 + */ +@RunWith(MockitoJUnitRunner.class) +public class CsrfWebFilterTests { + @Mock + private WebFilterChain chain; + @Mock + private ServerCsrfTokenRepository repository; + + private CsrfToken token = new DefaultCsrfToken("csrf", "CSRF", "a"); + + private CsrfWebFilter csrfFilter = new CsrfWebFilter(); + + private MockServerWebExchange get = MockServerWebExchange.from( + MockServerHttpRequest.get("/")); + + private MockServerWebExchange post = MockServerWebExchange.from( + MockServerHttpRequest.post("/")); + + @Test + public void filterWhenGetThenSessionNotCreatedAndChainContinues() { + PublisherProbe chainResult = PublisherProbe.empty(); + when(this.chain.filter(this.get)).thenReturn(chainResult.mono()); + + Mono result = this.csrfFilter.filter(this.get, this.chain); + + StepVerifier.create(result) + .verifyComplete(); + + Mono isSessionStarted = this.get.getSession() + .map(WebSession::isStarted); + StepVerifier.create(isSessionStarted) + .expectNext(false) + .verifyComplete(); + + chainResult.assertWasSubscribed(); + } + + @Test + public void filterWhenPostAndNoTokenThenCsrfException() { + Mono result = this.csrfFilter.filter(this.post, this.chain); + + StepVerifier.create(result) + .verifyComplete(); + + assertThat(this.post.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN); + } + + @Test + public void filterWhenPostAndEstablishedCsrfTokenAndRequestMissingTokenThenCsrfException() { + this.csrfFilter.setServerCsrfTokenRepository(this.repository); + when(this.repository.loadToken(any())) + .thenReturn(Mono.just(this.token)); + when(this.repository.generateToken(any())) + .thenReturn(Mono.just(this.token)); + + Mono result = this.csrfFilter.filter(this.post, this.chain); + + + StepVerifier.create(result) + .verifyComplete(); + + assertThat(this.post.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN); + } + + @Test + public void filterWhenPostAndEstablishedCsrfTokenAndRequestParamInvalidTokenThenCsrfException() { + this.csrfFilter.setServerCsrfTokenRepository(this.repository); + when(this.repository.loadToken(any())) + .thenReturn(Mono.just(this.token)); + when(this.repository.generateToken(any())) + .thenReturn(Mono.just(this.token)); + this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/") + .body(this.token.getParameterName() + "="+this.token.getToken()+"INVALID")); + + Mono result = this.csrfFilter.filter(this.post, this.chain); + + StepVerifier.create(result) + .verifyComplete(); + + assertThat(this.post.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN); + } + + @Test + public void filterWhenPostAndEstablishedCsrfTokenAndRequestParamValidTokenThenContinues() { + PublisherProbe chainResult = PublisherProbe.empty(); + when(this.chain.filter(any())).thenReturn(chainResult.mono()); + + this.csrfFilter.setServerCsrfTokenRepository(this.repository); + when(this.repository.loadToken(any())) + .thenReturn(Mono.just(this.token)); + when(this.repository.generateToken(any())) + .thenReturn(Mono.just(this.token)); + this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/") + .contentType(MediaType.APPLICATION_FORM_URLENCODED) + .body(this.token.getParameterName() + "="+this.token.getToken())); + + Mono result = this.csrfFilter.filter(this.post, this.chain); + + StepVerifier.create(result) + .verifyComplete(); + + chainResult.assertWasSubscribed(); + } + + @Test + public void filterWhenPostAndEstablishedCsrfTokenAndHeaderInvalidTokenThenCsrfException() { + this.csrfFilter.setServerCsrfTokenRepository(this.repository); + when(this.repository.loadToken(any())) + .thenReturn(Mono.just(this.token)); + when(this.repository.generateToken(any())) + .thenReturn(Mono.just(this.token)); + this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/") + .header(this.token.getHeaderName(), this.token.getToken()+"INVALID")); + + Mono result = this.csrfFilter.filter(this.post, this.chain); + + StepVerifier.create(result) + .verifyComplete(); + + assertThat(this.post.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN); + } + + @Test + public void filterWhenPostAndEstablishedCsrfTokenAndHeaderValidTokenThenContinues() { + PublisherProbe chainResult = PublisherProbe.empty(); + when(this.chain.filter(any())).thenReturn(chainResult.mono()); + + this.csrfFilter.setServerCsrfTokenRepository(this.repository); + when(this.repository.loadToken(any())) + .thenReturn(Mono.just(this.token)); + when(this.repository.generateToken(any())) + .thenReturn(Mono.just(this.token)); + this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/") + .header(this.token.getHeaderName(), this.token.getToken())); + + Mono result = this.csrfFilter.filter(this.post, this.chain); + + StepVerifier.create(result) + .verifyComplete(); + + chainResult.assertWasSubscribed(); + } +} diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepositoryTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepositoryTests.java new file mode 100644 index 0000000000..c9ed4ae708 --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepositoryTests.java @@ -0,0 +1,112 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.server.csrf; + +import org.junit.Test; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.web.server.WebSession; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import java.util.Map; + +import static org.assertj.core.api.Assertions.*; + +/** + * @author Rob Winch + * @since 5.0 + */ +public class WebSessionServerCsrfTokenRepositoryTests { + private WebSessionServerCsrfTokenRepository repository = new WebSessionServerCsrfTokenRepository(); + + private MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/")); + + @Test + public void generateTokenWhenNoSubscriptionThenNoSession() { + Mono result = this.repository.generateToken(this.exchange); + + Mono isSessionStarted = this.exchange.getSession() + .map(WebSession::isStarted); + + StepVerifier.create(isSessionStarted) + .expectNext(false) + .verifyComplete(); + } + + @Test + public void generateTokenWhenSubscriptionThenAddsToSession() { + Mono result = this.repository.generateToken(this.exchange); + + StepVerifier.create(result) + .consumeNextWith( t -> assertThat(t).isNotNull()) + .verifyComplete(); + + WebSession session = this.exchange.getSession().block(); + Map attributes = session.getAttributes(); + + assertThat(session.isStarted()).isTrue(); + assertThat(attributes).hasSize(1); + assertThat(attributes.values().iterator().next()).isInstanceOf(CsrfToken.class); + + } + + @Test + public void saveTokenWhenSetSessionAttributeNameAndSubscriptionThenAddsToSession() { + CsrfToken token = new DefaultCsrfToken("h","p", "t"); + String attrName = "ATTR"; + this.repository.setSessionAttributeName(attrName); + Mono result = this.repository.saveToken(this.exchange, token); + + StepVerifier.create(result) + .verifyComplete(); + + WebSession session = this.exchange.getSession().block(); + + assertThat(session.isStarted()).isTrue(); + assertThat(session.getAttribute(attrName)).isEqualTo(token); + } + + @Test + public void saveTokenWhenNullThenDeletes() { + CsrfToken token = new DefaultCsrfToken("h","p", "t"); + this.repository.saveToken(this.exchange, token).block(); + + Mono result = this.repository.saveToken(this.exchange, null); + StepVerifier.create(result) + .verifyComplete(); + + WebSession session = this.exchange.getSession().block(); + + assertThat(session.getAttributes()).isEmpty(); + } + + @Test + public void generateTokenAndLoadTokenDeleteTokenWhenNullThenDeletes() { + CsrfToken generate = this.repository.generateToken(this.exchange).block(); + + CsrfToken load = this.repository.loadToken(this.exchange).block(); + assertThat(load).isEqualTo(generate); + + this.repository.saveToken(this.exchange, null).block(); + WebSession session = this.exchange.getSession().block(); + assertThat(session.getAttributes()).isEmpty(); + + load = this.repository.loadToken(this.exchange).block(); + assertThat(load).isNull(); + } +}