WebSessionServerCsrfTokenRepository saves on getToken
Fixes gh-4801
This commit is contained in:
parent
776364d403
commit
7622826b69
|
@ -316,9 +316,9 @@ public class FormLoginTests {
|
|||
public static class CustomLoginPageController {
|
||||
@ResponseBody
|
||||
@GetMapping("/login")
|
||||
public Mono<String> login(ServerWebExchange exchange) {
|
||||
Mono<CsrfToken> token = exchange.getAttribute(CsrfToken.class.getName());
|
||||
return token.map(t ->
|
||||
public String login(ServerWebExchange exchange) {
|
||||
CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
|
||||
return
|
||||
"<!DOCTYPE html>\n"
|
||||
+ "<html lang=\"en\">\n"
|
||||
+ " <head>\n"
|
||||
|
@ -340,12 +340,12 @@ public class FormLoginTests {
|
|||
+ " <label for=\"password\" class=\"sr-only\">Password</label>\n"
|
||||
+ " <input type=\"password\" id=\"password\" name=\"password\" placeholder=\"Password\" required>\n"
|
||||
+ " </p>\n"
|
||||
+ " <input type=\"hidden\" name=\"" + t.getParameterName() + "\" value=\"" + t.getToken() + "\">\n"
|
||||
+ " <input type=\"hidden\" name=\"" + token.getParameterName() + "\" value=\"" + token.getToken() + "\">\n"
|
||||
+ " <button type=\"submit\">Sign in</button>\n"
|
||||
+ " </form>\n"
|
||||
+ " </div>\n"
|
||||
+ " </body>\n"
|
||||
+ "</html>");
|
||||
+ "</html>";
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -106,14 +106,19 @@ public class CsrfWebFilter implements WebFilter {
|
|||
private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
|
||||
return csrfToken(exchange)
|
||||
.doOnSuccess(csrfToken -> exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken))
|
||||
.doOnSuccess(csrfToken -> exchange.getAttributes().put(csrfToken.getParameterName(), csrfToken))
|
||||
.flatMap( t -> chain.filter(exchange))
|
||||
.then();
|
||||
}
|
||||
|
||||
private Mono<Mono<CsrfToken>> csrfToken(ServerWebExchange exchange) {
|
||||
private Mono<CsrfToken> csrfToken(ServerWebExchange exchange) {
|
||||
return this.serverCsrfTokenRepository.loadToken(exchange)
|
||||
.switchIfEmpty(this.serverCsrfTokenRepository.generateToken(exchange))
|
||||
.as(Mono::just); // FIXME eager saving of CsrfToken with .as
|
||||
.switchIfEmpty(generateToken(exchange));
|
||||
}
|
||||
|
||||
private Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
|
||||
return this.serverCsrfTokenRepository.generateToken(exchange)
|
||||
.flatMap(token -> this.serverCsrfTokenRepository.saveToken(exchange, token));
|
||||
}
|
||||
|
||||
private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher {
|
||||
|
|
|
@ -74,4 +74,28 @@ public final class DefaultCsrfToken implements CsrfToken {
|
|||
public String getToken() {
|
||||
return this.token;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o)
|
||||
return true;
|
||||
if (o == null || !(o instanceof CsrfToken))
|
||||
return false;
|
||||
|
||||
CsrfToken that = (CsrfToken) o;
|
||||
|
||||
if (!getToken().equals(that.getToken()))
|
||||
return false;
|
||||
if (!getParameterName().equals(that.getParameterName()))
|
||||
return false;
|
||||
return getHeaderName().equals(that.getHeaderName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = getToken().hashCode();
|
||||
result = 31 * result + getParameterName().hashCode();
|
||||
result = 31 * result + getHeaderName().hashCode();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,12 +49,16 @@ public class WebSessionServerCsrfTokenRepository
|
|||
|
||||
@Override
|
||||
public Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
|
||||
return Mono.defer(() -> Mono.just(createCsrfToken()))
|
||||
.flatMap(token -> saveToken(exchange, token));
|
||||
return exchange.getSession()
|
||||
.map(WebSession::getAttributes)
|
||||
.map(this::createCsrfToken);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<CsrfToken> saveToken(ServerWebExchange exchange, CsrfToken token) {
|
||||
if(token != null) {
|
||||
return Mono.just(token);
|
||||
}
|
||||
return exchange.getSession()
|
||||
.map(WebSession::getAttributes)
|
||||
.flatMap( attrs -> save(attrs, token));
|
||||
|
@ -113,6 +117,11 @@ public class WebSessionServerCsrfTokenRepository
|
|||
this.sessionAttributeName = sessionAttributeName;
|
||||
}
|
||||
|
||||
|
||||
private CsrfToken createCsrfToken(Map<String,Object> attributes) {
|
||||
return new LazyCsrfToken(attributes, createCsrfToken());
|
||||
}
|
||||
|
||||
private CsrfToken createCsrfToken() {
|
||||
return new DefaultCsrfToken(this.headerName, this.parameterName, createNewToken());
|
||||
}
|
||||
|
@ -120,4 +129,59 @@ public class WebSessionServerCsrfTokenRepository
|
|||
private String createNewToken() {
|
||||
return UUID.randomUUID().toString();
|
||||
}
|
||||
|
||||
private class LazyCsrfToken implements CsrfToken {
|
||||
private final Map<String,Object> attributes;
|
||||
private final CsrfToken delegate;
|
||||
|
||||
private LazyCsrfToken(Map<String, Object> attributes, CsrfToken delegate) {
|
||||
this.attributes = attributes;
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getHeaderName() {
|
||||
return this.delegate.getHeaderName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getParameterName() {
|
||||
return this.delegate.getParameterName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getToken() {
|
||||
putToken(this.attributes, this.delegate);
|
||||
return this.delegate.getToken();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o)
|
||||
return true;
|
||||
if (o == null || !(o instanceof CsrfToken))
|
||||
return false;
|
||||
|
||||
CsrfToken that = (CsrfToken) o;
|
||||
|
||||
if (!getToken().equals(that.getToken()))
|
||||
return false;
|
||||
if (!getParameterName().equals(that.getParameterName()))
|
||||
return false;
|
||||
return getHeaderName().equals(that.getHeaderName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = getToken().hashCode();
|
||||
result = 31 * result + getParameterName().hashCode();
|
||||
result = 31 * result + getHeaderName().hashCode();
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "LazyCsrfToken{" + "delegate=" + this.delegate + '}';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -61,9 +61,8 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
|
|||
private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
|
||||
MultiValueMap<String, String> queryParams = exchange.getRequest()
|
||||
.getQueryParams();
|
||||
Mono<CsrfToken> token = (Mono<CsrfToken>) exchange.getAttributes()
|
||||
.getOrDefault(CsrfToken.class.getName(), Mono.<CsrfToken>empty());
|
||||
return token
|
||||
CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
|
||||
return Mono.justOrEmpty(token)
|
||||
.map(LoginPageGeneratingWebFilter::csrfToken)
|
||||
.defaultIfEmpty("")
|
||||
.map(csrfTokenHtmlInput -> {
|
||||
|
|
|
@ -58,9 +58,8 @@ public class LogoutPageGeneratingWebFilter implements WebFilter {
|
|||
}
|
||||
|
||||
private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
|
||||
Mono<CsrfToken> token = (Mono<CsrfToken>) exchange.getAttributes()
|
||||
.getOrDefault(CsrfToken.class.getName(), Mono.<CsrfToken>empty());
|
||||
return token
|
||||
CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
|
||||
return Mono.justOrEmpty(token)
|
||||
.map(LogoutPageGeneratingWebFilter::csrfToken)
|
||||
.defaultIfEmpty("")
|
||||
.map(csrfTokenHtmlInput -> {
|
||||
|
|
|
@ -37,7 +37,7 @@ public class WebSessionServerCsrfTokenRepositoryTests {
|
|||
private MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"));
|
||||
|
||||
@Test
|
||||
public void generateTokenWhenNoSubscriptionThenNoSession() {
|
||||
public void generateTokenThenNoSession() {
|
||||
Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
|
||||
|
||||
Mono<Boolean> isSessionStarted = this.exchange.getSession()
|
||||
|
@ -49,12 +49,21 @@ public class WebSessionServerCsrfTokenRepositoryTests {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void generateTokenWhenSubscriptionThenAddsToSession() {
|
||||
public void generateTokenWhenSubscriptionThenNoSession() {
|
||||
Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
|
||||
|
||||
StepVerifier.create(result)
|
||||
.consumeNextWith( t -> assertThat(t).isNotNull())
|
||||
Mono<Boolean> isSessionStarted = this.exchange.getSession()
|
||||
.map(WebSession::isStarted);
|
||||
|
||||
StepVerifier.create(isSessionStarted)
|
||||
.expectNext(false)
|
||||
.verifyComplete();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void generateTokenWhenGetTokenThenAddsToSession() {
|
||||
Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
|
||||
result.block().getToken();
|
||||
|
||||
WebSession session = this.exchange.getSession().block();
|
||||
Map<String, Object> attributes = session.getAttributes();
|
||||
|
@ -62,30 +71,12 @@ public class WebSessionServerCsrfTokenRepositoryTests {
|
|||
assertThat(session.isStarted()).isTrue();
|
||||
assertThat(attributes).hasSize(1);
|
||||
assertThat(attributes.values().iterator().next()).isInstanceOf(CsrfToken.class);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void saveTokenWhenSetSessionAttributeNameAndSubscriptionThenAddsToSession() {
|
||||
CsrfToken token = new DefaultCsrfToken("h","p", "t");
|
||||
String attrName = "ATTR";
|
||||
this.repository.setSessionAttributeName(attrName);
|
||||
Mono<CsrfToken> result = this.repository.saveToken(this.exchange, token);
|
||||
|
||||
StepVerifier.create(result)
|
||||
.consumeNextWith(n -> assertThat(n).isEqualTo(token))
|
||||
.verifyComplete();
|
||||
|
||||
WebSession session = this.exchange.getSession().block();
|
||||
|
||||
assertThat(session.isStarted()).isTrue();
|
||||
assertThat(session.<WebSession>getAttribute(attrName)).isEqualTo(token);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void saveTokenWhenNullThenDeletes() {
|
||||
CsrfToken token = new DefaultCsrfToken("h","p", "t");
|
||||
this.repository.saveToken(this.exchange, token).block();
|
||||
CsrfToken token = this.repository.generateToken(this.exchange).block();
|
||||
token.getToken();
|
||||
|
||||
Mono<CsrfToken> result = this.repository.saveToken(this.exchange, null);
|
||||
StepVerifier.create(result)
|
||||
|
@ -99,6 +90,7 @@ public class WebSessionServerCsrfTokenRepositoryTests {
|
|||
@Test
|
||||
public void generateTokenAndLoadTokenDeleteTokenWhenNullThenDeletes() {
|
||||
CsrfToken generate = this.repository.generateToken(this.exchange).block();
|
||||
generate.getToken();
|
||||
|
||||
CsrfToken load = this.repository.loadToken(this.exchange).block();
|
||||
assertThat(load).isEqualTo(generate);
|
||||
|
|
Loading…
Reference in New Issue