WebSessionServerCsrfTokenRepository saves on getToken

Fixes gh-4801
This commit is contained in:
Rob Winch 2017-11-07 21:59:47 -06:00
parent 776364d403
commit 7622826b69
7 changed files with 123 additions and 40 deletions

View File

@ -316,9 +316,9 @@ public class FormLoginTests {
public static class CustomLoginPageController {
@ResponseBody
@GetMapping("/login")
public Mono<String> login(ServerWebExchange exchange) {
Mono<CsrfToken> token = exchange.getAttribute(CsrfToken.class.getName());
return token.map(t ->
public String login(ServerWebExchange exchange) {
CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
return
"<!DOCTYPE html>\n"
+ "<html lang=\"en\">\n"
+ " <head>\n"
@ -340,12 +340,12 @@ public class FormLoginTests {
+ " <label for=\"password\" class=\"sr-only\">Password</label>\n"
+ " <input type=\"password\" id=\"password\" name=\"password\" placeholder=\"Password\" required>\n"
+ " </p>\n"
+ " <input type=\"hidden\" name=\"" + t.getParameterName() + "\" value=\"" + t.getToken() + "\">\n"
+ " <input type=\"hidden\" name=\"" + token.getParameterName() + "\" value=\"" + token.getToken() + "\">\n"
+ " <button type=\"submit\">Sign in</button>\n"
+ " </form>\n"
+ " </div>\n"
+ " </body>\n"
+ "</html>");
+ "</html>";
}
}

View File

@ -106,14 +106,19 @@ public class CsrfWebFilter implements WebFilter {
private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
return csrfToken(exchange)
.doOnSuccess(csrfToken -> exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken))
.doOnSuccess(csrfToken -> exchange.getAttributes().put(csrfToken.getParameterName(), csrfToken))
.flatMap( t -> chain.filter(exchange))
.then();
}
private Mono<Mono<CsrfToken>> csrfToken(ServerWebExchange exchange) {
private Mono<CsrfToken> csrfToken(ServerWebExchange exchange) {
return this.serverCsrfTokenRepository.loadToken(exchange)
.switchIfEmpty(this.serverCsrfTokenRepository.generateToken(exchange))
.as(Mono::just); // FIXME eager saving of CsrfToken with .as
.switchIfEmpty(generateToken(exchange));
}
private Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
return this.serverCsrfTokenRepository.generateToken(exchange)
.flatMap(token -> this.serverCsrfTokenRepository.saveToken(exchange, token));
}
private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher {

View File

@ -74,4 +74,28 @@ public final class DefaultCsrfToken implements CsrfToken {
public String getToken() {
return this.token;
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || !(o instanceof CsrfToken))
return false;
CsrfToken that = (CsrfToken) o;
if (!getToken().equals(that.getToken()))
return false;
if (!getParameterName().equals(that.getParameterName()))
return false;
return getHeaderName().equals(that.getHeaderName());
}
@Override
public int hashCode() {
int result = getToken().hashCode();
result = 31 * result + getParameterName().hashCode();
result = 31 * result + getHeaderName().hashCode();
return result;
}
}

View File

@ -49,12 +49,16 @@ public class WebSessionServerCsrfTokenRepository
@Override
public Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
return Mono.defer(() -> Mono.just(createCsrfToken()))
.flatMap(token -> saveToken(exchange, token));
return exchange.getSession()
.map(WebSession::getAttributes)
.map(this::createCsrfToken);
}
@Override
public Mono<CsrfToken> saveToken(ServerWebExchange exchange, CsrfToken token) {
if(token != null) {
return Mono.just(token);
}
return exchange.getSession()
.map(WebSession::getAttributes)
.flatMap( attrs -> save(attrs, token));
@ -113,6 +117,11 @@ public class WebSessionServerCsrfTokenRepository
this.sessionAttributeName = sessionAttributeName;
}
private CsrfToken createCsrfToken(Map<String,Object> attributes) {
return new LazyCsrfToken(attributes, createCsrfToken());
}
private CsrfToken createCsrfToken() {
return new DefaultCsrfToken(this.headerName, this.parameterName, createNewToken());
}
@ -120,4 +129,59 @@ public class WebSessionServerCsrfTokenRepository
private String createNewToken() {
return UUID.randomUUID().toString();
}
private class LazyCsrfToken implements CsrfToken {
private final Map<String,Object> attributes;
private final CsrfToken delegate;
private LazyCsrfToken(Map<String, Object> attributes, CsrfToken delegate) {
this.attributes = attributes;
this.delegate = delegate;
}
@Override
public String getHeaderName() {
return this.delegate.getHeaderName();
}
@Override
public String getParameterName() {
return this.delegate.getParameterName();
}
@Override
public String getToken() {
putToken(this.attributes, this.delegate);
return this.delegate.getToken();
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || !(o instanceof CsrfToken))
return false;
CsrfToken that = (CsrfToken) o;
if (!getToken().equals(that.getToken()))
return false;
if (!getParameterName().equals(that.getParameterName()))
return false;
return getHeaderName().equals(that.getHeaderName());
}
@Override
public int hashCode() {
int result = getToken().hashCode();
result = 31 * result + getParameterName().hashCode();
result = 31 * result + getHeaderName().hashCode();
return result;
}
@Override
public String toString() {
return "LazyCsrfToken{" + "delegate=" + this.delegate + '}';
}
}
}

View File

@ -61,9 +61,8 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
MultiValueMap<String, String> queryParams = exchange.getRequest()
.getQueryParams();
Mono<CsrfToken> token = (Mono<CsrfToken>) exchange.getAttributes()
.getOrDefault(CsrfToken.class.getName(), Mono.<CsrfToken>empty());
return token
CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
return Mono.justOrEmpty(token)
.map(LoginPageGeneratingWebFilter::csrfToken)
.defaultIfEmpty("")
.map(csrfTokenHtmlInput -> {

View File

@ -58,9 +58,8 @@ public class LogoutPageGeneratingWebFilter implements WebFilter {
}
private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
Mono<CsrfToken> token = (Mono<CsrfToken>) exchange.getAttributes()
.getOrDefault(CsrfToken.class.getName(), Mono.<CsrfToken>empty());
return token
CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
return Mono.justOrEmpty(token)
.map(LogoutPageGeneratingWebFilter::csrfToken)
.defaultIfEmpty("")
.map(csrfTokenHtmlInput -> {

View File

@ -37,7 +37,7 @@ public class WebSessionServerCsrfTokenRepositoryTests {
private MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"));
@Test
public void generateTokenWhenNoSubscriptionThenNoSession() {
public void generateTokenThenNoSession() {
Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
Mono<Boolean> isSessionStarted = this.exchange.getSession()
@ -49,12 +49,21 @@ public class WebSessionServerCsrfTokenRepositoryTests {
}
@Test
public void generateTokenWhenSubscriptionThenAddsToSession() {
public void generateTokenWhenSubscriptionThenNoSession() {
Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
StepVerifier.create(result)
.consumeNextWith( t -> assertThat(t).isNotNull())
Mono<Boolean> isSessionStarted = this.exchange.getSession()
.map(WebSession::isStarted);
StepVerifier.create(isSessionStarted)
.expectNext(false)
.verifyComplete();
}
@Test
public void generateTokenWhenGetTokenThenAddsToSession() {
Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
result.block().getToken();
WebSession session = this.exchange.getSession().block();
Map<String, Object> attributes = session.getAttributes();
@ -62,30 +71,12 @@ public class WebSessionServerCsrfTokenRepositoryTests {
assertThat(session.isStarted()).isTrue();
assertThat(attributes).hasSize(1);
assertThat(attributes.values().iterator().next()).isInstanceOf(CsrfToken.class);
}
@Test
public void saveTokenWhenSetSessionAttributeNameAndSubscriptionThenAddsToSession() {
CsrfToken token = new DefaultCsrfToken("h","p", "t");
String attrName = "ATTR";
this.repository.setSessionAttributeName(attrName);
Mono<CsrfToken> result = this.repository.saveToken(this.exchange, token);
StepVerifier.create(result)
.consumeNextWith(n -> assertThat(n).isEqualTo(token))
.verifyComplete();
WebSession session = this.exchange.getSession().block();
assertThat(session.isStarted()).isTrue();
assertThat(session.<WebSession>getAttribute(attrName)).isEqualTo(token);
}
@Test
public void saveTokenWhenNullThenDeletes() {
CsrfToken token = new DefaultCsrfToken("h","p", "t");
this.repository.saveToken(this.exchange, token).block();
CsrfToken token = this.repository.generateToken(this.exchange).block();
token.getToken();
Mono<CsrfToken> result = this.repository.saveToken(this.exchange, null);
StepVerifier.create(result)
@ -99,6 +90,7 @@ public class WebSessionServerCsrfTokenRepositoryTests {
@Test
public void generateTokenAndLoadTokenDeleteTokenWhenNullThenDeletes() {
CsrfToken generate = this.repository.generateToken(this.exchange).block();
generate.getToken();
CsrfToken load = this.repository.loadToken(this.exchange).block();
assertThat(load).isEqualTo(generate);