WebSessionServerCsrfTokenRepository saves on getToken

Fixes gh-4801
This commit is contained in:
Rob Winch 2017-11-07 21:59:47 -06:00
parent 776364d403
commit 7622826b69
7 changed files with 123 additions and 40 deletions

View File

@ -316,9 +316,9 @@ public class FormLoginTests {
public static class CustomLoginPageController { public static class CustomLoginPageController {
@ResponseBody @ResponseBody
@GetMapping("/login") @GetMapping("/login")
public Mono<String> login(ServerWebExchange exchange) { public String login(ServerWebExchange exchange) {
Mono<CsrfToken> token = exchange.getAttribute(CsrfToken.class.getName()); CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
return token.map(t -> return
"<!DOCTYPE html>\n" "<!DOCTYPE html>\n"
+ "<html lang=\"en\">\n" + "<html lang=\"en\">\n"
+ " <head>\n" + " <head>\n"
@ -340,12 +340,12 @@ public class FormLoginTests {
+ " <label for=\"password\" class=\"sr-only\">Password</label>\n" + " <label for=\"password\" class=\"sr-only\">Password</label>\n"
+ " <input type=\"password\" id=\"password\" name=\"password\" placeholder=\"Password\" required>\n" + " <input type=\"password\" id=\"password\" name=\"password\" placeholder=\"Password\" required>\n"
+ " </p>\n" + " </p>\n"
+ " <input type=\"hidden\" name=\"" + t.getParameterName() + "\" value=\"" + t.getToken() + "\">\n" + " <input type=\"hidden\" name=\"" + token.getParameterName() + "\" value=\"" + token.getToken() + "\">\n"
+ " <button type=\"submit\">Sign in</button>\n" + " <button type=\"submit\">Sign in</button>\n"
+ " </form>\n" + " </form>\n"
+ " </div>\n" + " </div>\n"
+ " </body>\n" + " </body>\n"
+ "</html>"); + "</html>";
} }
} }

View File

@ -106,14 +106,19 @@ public class CsrfWebFilter implements WebFilter {
private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) { private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
return csrfToken(exchange) return csrfToken(exchange)
.doOnSuccess(csrfToken -> exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken)) .doOnSuccess(csrfToken -> exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken))
.doOnSuccess(csrfToken -> exchange.getAttributes().put(csrfToken.getParameterName(), csrfToken))
.flatMap( t -> chain.filter(exchange)) .flatMap( t -> chain.filter(exchange))
.then(); .then();
} }
private Mono<Mono<CsrfToken>> csrfToken(ServerWebExchange exchange) { private Mono<CsrfToken> csrfToken(ServerWebExchange exchange) {
return this.serverCsrfTokenRepository.loadToken(exchange) return this.serverCsrfTokenRepository.loadToken(exchange)
.switchIfEmpty(this.serverCsrfTokenRepository.generateToken(exchange)) .switchIfEmpty(generateToken(exchange));
.as(Mono::just); // FIXME eager saving of CsrfToken with .as }
private Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
return this.serverCsrfTokenRepository.generateToken(exchange)
.flatMap(token -> this.serverCsrfTokenRepository.saveToken(exchange, token));
} }
private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher { private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher {

View File

@ -74,4 +74,28 @@ public final class DefaultCsrfToken implements CsrfToken {
public String getToken() { public String getToken() {
return this.token; return this.token;
} }
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || !(o instanceof CsrfToken))
return false;
CsrfToken that = (CsrfToken) o;
if (!getToken().equals(that.getToken()))
return false;
if (!getParameterName().equals(that.getParameterName()))
return false;
return getHeaderName().equals(that.getHeaderName());
}
@Override
public int hashCode() {
int result = getToken().hashCode();
result = 31 * result + getParameterName().hashCode();
result = 31 * result + getHeaderName().hashCode();
return result;
}
} }

View File

@ -49,12 +49,16 @@ public class WebSessionServerCsrfTokenRepository
@Override @Override
public Mono<CsrfToken> generateToken(ServerWebExchange exchange) { public Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
return Mono.defer(() -> Mono.just(createCsrfToken())) return exchange.getSession()
.flatMap(token -> saveToken(exchange, token)); .map(WebSession::getAttributes)
.map(this::createCsrfToken);
} }
@Override @Override
public Mono<CsrfToken> saveToken(ServerWebExchange exchange, CsrfToken token) { public Mono<CsrfToken> saveToken(ServerWebExchange exchange, CsrfToken token) {
if(token != null) {
return Mono.just(token);
}
return exchange.getSession() return exchange.getSession()
.map(WebSession::getAttributes) .map(WebSession::getAttributes)
.flatMap( attrs -> save(attrs, token)); .flatMap( attrs -> save(attrs, token));
@ -113,6 +117,11 @@ public class WebSessionServerCsrfTokenRepository
this.sessionAttributeName = sessionAttributeName; this.sessionAttributeName = sessionAttributeName;
} }
private CsrfToken createCsrfToken(Map<String,Object> attributes) {
return new LazyCsrfToken(attributes, createCsrfToken());
}
private CsrfToken createCsrfToken() { private CsrfToken createCsrfToken() {
return new DefaultCsrfToken(this.headerName, this.parameterName, createNewToken()); return new DefaultCsrfToken(this.headerName, this.parameterName, createNewToken());
} }
@ -120,4 +129,59 @@ public class WebSessionServerCsrfTokenRepository
private String createNewToken() { private String createNewToken() {
return UUID.randomUUID().toString(); return UUID.randomUUID().toString();
} }
private class LazyCsrfToken implements CsrfToken {
private final Map<String,Object> attributes;
private final CsrfToken delegate;
private LazyCsrfToken(Map<String, Object> attributes, CsrfToken delegate) {
this.attributes = attributes;
this.delegate = delegate;
}
@Override
public String getHeaderName() {
return this.delegate.getHeaderName();
}
@Override
public String getParameterName() {
return this.delegate.getParameterName();
}
@Override
public String getToken() {
putToken(this.attributes, this.delegate);
return this.delegate.getToken();
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || !(o instanceof CsrfToken))
return false;
CsrfToken that = (CsrfToken) o;
if (!getToken().equals(that.getToken()))
return false;
if (!getParameterName().equals(that.getParameterName()))
return false;
return getHeaderName().equals(that.getHeaderName());
}
@Override
public int hashCode() {
int result = getToken().hashCode();
result = 31 * result + getParameterName().hashCode();
result = 31 * result + getHeaderName().hashCode();
return result;
}
@Override
public String toString() {
return "LazyCsrfToken{" + "delegate=" + this.delegate + '}';
}
}
} }

View File

@ -61,9 +61,8 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) { private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
MultiValueMap<String, String> queryParams = exchange.getRequest() MultiValueMap<String, String> queryParams = exchange.getRequest()
.getQueryParams(); .getQueryParams();
Mono<CsrfToken> token = (Mono<CsrfToken>) exchange.getAttributes() CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
.getOrDefault(CsrfToken.class.getName(), Mono.<CsrfToken>empty()); return Mono.justOrEmpty(token)
return token
.map(LoginPageGeneratingWebFilter::csrfToken) .map(LoginPageGeneratingWebFilter::csrfToken)
.defaultIfEmpty("") .defaultIfEmpty("")
.map(csrfTokenHtmlInput -> { .map(csrfTokenHtmlInput -> {

View File

@ -58,9 +58,8 @@ public class LogoutPageGeneratingWebFilter implements WebFilter {
} }
private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) { private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
Mono<CsrfToken> token = (Mono<CsrfToken>) exchange.getAttributes() CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
.getOrDefault(CsrfToken.class.getName(), Mono.<CsrfToken>empty()); return Mono.justOrEmpty(token)
return token
.map(LogoutPageGeneratingWebFilter::csrfToken) .map(LogoutPageGeneratingWebFilter::csrfToken)
.defaultIfEmpty("") .defaultIfEmpty("")
.map(csrfTokenHtmlInput -> { .map(csrfTokenHtmlInput -> {

View File

@ -37,7 +37,7 @@ public class WebSessionServerCsrfTokenRepositoryTests {
private MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/")); private MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"));
@Test @Test
public void generateTokenWhenNoSubscriptionThenNoSession() { public void generateTokenThenNoSession() {
Mono<CsrfToken> result = this.repository.generateToken(this.exchange); Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
Mono<Boolean> isSessionStarted = this.exchange.getSession() Mono<Boolean> isSessionStarted = this.exchange.getSession()
@ -49,12 +49,21 @@ public class WebSessionServerCsrfTokenRepositoryTests {
} }
@Test @Test
public void generateTokenWhenSubscriptionThenAddsToSession() { public void generateTokenWhenSubscriptionThenNoSession() {
Mono<CsrfToken> result = this.repository.generateToken(this.exchange); Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
StepVerifier.create(result) Mono<Boolean> isSessionStarted = this.exchange.getSession()
.consumeNextWith( t -> assertThat(t).isNotNull()) .map(WebSession::isStarted);
StepVerifier.create(isSessionStarted)
.expectNext(false)
.verifyComplete(); .verifyComplete();
}
@Test
public void generateTokenWhenGetTokenThenAddsToSession() {
Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
result.block().getToken();
WebSession session = this.exchange.getSession().block(); WebSession session = this.exchange.getSession().block();
Map<String, Object> attributes = session.getAttributes(); Map<String, Object> attributes = session.getAttributes();
@ -62,30 +71,12 @@ public class WebSessionServerCsrfTokenRepositoryTests {
assertThat(session.isStarted()).isTrue(); assertThat(session.isStarted()).isTrue();
assertThat(attributes).hasSize(1); assertThat(attributes).hasSize(1);
assertThat(attributes.values().iterator().next()).isInstanceOf(CsrfToken.class); assertThat(attributes.values().iterator().next()).isInstanceOf(CsrfToken.class);
}
@Test
public void saveTokenWhenSetSessionAttributeNameAndSubscriptionThenAddsToSession() {
CsrfToken token = new DefaultCsrfToken("h","p", "t");
String attrName = "ATTR";
this.repository.setSessionAttributeName(attrName);
Mono<CsrfToken> result = this.repository.saveToken(this.exchange, token);
StepVerifier.create(result)
.consumeNextWith(n -> assertThat(n).isEqualTo(token))
.verifyComplete();
WebSession session = this.exchange.getSession().block();
assertThat(session.isStarted()).isTrue();
assertThat(session.<WebSession>getAttribute(attrName)).isEqualTo(token);
} }
@Test @Test
public void saveTokenWhenNullThenDeletes() { public void saveTokenWhenNullThenDeletes() {
CsrfToken token = new DefaultCsrfToken("h","p", "t"); CsrfToken token = this.repository.generateToken(this.exchange).block();
this.repository.saveToken(this.exchange, token).block(); token.getToken();
Mono<CsrfToken> result = this.repository.saveToken(this.exchange, null); Mono<CsrfToken> result = this.repository.saveToken(this.exchange, null);
StepVerifier.create(result) StepVerifier.create(result)
@ -99,6 +90,7 @@ public class WebSessionServerCsrfTokenRepositoryTests {
@Test @Test
public void generateTokenAndLoadTokenDeleteTokenWhenNullThenDeletes() { public void generateTokenAndLoadTokenDeleteTokenWhenNullThenDeletes() {
CsrfToken generate = this.repository.generateToken(this.exchange).block(); CsrfToken generate = this.repository.generateToken(this.exchange).block();
generate.getToken();
CsrfToken load = this.repository.loadToken(this.exchange).block(); CsrfToken load = this.repository.loadToken(this.exchange).block();
assertThat(load).isEqualTo(generate); assertThat(load).isEqualTo(generate);