Add WebFlux CSRF Protection

Fixes gh-4734
This commit is contained in:
Rob Winch 2017-10-28 20:27:57 -05:00
parent f040bd054d
commit 8da2c7f657
16 changed files with 943 additions and 12 deletions

View File

@ -23,6 +23,10 @@ package org.springframework.security.config.web.server;
public enum SecurityWebFiltersOrder {
FIRST(Integer.MIN_VALUE),
HTTP_HEADERS_WRITER,
/**
* {@link org.springframework.security.web.server.csrf.CsrfWebFilter}
*/
CSRF,
/**
* Instance of AuthenticationWebFilter
*/

View File

@ -44,11 +44,14 @@ import org.springframework.security.web.server.authorization.AuthorizationContex
import org.springframework.security.web.server.authorization.AuthorizationWebFilter;
import org.springframework.security.web.server.authorization.DelegatingReactiveAuthorizationManager;
import org.springframework.security.web.server.authorization.ExceptionTranslationWebFilter;
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
import org.springframework.security.web.server.context.ReactorContextWebFilter;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.context.NoOpServerSecurityContextRepository;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.security.web.server.csrf.CsrfWebFilter;
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
import org.springframework.security.web.server.header.CacheControlServerHttpHeadersWriter;
import org.springframework.security.web.server.header.CompositeServerHttpHeadersWriter;
import org.springframework.security.web.server.header.ContentTypeOptionsServerHttpHeadersWriter;
@ -90,6 +93,8 @@ public class ServerHttpSecurity {
private HeaderBuilder headers;
private CsrfBuilder csrf = new CsrfBuilder();
private HttpBasicBuilder httpBasic;
private FormLoginBuilder formLogin;
@ -139,6 +144,13 @@ public class ServerHttpSecurity {
return this;
}
public CsrfBuilder csrf() {
if(this.csrf == null) {
this.csrf = new CsrfBuilder();
}
return this.csrf;
}
public HttpBasicBuilder httpBasic() {
if(this.httpBasic == null) {
this.httpBasic = new HttpBasicBuilder();
@ -191,6 +203,9 @@ public class ServerHttpSecurity {
if(securityContextRepositoryWebFilter != null) {
this.webFilters.add(securityContextRepositoryWebFilter);
}
if(this.csrf != null) {
this.csrf.configure(this);
}
if(this.httpBasic != null) {
this.httpBasic.authenticationManager(this.authenticationManager);
if(this.serverSecurityContextRepository != null) {
@ -340,6 +355,53 @@ public class ServerHttpSecurity {
}
}
/**
* @author Rob Winch
* @since 5.0
*/
public class CsrfBuilder {
private CsrfWebFilter filter = new CsrfWebFilter();
public CsrfBuilder serverAccessDeniedHandler(
ServerAccessDeniedHandler serverAccessDeniedHandler) {
this.filter.setServerAccessDeniedHandler(serverAccessDeniedHandler);
return this;
}
public CsrfBuilder csrfTokenAttributeName(String csrfTokenAttributeName) {
Assert.notNull(csrfTokenAttributeName, "csrfTokenAttributeName cannot be null");
this.filter.setCsrfTokenAttributeName(csrfTokenAttributeName);
return this;
}
public CsrfBuilder serverCsrfTokenRepository(
ServerCsrfTokenRepository serverCsrfTokenRepository) {
this.filter.setServerCsrfTokenRepository(serverCsrfTokenRepository);
return this;
}
public CsrfBuilder requireCsrfProtectionMatcher(
ServerWebExchangeMatcher requireCsrfProtectionMatcher) {
this.filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher);
return this;
}
public ServerHttpSecurity and() {
return ServerHttpSecurity.this;
}
public ServerHttpSecurity disable() {
ServerHttpSecurity.this.csrf = null;
return ServerHttpSecurity.this;
}
protected void configure(ServerHttpSecurity http) {
http.addFilterAt(this.filter, SecurityWebFiltersOrder.CSRF);
}
private CsrfBuilder() {}
}
/**
* @author Rob Winch
* @since 5.0

View File

@ -55,6 +55,7 @@ import java.nio.charset.StandardCharsets;
import java.security.Principal;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.csrf;
import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.Credentials.basicAuthenticationCredentials;
import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.basicAuthentication;
@ -213,6 +214,7 @@ public class EnableWebFluxSecurityTests {
data.add("username", "user");
data.add("password", "password");
client
.mutateWith(csrf())
.post()
.uri("/login")
.body(BodyInserters.fromFormData(data))

View File

@ -32,6 +32,7 @@ public class AuthorizeExchangeBuilderTests {
@Test
public void antMatchersWhenMethodAndPatternsThenDiscriminatesByMethod() {
this.http
.csrf().disable()
.authorizeExchange()
.pathMatchers(HttpMethod.POST, "/a", "/b").denyAll()
.anyExchange().permitAll();
@ -63,6 +64,7 @@ public class AuthorizeExchangeBuilderTests {
@Test
public void antMatchersWhenPatternsThenAnyMethod() {
this.http
.csrf().disable()
.authorizeExchange()
.pathMatchers("/a", "/b").denyAll()
.anyExchange().permitAll();

View File

@ -26,6 +26,10 @@ import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.web.server.csrf.CsrfToken;
import org.springframework.security.web.server.csrf.CsrfWebFilter;
import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.test.web.reactive.server.MockServerConfigurer;
import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.test.web.reactive.server.WebTestClientConfigurer;
@ -107,6 +111,35 @@ public class SecurityMockServerConfigurers {
return new UserExchangeMutator(username);
}
public static CsrfMutator csrf() {
return new CsrfMutator();
}
public static class CsrfMutator implements WebTestClientConfigurer, MockServerConfigurer {
@Override
public void afterConfigurerAdded(WebTestClient.Builder builder,
@Nullable WebHttpHandlerBuilder httpHandlerBuilder,
@Nullable ClientHttpConnector connector) {
CsrfWebFilter filter = new CsrfWebFilter();
filter.setRequireCsrfProtectionMatcher( e -> ServerWebExchangeMatcher.MatchResult.notMatch());
httpHandlerBuilder.filters( filters -> filters.add(0, filter));
}
@Override
public void afterConfigureAdded(
WebTestClient.MockServerSpec<?> serverSpec) {
}
@Override
public void beforeServerCreated(WebHttpHandlerBuilder builder) {
}
private CsrfMutator() {}
}
/**
* Updates the WebServerExchange using {@code {@link SecurityMockServerConfigurers#mockUser(UserDetails)}. Defaults to use a
* password of "password" and granted authorities of "ROLE_USER".

View File

@ -18,15 +18,18 @@ package org.springframework.security.test.web.reactive.server;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
import org.springframework.security.web.server.csrf.CsrfWebFilter;
import org.springframework.test.web.reactive.server.WebTestClient;
import java.security.Principal;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.*;
/**
@ -36,7 +39,7 @@ import static org.springframework.security.test.web.reactive.server.SecurityMock
public class SecurityMockServerConfigurersTests extends AbstractMockServerConfigurersTests {
WebTestClient client = WebTestClient
.bindToController(controller)
.webFilter(new SecurityContextServerWebExchangeWebFilter())
.webFilter( new CsrfWebFilter(), new SecurityContextServerWebExchangeWebFilter())
.apply(springSecurity())
.configureClient()
.defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
@ -144,4 +147,37 @@ public class SecurityMockServerConfigurersTests extends AbstractMockServerConfig
assertPrincipalCreatedFromUserDetails(actual, userBuilder.build());
}
@Test
public void csrfWhenMutateWithThenDisablesCsrf() {
this.client
.post()
.exchange()
.expectStatus().isEqualTo(HttpStatus.FORBIDDEN)
.expectBody().consumeWith( b -> assertThat(new String(b.getResponseBody())).contains("CSRF"));
this.client
.mutateWith(csrf())
.post()
.exchange()
.expectStatus().isOk();
}
@Test
public void csrfWhenGlobalThenDisablesCsrf() {
this.client = WebTestClient
.bindToController(this.controller)
.webFilter(new CsrfWebFilter())
.apply(springSecurity())
.apply(csrf())
.configureClient()
.build();
this.client
.get()
.exchange()
.expectStatus().isOk();
}
}

View File

@ -52,6 +52,7 @@ public class AuthenticationWebFilter implements WebFilter {
private ServerSecurityContextRepository serverSecurityContextRepository = NoOpServerSecurityContextRepository.getInstance();
private ServerWebExchangeMatcher requiresAuthenticationMatcher = ServerWebExchangeMatchers.anyExchange();
public AuthenticationWebFilter(ReactiveAuthenticationManager authenticationManager) {
Assert.notNull(authenticationManager, "authenticationManager cannot be null");
this.authenticationManager = authenticationManager;

View File

@ -0,0 +1,33 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.server.csrf;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.web.csrf.CsrfToken;
/**
* Thrown when an invalid or missing {@link CsrfToken} is found in the HttpServletRequest
*
* @author Rob Winch
* @since 3.2
*/
@SuppressWarnings("serial")
public class CsrfException extends AccessDeniedException {
public CsrfException(String message) {
super(message);
}
}

View File

@ -0,0 +1,48 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.server.csrf;
import java.io.Serializable;
/**
* @author Rob Winch
* @since 5.0
*/
public interface CsrfToken extends Serializable {
/**
* Gets the HTTP header that the CSRF is populated on the response and can be placed
* on requests instead of the parameter. Cannot be null.
*
* @return the HTTP header that the CSRF is populated on the response and can be
* placed on requests instead of the parameter
*/
String getHeaderName();
/**
* Gets the HTTP parameter name that should contain the token. Cannot be null.
* @return the HTTP parameter name that should contain the token.
*/
String getParameterName();
/**
* Gets the token value. Cannot be null.
* @return the token value
*/
String getToken();
}

View File

@ -0,0 +1,140 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.server.csrf;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
/**
* <p>
* Applies
* <a href="https://www.owasp.org/index.php/Cross-Site_Request_Forgery_(CSRF)" >CSRF</a>
* protection using a synchronizer token pattern. Developers are required to ensure that
* {@link CsrfWebFilter} is invoked for any request that allows state to change. Typically
* this just means that they should ensure their web application follows proper REST
* semantics (i.e. do not change state with the HTTP methods GET, HEAD, TRACE, OPTIONS).
* </p>
*
* <p>
* Typically the {@link ServerCsrfTokenRepository} implementation chooses to store the
* {@link CsrfToken} in {@link org.springframework.web.server.WebSession} with
* {@link WebSessionServerCsrfTokenRepository}. This is preferred to storing the token in
* a cookie which can be modified by a client application.
* </p>
*
* @author Rob Winch
* @since 5.0
*/
public class CsrfWebFilter implements WebFilter {
private ServerWebExchangeMatcher requireCsrfProtectionMatcher = new DefaultRequireCsrfProtectionMatcher();
private ServerCsrfTokenRepository serverCsrfTokenRepository = new WebSessionServerCsrfTokenRepository();
private ServerAccessDeniedHandler serverAccessDeniedHandler = new HttpStatusServerAccessDeniedHandler(HttpStatus.FORBIDDEN);
private String csrfTokenAttributeName = "csrf";
public void setServerAccessDeniedHandler(
ServerAccessDeniedHandler serverAccessDeniedHandler) {
Assert.notNull(serverAccessDeniedHandler, "serverAccessDeniedHandler");
this.serverAccessDeniedHandler = serverAccessDeniedHandler;
}
public void setCsrfTokenAttributeName(String csrfTokenAttributeName) {
Assert.notNull(csrfTokenAttributeName, "csrfTokenAttributeName cannot be null");
this.csrfTokenAttributeName = csrfTokenAttributeName;
}
public void setServerCsrfTokenRepository(
ServerCsrfTokenRepository serverCsrfTokenRepository) {
Assert.notNull(serverCsrfTokenRepository, "serverCsrfTokenRepository cannot be null");
this.serverCsrfTokenRepository = serverCsrfTokenRepository;
}
public void setRequireCsrfProtectionMatcher(
ServerWebExchangeMatcher requireCsrfProtectionMatcher) {
Assert.notNull(requireCsrfProtectionMatcher, "requireCsrfProtectionMatcher cannot be null");
this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
return this.requireCsrfProtectionMatcher.matches(exchange)
.filter( matchResult -> matchResult.isMatch())
.filter( matchResult -> !exchange.getAttributes().containsKey(CsrfToken.class.getName()))
.flatMap(m -> validateToken(exchange))
.flatMap(m -> continueFilterChain(exchange, chain))
.switchIfEmpty(continueFilterChain(exchange, chain).then(Mono.empty()))
.onErrorResume(CsrfException.class, e -> this.serverAccessDeniedHandler.handle(exchange, e));
}
private Mono<Void> validateToken(ServerWebExchange exchange) {
return this.serverCsrfTokenRepository.loadToken(exchange)
.switchIfEmpty(Mono.error(new CsrfException("CSRF Token has been associated to this client")))
.filterWhen(expected -> containsValidCsrfToken(exchange, expected))
.switchIfEmpty(Mono.error(new CsrfException("Invalid CSRF Token")))
.then();
}
private Mono<Boolean> containsValidCsrfToken(ServerWebExchange exchange, CsrfToken expected) {
return exchange.getFormData()
.flatMap(data -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
.switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
.map(actual -> actual.equals(expected.getToken()));
}
private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
return csrfToken(exchange)
.doOnSuccess(csrfToken -> exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken))
.doOnSuccess(csrfToken -> exchange.getAttributes().put(this.csrfTokenAttributeName, csrfToken))
.flatMap( t -> chain.filter(exchange))
.then();
}
private Mono<Mono<CsrfToken>> csrfToken(ServerWebExchange exchange) {
return this.serverCsrfTokenRepository.loadToken(exchange)
.switchIfEmpty(this.serverCsrfTokenRepository.generateToken(exchange))
.as(Mono::just); // FIXME eager saving of CsrfToken with .as
}
private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher {
private static final Set<HttpMethod> ALLOWED_METHODS = new HashSet<>(
Arrays.asList(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.TRACE, HttpMethod.OPTIONS));
@Override
public Mono<MatchResult> matches(ServerWebExchange exchange) {
return Mono.just(exchange.getRequest())
.map(r -> r.getMethod())
.filter(m -> ALLOWED_METHODS.contains(m))
.flatMap(m -> MatchResult.notMatch())
.switchIfEmpty(MatchResult.match());
}
}
}

View File

@ -0,0 +1,77 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.server.csrf;
import org.springframework.util.Assert;
/**
* A CSRF token that is used to protect against CSRF attacks.
*
* @author Rob Winch
* @since 5.0
*/
@SuppressWarnings("serial")
public final class DefaultCsrfToken implements CsrfToken {
private final String token;
private final String parameterName;
private final String headerName;
/**
* Creates a new instance
* @param headerName the HTTP header name to use
* @param parameterName the HTTP parameter name to use
* @param token the value of the token (i.e. expected value of the HTTP parameter of
* parametername).
*/
public DefaultCsrfToken(String headerName, String parameterName, String token) {
Assert.hasLength(headerName, "headerName cannot be null or empty");
Assert.hasLength(parameterName, "parameterName cannot be null or empty");
Assert.hasLength(token, "token cannot be null or empty");
this.headerName = headerName;
this.parameterName = parameterName;
this.token = token;
}
/*
* (non-Javadoc)
*
* @see org.springframework.security.web.csrf.CsrfToken#getHeaderName()
*/
public String getHeaderName() {
return this.headerName;
}
/*
* (non-Javadoc)
*
* @see org.springframework.security.web.csrf.CsrfToken#getParameterName()
*/
public String getParameterName() {
return this.parameterName;
}
/*
* (non-Javadoc)
*
* @see org.springframework.security.web.csrf.CsrfToken#getToken()
*/
public String getToken() {
return this.token;
}
}

View File

@ -0,0 +1,58 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.server.csrf;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
/**
* An API to allow changing the method in which the expected {@link CsrfToken} is
* associated to the {@link ServerWebExchange}. For example, it may be stored in
* {@link org.springframework.web.server.WebSession}.
*
* @see WebSessionServerCsrfTokenRepository
*
* @author Rob Winch
* @since 5.0
*
*/
public interface ServerCsrfTokenRepository {
/**
* Generates a {@link CsrfToken}
*
* @param exchange the {@link ServerWebExchange} to use
* @return the {@link CsrfToken} that was generated. Cannot be null.
*/
Mono<CsrfToken> generateToken(ServerWebExchange exchange);
/**
* Saves the {@link CsrfToken} using the {@link ServerWebExchange}. If the
* {@link CsrfToken} is null, it is the same as deleting it.
*
* @param exchange the {@link ServerWebExchange} to use
* @param token the {@link CsrfToken} to save or null to delete
*/
Mono<Void> saveToken(ServerWebExchange exchange, CsrfToken token);
/**
* Loads the expected {@link CsrfToken} from the {@link ServerWebExchange}
*
* @param exchange the {@link ServerWebExchange} to use
* @return the {@link CsrfToken} or null if none exists
*/
Mono<CsrfToken> loadToken(ServerWebExchange exchange);
}

View File

@ -0,0 +1,122 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.server.csrf;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession;
import reactor.core.publisher.Mono;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession;
import java.util.Map;
import java.util.UUID;
/**
* A {@link ServerCsrfTokenRepository} that stores the {@link CsrfToken} in the
* {@link HttpSession}.
*
* @author Rob Winch
* @since 5.0
*/
public class WebSessionServerCsrfTokenRepository
implements ServerCsrfTokenRepository {
private static final String DEFAULT_CSRF_PARAMETER_NAME = "_csrf";
private static final String DEFAULT_CSRF_HEADER_NAME = "X-CSRF-TOKEN";
private static final String DEFAULT_CSRF_TOKEN_ATTR_NAME = WebSessionServerCsrfTokenRepository.class
.getName().concat(".CSRF_TOKEN");
private String parameterName = DEFAULT_CSRF_PARAMETER_NAME;
private String headerName = DEFAULT_CSRF_HEADER_NAME;
private String sessionAttributeName = DEFAULT_CSRF_TOKEN_ATTR_NAME;
@Override
public Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
return Mono.defer(() -> Mono.just(createCsrfToken()))
.flatMap(token -> save(exchange, token));
}
@Override
public Mono<Void> saveToken(ServerWebExchange exchange, CsrfToken token) {
return save(exchange, token)
.then();
}
private Mono<CsrfToken> save(ServerWebExchange exchange, CsrfToken token) {
return exchange.getSession()
.map(WebSession::getAttributes)
.flatMap( attrs -> save(attrs, token));
}
private Mono<CsrfToken> save(Map<String,Object> attributes, CsrfToken token) {
if(token == null) {
attributes.remove(this.sessionAttributeName);
} else {
attributes.put(this.sessionAttributeName, token);
}
return Mono.justOrEmpty(token);
}
@Override
public Mono<CsrfToken> loadToken(ServerWebExchange exchange) {
return exchange.getSession()
.filter( s -> s.getAttributes().containsKey(this.sessionAttributeName))
.map(s -> s.getAttribute(this.sessionAttributeName));
}
/**
* Sets the {@link HttpServletRequest} parameter name that the {@link CsrfToken} is
* expected to appear on
* @param parameterName the new parameter name to use
*/
public void setParameterName(String parameterName) {
Assert.hasLength(parameterName, "parameterName cannot be null or empty");
this.parameterName = parameterName;
}
/**
* Sets the header name that the {@link CsrfToken} is expected to appear on and the
* header that the response will contain the {@link CsrfToken}.
*
* @param headerName the new header name to use
*/
public void setHeaderName(String headerName) {
Assert.hasLength(headerName, "headerName cannot be null or empty");
this.headerName = headerName;
}
/**
* Sets the {@link HttpSession} attribute name that the {@link CsrfToken} is stored in
* @param sessionAttributeName the new attribute name to use
*/
public void setSessionAttributeName(String sessionAttributeName) {
Assert.hasLength(sessionAttributeName,
"sessionAttributename cannot be null or empty");
this.sessionAttributeName = sessionAttributeName;
}
private CsrfToken createCsrfToken() {
return new DefaultCsrfToken(this.headerName, this.parameterName, createNewToken());
}
private String createNewToken() {
return UUID.randomUUID().toString();
}
}

View File

@ -23,6 +23,7 @@ import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.security.web.server.csrf.CsrfToken;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
import org.springframework.util.MultiValueMap;
@ -50,21 +51,31 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
}
private Mono<Void> render(ServerWebExchange exchange) {
MultiValueMap<String, String> queryParams = exchange.getRequest()
.getQueryParams();
boolean isError = queryParams.containsKey("error");
boolean isLogoutSuccess = queryParams.containsKey("logout");
ServerHttpResponse result = exchange.getResponse();
result.setStatusCode(HttpStatus.FOUND);
result.setStatusCode(HttpStatus.OK);
result.getHeaders().setContentType(MediaType.TEXT_HTML);
byte[] bytes = createPage(isError, isLogoutSuccess);
DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory();
DataBuffer buffer = bufferFactory.wrap(bytes);
return result.writeWith(Mono.just(buffer))
.doOnError( error -> DataBufferUtils.release(buffer));
return result.writeWith(createBuffer(exchange));
// .doOnError( error -> DataBufferUtils.release(buffer));
}
private static byte[] createPage(boolean isError, boolean isLogoutSuccess) {
private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
MultiValueMap<String, String> queryParams = exchange.getRequest()
.getQueryParams();
Mono<CsrfToken> token = (Mono<CsrfToken>) exchange.getAttributes()
.getOrDefault(CsrfToken.class.getName(), Mono.<CsrfToken>empty());
return token
.map(LoginPageGeneratingWebFilter::csrfToken)
.defaultIfEmpty("")
.map(csrfTokenHtmlInput -> {
boolean isError = queryParams.containsKey("error");
boolean isLogoutSuccess = queryParams.containsKey("logout");
byte[] bytes = createPage(isError, isLogoutSuccess, csrfTokenHtmlInput);
DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory();
return bufferFactory.wrap(bytes);
});
}
private static byte[] createPage(boolean isError, boolean isLogoutSuccess, String csrfTokenHtmlInput) {
String page = "<!DOCTYPE html>\n"
+ "<html lang=\"en\">\n"
+ " <head>\n"
@ -90,6 +101,7 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
+ " <label for=\"password\" class=\"sr-only\">Password</label>\n"
+ " <input type=\"password\" id=\"password\" name=\"password\" class=\"form-control\" placeholder=\"Password\" required>\n"
+ " </p>\n"
+ csrfTokenHtmlInput
+ " <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Sign in</button>\n"
+ " </form>\n"
+ " </div>\n"
@ -99,6 +111,10 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
return page.getBytes(Charset.defaultCharset());
}
private static String csrfToken(CsrfToken token) {
return " <input type=\"hidden\" name=\"" + token.getParameterName() + "\" value=\"" + token.getToken() + "\">\n";
}
private static String createError(boolean isError) {
return isError ? "<div class=\"alert alert-danger\" role=\"alert\">Invalid credentials</div>" : "";
}

View File

@ -0,0 +1,185 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.server.csrf;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebSession;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import reactor.test.publisher.PublisherProbe;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
/**
* @author Rob Winch
* @since 5.0
*/
@RunWith(MockitoJUnitRunner.class)
public class CsrfWebFilterTests {
@Mock
private WebFilterChain chain;
@Mock
private ServerCsrfTokenRepository repository;
private CsrfToken token = new DefaultCsrfToken("csrf", "CSRF", "a");
private CsrfWebFilter csrfFilter = new CsrfWebFilter();
private MockServerWebExchange get = MockServerWebExchange.from(
MockServerHttpRequest.get("/"));
private MockServerWebExchange post = MockServerWebExchange.from(
MockServerHttpRequest.post("/"));
@Test
public void filterWhenGetThenSessionNotCreatedAndChainContinues() {
PublisherProbe<Void> chainResult = PublisherProbe.empty();
when(this.chain.filter(this.get)).thenReturn(chainResult.mono());
Mono<Void> result = this.csrfFilter.filter(this.get, this.chain);
StepVerifier.create(result)
.verifyComplete();
Mono<Boolean> isSessionStarted = this.get.getSession()
.map(WebSession::isStarted);
StepVerifier.create(isSessionStarted)
.expectNext(false)
.verifyComplete();
chainResult.assertWasSubscribed();
}
@Test
public void filterWhenPostAndNoTokenThenCsrfException() {
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
StepVerifier.create(result)
.verifyComplete();
assertThat(this.post.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN);
}
@Test
public void filterWhenPostAndEstablishedCsrfTokenAndRequestMissingTokenThenCsrfException() {
this.csrfFilter.setServerCsrfTokenRepository(this.repository);
when(this.repository.loadToken(any()))
.thenReturn(Mono.just(this.token));
when(this.repository.generateToken(any()))
.thenReturn(Mono.just(this.token));
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
StepVerifier.create(result)
.verifyComplete();
assertThat(this.post.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN);
}
@Test
public void filterWhenPostAndEstablishedCsrfTokenAndRequestParamInvalidTokenThenCsrfException() {
this.csrfFilter.setServerCsrfTokenRepository(this.repository);
when(this.repository.loadToken(any()))
.thenReturn(Mono.just(this.token));
when(this.repository.generateToken(any()))
.thenReturn(Mono.just(this.token));
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
.body(this.token.getParameterName() + "="+this.token.getToken()+"INVALID"));
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
StepVerifier.create(result)
.verifyComplete();
assertThat(this.post.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN);
}
@Test
public void filterWhenPostAndEstablishedCsrfTokenAndRequestParamValidTokenThenContinues() {
PublisherProbe<Void> chainResult = PublisherProbe.empty();
when(this.chain.filter(any())).thenReturn(chainResult.mono());
this.csrfFilter.setServerCsrfTokenRepository(this.repository);
when(this.repository.loadToken(any()))
.thenReturn(Mono.just(this.token));
when(this.repository.generateToken(any()))
.thenReturn(Mono.just(this.token));
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
.contentType(MediaType.APPLICATION_FORM_URLENCODED)
.body(this.token.getParameterName() + "="+this.token.getToken()));
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
StepVerifier.create(result)
.verifyComplete();
chainResult.assertWasSubscribed();
}
@Test
public void filterWhenPostAndEstablishedCsrfTokenAndHeaderInvalidTokenThenCsrfException() {
this.csrfFilter.setServerCsrfTokenRepository(this.repository);
when(this.repository.loadToken(any()))
.thenReturn(Mono.just(this.token));
when(this.repository.generateToken(any()))
.thenReturn(Mono.just(this.token));
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
.header(this.token.getHeaderName(), this.token.getToken()+"INVALID"));
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
StepVerifier.create(result)
.verifyComplete();
assertThat(this.post.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN);
}
@Test
public void filterWhenPostAndEstablishedCsrfTokenAndHeaderValidTokenThenContinues() {
PublisherProbe<Void> chainResult = PublisherProbe.empty();
when(this.chain.filter(any())).thenReturn(chainResult.mono());
this.csrfFilter.setServerCsrfTokenRepository(this.repository);
when(this.repository.loadToken(any()))
.thenReturn(Mono.just(this.token));
when(this.repository.generateToken(any()))
.thenReturn(Mono.just(this.token));
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
.header(this.token.getHeaderName(), this.token.getToken()));
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
StepVerifier.create(result)
.verifyComplete();
chainResult.assertWasSubscribed();
}
}

View File

@ -0,0 +1,112 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.server.csrf;
import org.junit.Test;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.web.server.WebSession;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import java.util.Map;
import static org.assertj.core.api.Assertions.*;
/**
* @author Rob Winch
* @since 5.0
*/
public class WebSessionServerCsrfTokenRepositoryTests {
private WebSessionServerCsrfTokenRepository repository = new WebSessionServerCsrfTokenRepository();
private MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"));
@Test
public void generateTokenWhenNoSubscriptionThenNoSession() {
Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
Mono<Boolean> isSessionStarted = this.exchange.getSession()
.map(WebSession::isStarted);
StepVerifier.create(isSessionStarted)
.expectNext(false)
.verifyComplete();
}
@Test
public void generateTokenWhenSubscriptionThenAddsToSession() {
Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
StepVerifier.create(result)
.consumeNextWith( t -> assertThat(t).isNotNull())
.verifyComplete();
WebSession session = this.exchange.getSession().block();
Map<String, Object> attributes = session.getAttributes();
assertThat(session.isStarted()).isTrue();
assertThat(attributes).hasSize(1);
assertThat(attributes.values().iterator().next()).isInstanceOf(CsrfToken.class);
}
@Test
public void saveTokenWhenSetSessionAttributeNameAndSubscriptionThenAddsToSession() {
CsrfToken token = new DefaultCsrfToken("h","p", "t");
String attrName = "ATTR";
this.repository.setSessionAttributeName(attrName);
Mono<Void> result = this.repository.saveToken(this.exchange, token);
StepVerifier.create(result)
.verifyComplete();
WebSession session = this.exchange.getSession().block();
assertThat(session.isStarted()).isTrue();
assertThat(session.<WebSession>getAttribute(attrName)).isEqualTo(token);
}
@Test
public void saveTokenWhenNullThenDeletes() {
CsrfToken token = new DefaultCsrfToken("h","p", "t");
this.repository.saveToken(this.exchange, token).block();
Mono<Void> result = this.repository.saveToken(this.exchange, null);
StepVerifier.create(result)
.verifyComplete();
WebSession session = this.exchange.getSession().block();
assertThat(session.getAttributes()).isEmpty();
}
@Test
public void generateTokenAndLoadTokenDeleteTokenWhenNullThenDeletes() {
CsrfToken generate = this.repository.generateToken(this.exchange).block();
CsrfToken load = this.repository.loadToken(this.exchange).block();
assertThat(load).isEqualTo(generate);
this.repository.saveToken(this.exchange, null).block();
WebSession session = this.exchange.getSession().block();
assertThat(session.getAttributes()).isEmpty();
load = this.repository.loadToken(this.exchange).block();
assertThat(load).isNull();
}
}