CsrfWebFilter places Mono<CsrfToken>

Fixes: gh-4855
This commit is contained in:
Rob Winch 2017-11-20 14:16:49 -06:00
parent edccafca84
commit d55db837e1
11 changed files with 73 additions and 114 deletions

View File

@ -33,6 +33,7 @@ import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import static org.assertj.core.api.Assertions.assertThat;
@ -314,9 +315,9 @@ public class FormLoginTests {
public static class CustomLoginPageController {
@ResponseBody
@GetMapping("/login")
public String login(ServerWebExchange exchange) {
CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
return
public Mono<String> login(ServerWebExchange exchange) {
Mono<CsrfToken> token = exchange.getAttributeOrDefault(CsrfToken.class.getName(), Mono.empty());
return token.map(t ->
"<!DOCTYPE html>\n"
+ "<html lang=\"en\">\n"
+ " <head>\n"
@ -338,12 +339,12 @@ public class FormLoginTests {
+ " <label for=\"password\" class=\"sr-only\">Password</label>\n"
+ " <input type=\"password\" id=\"password\" name=\"password\" placeholder=\"Password\" required>\n"
+ " </p>\n"
+ " <input type=\"hidden\" name=\"" + token.getParameterName() + "\" value=\"" + token.getToken() + "\">\n"
+ " <input type=\"hidden\" name=\"" + t.getParameterName() + "\" value=\"" + t.getToken() + "\">\n"
+ " <button type=\"submit\">Sign in</button>\n"
+ " </form>\n"
+ " </div>\n"
+ " </body>\n"
+ "</html>";
+ "</html>");
}
}
}

View File

@ -26,7 +26,6 @@ import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverB
import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.WebFilterChainProxy;
import org.springframework.security.web.server.csrf.CsrfToken;
import org.springframework.security.web.server.savedrequest.NoOpServerRequestCache;
import org.springframework.stereotype.Controller;
import org.springframework.test.web.reactive.server.WebTestClient;
@ -126,7 +125,6 @@ public class RequestCacheTests {
@ResponseBody
@GetMapping("/secured")
public String login(ServerWebExchange exchange) {
CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
return
"<!DOCTYPE html>\n"
+ "<html lang=\"en\">\n"

View File

@ -0,0 +1,38 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sample;
import org.springframework.security.web.server.csrf.CsrfToken;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import static org.springframework.security.web.reactive.result.view.CsrfRequestDataValueProcessor.DEFAULT_CSRF_ATTR_NAME;
/**
* @author Rob Winch
* @since 5.0
*/
@ControllerAdvice
public class CsrfControllerAdvice {
@ModelAttribute
public Mono<CsrfToken> csrfToken(ServerWebExchange exchange) {
Mono<CsrfToken> csrfToken = exchange.getAttribute(CsrfToken.class.getName());
return csrfToken.doOnSuccess(token -> exchange.getAttributes().put(DEFAULT_CSRF_ATTR_NAME, token));
}
}

View File

@ -30,6 +30,10 @@ import java.util.regex.Pattern;
* @since 5.0
*/
public class CsrfRequestDataValueProcessor implements RequestDataValueProcessor {
/**
* The default request attribute to look for a {@link CsrfToken}.
*/
public static final String DEFAULT_CSRF_ATTR_NAME = "_csrf";
private static final Pattern DISABLE_CSRF_TOKEN_PATTERN = Pattern
.compile("(?i)^(GET|HEAD|TRACE|OPTIONS)$");
@ -62,7 +66,7 @@ public class CsrfRequestDataValueProcessor implements RequestDataValueProcessor
exchange.getAttributes().remove(DISABLE_CSRF_TOKEN_ATTR);
return Collections.emptyMap();
}
CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
CsrfToken token = exchange.getAttribute(DEFAULT_CSRF_ATTR_NAME);
if(token == null) {
return Collections.emptyMap();
}

View File

@ -47,12 +47,16 @@ import java.util.Set;
* {@link WebSessionServerCsrfTokenRepository}. This is preferred to storing the token in
* a cookie which can be modified by a client application.
* </p>
* <p>
* The {@code Mono&lt;CsrfToken&gt;} is exposes as a request attribute with the name of
* {@code CsrfToken.class.getName()}. If the token is new it will automatically be saved
* at the time it is subscribed.
* </p>
*
* @author Rob Winch
* @since 5.0
*/
public class CsrfWebFilter implements WebFilter {
private ServerWebExchangeMatcher requireCsrfProtectionMatcher = new DefaultRequireCsrfProtectionMatcher();
private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository();
@ -105,11 +109,11 @@ public class CsrfWebFilter implements WebFilter {
}
private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
return csrfToken(exchange)
.doOnSuccess(csrfToken -> exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken))
.doOnSuccess(csrfToken -> exchange.getAttributes().put(csrfToken.getParameterName(), csrfToken))
.flatMap( t -> chain.filter(exchange))
.then();
return Mono.defer(() ->{
Mono<CsrfToken> csrfToken = csrfToken(exchange);
exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken);
return chain.filter(exchange);
});
}
private Mono<CsrfToken> csrfToken(ServerWebExchange exchange) {

View File

@ -17,7 +17,6 @@ package org.springframework.security.web.server.csrf;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession;
import reactor.core.publisher.Mono;
import javax.servlet.http.HttpServletRequest;
@ -49,20 +48,15 @@ public class WebSessionServerCsrfTokenRepository
@Override
public Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
return exchange.getSession()
.map(WebSession::getAttributes)
.map(this::createCsrfToken);
return Mono.fromCallable(() -> createCsrfToken());
}
@Override
public Mono<CsrfToken> saveToken(ServerWebExchange exchange, CsrfToken token) {
if(token != null) {
return Mono.just(token);
}
return exchange.getSession()
.doOnSuccess(session -> putToken(session.getAttributes(), token))
.doOnNext(session -> putToken(session.getAttributes(), token))
.flatMap(session -> session.changeSessionId())
.flatMap(r -> Mono.justOrEmpty(token));
.then(Mono.justOrEmpty(token));
}
private void putToken(Map<String, Object> attributes, CsrfToken token) {
@ -111,11 +105,6 @@ public class WebSessionServerCsrfTokenRepository
this.sessionAttributeName = sessionAttributeName;
}
private CsrfToken createCsrfToken(Map<String, Object> attributes) {
return new LazyCsrfToken(attributes, createCsrfToken());
}
private CsrfToken createCsrfToken() {
return new DefaultCsrfToken(this.headerName, this.parameterName, createNewToken());
}
@ -124,58 +113,4 @@ public class WebSessionServerCsrfTokenRepository
return UUID.randomUUID().toString();
}
private class LazyCsrfToken implements CsrfToken {
private final Map<String, Object> attributes;
private final CsrfToken delegate;
private LazyCsrfToken(Map<String, Object> attributes, CsrfToken delegate) {
this.attributes = attributes;
this.delegate = delegate;
}
@Override
public String getHeaderName() {
return this.delegate.getHeaderName();
}
@Override
public String getParameterName() {
return this.delegate.getParameterName();
}
@Override
public String getToken() {
putToken(this.attributes, this.delegate);
return this.delegate.getToken();
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || !(o instanceof CsrfToken))
return false;
CsrfToken that = (CsrfToken) o;
if (!getToken().equals(that.getToken()))
return false;
if (!getParameterName().equals(that.getParameterName()))
return false;
return getHeaderName().equals(that.getHeaderName());
}
@Override
public int hashCode() {
int result = getToken().hashCode();
result = 31 * result + getParameterName().hashCode();
result = 31 * result + getHeaderName().hashCode();
return result;
}
@Override
public String toString() {
return "LazyCsrfToken{" + "delegate=" + this.delegate + '}';
}
}
}

View File

@ -60,8 +60,8 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
MultiValueMap<String, String> queryParams = exchange.getRequest()
.getQueryParams();
CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
return Mono.justOrEmpty(token)
Mono<CsrfToken> token = exchange.getAttributeOrDefault(CsrfToken.class.getName(), Mono.empty());
return token
.map(LoginPageGeneratingWebFilter::csrfToken)
.defaultIfEmpty("")
.map(csrfTokenHtmlInput -> {

View File

@ -57,8 +57,8 @@ public class LogoutPageGeneratingWebFilter implements WebFilter {
}
private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
return Mono.justOrEmpty(token)
Mono<CsrfToken> token = exchange.getAttributeOrDefault(CsrfToken.class.getName(), Mono.empty());
return token
.map(LogoutPageGeneratingWebFilter::csrfToken)
.defaultIfEmpty("")
.map(csrfTokenHtmlInput -> {

View File

@ -30,6 +30,7 @@ import java.util.HashMap;
import java.util.Map;
import static org.assertj.core.api.Assertions.*;
import static org.springframework.security.web.reactive.result.view.CsrfRequestDataValueProcessor.DEFAULT_CSRF_ATTR_NAME;
/**
* @author Rob Winch
@ -46,7 +47,7 @@ public class CsrfRequestDataValueProcessorTests {
@Before
public void setup() {
this.expected.put(this.token.getParameterName(), this.token.getToken());
this.exchange.getAttributes().put(CsrfToken.class.getName(), this.token);
this.exchange.getAttributes().put(DEFAULT_CSRF_ATTR_NAME, this.token);
}
@Test
@ -122,7 +123,7 @@ public class CsrfRequestDataValueProcessorTests {
@Test
public void createGetExtraHiddenFieldsHasCsrfToken() {
CsrfToken token = new DefaultCsrfToken("1", "a", "b");
this.exchange.getAttributes().put(CsrfToken.class.getName(), token);
this.exchange.getAttributes().put(DEFAULT_CSRF_ATTR_NAME, token);
Map<String, String> expected = new HashMap<String, String>();
expected.put(token.getParameterName(), token.getToken());

View File

@ -89,8 +89,6 @@ public class CsrfWebFilterTests {
this.csrfFilter.setCsrfTokenRepository(this.repository);
when(this.repository.loadToken(any()))
.thenReturn(Mono.just(this.token));
when(this.repository.generateToken(any()))
.thenReturn(Mono.just(this.token));
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
@ -106,8 +104,6 @@ public class CsrfWebFilterTests {
this.csrfFilter.setCsrfTokenRepository(this.repository);
when(this.repository.loadToken(any()))
.thenReturn(Mono.just(this.token));
when(this.repository.generateToken(any()))
.thenReturn(Mono.just(this.token));
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
.body(this.token.getParameterName() + "="+this.token.getToken()+"INVALID"));
@ -146,8 +142,6 @@ public class CsrfWebFilterTests {
this.csrfFilter.setCsrfTokenRepository(this.repository);
when(this.repository.loadToken(any()))
.thenReturn(Mono.just(this.token));
when(this.repository.generateToken(any()))
.thenReturn(Mono.just(this.token));
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
.header(this.token.getHeaderName(), this.token.getToken()+"INVALID"));

View File

@ -61,9 +61,10 @@ public class WebSessionServerCsrfTokenRepositoryTests {
}
@Test
public void generateTokenWhenGetTokenThenAddsToSession() {
Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
result.block().getToken();
public void saveTokenWhenDefaultThenAddsToSession() {
Mono<CsrfToken> result = this.repository.generateToken(this.exchange)
.delayUntil(t-> this.repository.saveToken(this.exchange, t));
result.block();
WebSession session = this.exchange.getSession().block();
Map<String, Object> attributes = session.getAttributes();
@ -76,7 +77,6 @@ public class WebSessionServerCsrfTokenRepositoryTests {
@Test
public void saveTokenWhenNullThenDeletes() {
CsrfToken token = this.repository.generateToken(this.exchange).block();
token.getToken();
Mono<CsrfToken> result = this.repository.saveToken(this.exchange, null);
StepVerifier.create(result)
@ -87,22 +87,6 @@ public class WebSessionServerCsrfTokenRepositoryTests {
assertThat(session.getAttributes()).isEmpty();
}
@Test
public void generateTokenAndLoadTokenDeleteTokenWhenNullThenDeletes() {
CsrfToken generate = this.repository.generateToken(this.exchange).block();
generate.getToken();
CsrfToken load = this.repository.loadToken(this.exchange).block();
assertThat(load).isEqualTo(generate);
this.repository.saveToken(this.exchange, null).block();
WebSession session = this.exchange.getSession().block();
assertThat(session.getAttributes()).isEmpty();
load = this.repository.loadToken(this.exchange).block();
assertThat(load).isNull();
}
@Test
public void saveTokenChangeSessionId() {
String originalSessionId = this.exchange.getSession().block().getId();