WebSessionServerCsrfTokenRepository saves on getToken
Fixes gh-4801
This commit is contained in:
parent
776364d403
commit
7622826b69
|
@ -316,9 +316,9 @@ public class FormLoginTests {
|
||||||
public static class CustomLoginPageController {
|
public static class CustomLoginPageController {
|
||||||
@ResponseBody
|
@ResponseBody
|
||||||
@GetMapping("/login")
|
@GetMapping("/login")
|
||||||
public Mono<String> login(ServerWebExchange exchange) {
|
public String login(ServerWebExchange exchange) {
|
||||||
Mono<CsrfToken> token = exchange.getAttribute(CsrfToken.class.getName());
|
CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
|
||||||
return token.map(t ->
|
return
|
||||||
"<!DOCTYPE html>\n"
|
"<!DOCTYPE html>\n"
|
||||||
+ "<html lang=\"en\">\n"
|
+ "<html lang=\"en\">\n"
|
||||||
+ " <head>\n"
|
+ " <head>\n"
|
||||||
|
@ -340,12 +340,12 @@ public class FormLoginTests {
|
||||||
+ " <label for=\"password\" class=\"sr-only\">Password</label>\n"
|
+ " <label for=\"password\" class=\"sr-only\">Password</label>\n"
|
||||||
+ " <input type=\"password\" id=\"password\" name=\"password\" placeholder=\"Password\" required>\n"
|
+ " <input type=\"password\" id=\"password\" name=\"password\" placeholder=\"Password\" required>\n"
|
||||||
+ " </p>\n"
|
+ " </p>\n"
|
||||||
+ " <input type=\"hidden\" name=\"" + t.getParameterName() + "\" value=\"" + t.getToken() + "\">\n"
|
+ " <input type=\"hidden\" name=\"" + token.getParameterName() + "\" value=\"" + token.getToken() + "\">\n"
|
||||||
+ " <button type=\"submit\">Sign in</button>\n"
|
+ " <button type=\"submit\">Sign in</button>\n"
|
||||||
+ " </form>\n"
|
+ " </form>\n"
|
||||||
+ " </div>\n"
|
+ " </div>\n"
|
||||||
+ " </body>\n"
|
+ " </body>\n"
|
||||||
+ "</html>");
|
+ "</html>";
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -106,14 +106,19 @@ public class CsrfWebFilter implements WebFilter {
|
||||||
private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
|
private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
|
||||||
return csrfToken(exchange)
|
return csrfToken(exchange)
|
||||||
.doOnSuccess(csrfToken -> exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken))
|
.doOnSuccess(csrfToken -> exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken))
|
||||||
|
.doOnSuccess(csrfToken -> exchange.getAttributes().put(csrfToken.getParameterName(), csrfToken))
|
||||||
.flatMap( t -> chain.filter(exchange))
|
.flatMap( t -> chain.filter(exchange))
|
||||||
.then();
|
.then();
|
||||||
}
|
}
|
||||||
|
|
||||||
private Mono<Mono<CsrfToken>> csrfToken(ServerWebExchange exchange) {
|
private Mono<CsrfToken> csrfToken(ServerWebExchange exchange) {
|
||||||
return this.serverCsrfTokenRepository.loadToken(exchange)
|
return this.serverCsrfTokenRepository.loadToken(exchange)
|
||||||
.switchIfEmpty(this.serverCsrfTokenRepository.generateToken(exchange))
|
.switchIfEmpty(generateToken(exchange));
|
||||||
.as(Mono::just); // FIXME eager saving of CsrfToken with .as
|
}
|
||||||
|
|
||||||
|
private Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
|
||||||
|
return this.serverCsrfTokenRepository.generateToken(exchange)
|
||||||
|
.flatMap(token -> this.serverCsrfTokenRepository.saveToken(exchange, token));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher {
|
private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher {
|
||||||
|
|
|
@ -74,4 +74,28 @@ public final class DefaultCsrfToken implements CsrfToken {
|
||||||
public String getToken() {
|
public String getToken() {
|
||||||
return this.token;
|
return this.token;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o)
|
||||||
|
return true;
|
||||||
|
if (o == null || !(o instanceof CsrfToken))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
CsrfToken that = (CsrfToken) o;
|
||||||
|
|
||||||
|
if (!getToken().equals(that.getToken()))
|
||||||
|
return false;
|
||||||
|
if (!getParameterName().equals(that.getParameterName()))
|
||||||
|
return false;
|
||||||
|
return getHeaderName().equals(that.getHeaderName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
int result = getToken().hashCode();
|
||||||
|
result = 31 * result + getParameterName().hashCode();
|
||||||
|
result = 31 * result + getHeaderName().hashCode();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,12 +49,16 @@ public class WebSessionServerCsrfTokenRepository
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
|
public Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
|
||||||
return Mono.defer(() -> Mono.just(createCsrfToken()))
|
return exchange.getSession()
|
||||||
.flatMap(token -> saveToken(exchange, token));
|
.map(WebSession::getAttributes)
|
||||||
|
.map(this::createCsrfToken);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Mono<CsrfToken> saveToken(ServerWebExchange exchange, CsrfToken token) {
|
public Mono<CsrfToken> saveToken(ServerWebExchange exchange, CsrfToken token) {
|
||||||
|
if(token != null) {
|
||||||
|
return Mono.just(token);
|
||||||
|
}
|
||||||
return exchange.getSession()
|
return exchange.getSession()
|
||||||
.map(WebSession::getAttributes)
|
.map(WebSession::getAttributes)
|
||||||
.flatMap( attrs -> save(attrs, token));
|
.flatMap( attrs -> save(attrs, token));
|
||||||
|
@ -113,6 +117,11 @@ public class WebSessionServerCsrfTokenRepository
|
||||||
this.sessionAttributeName = sessionAttributeName;
|
this.sessionAttributeName = sessionAttributeName;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private CsrfToken createCsrfToken(Map<String,Object> attributes) {
|
||||||
|
return new LazyCsrfToken(attributes, createCsrfToken());
|
||||||
|
}
|
||||||
|
|
||||||
private CsrfToken createCsrfToken() {
|
private CsrfToken createCsrfToken() {
|
||||||
return new DefaultCsrfToken(this.headerName, this.parameterName, createNewToken());
|
return new DefaultCsrfToken(this.headerName, this.parameterName, createNewToken());
|
||||||
}
|
}
|
||||||
|
@ -120,4 +129,59 @@ public class WebSessionServerCsrfTokenRepository
|
||||||
private String createNewToken() {
|
private String createNewToken() {
|
||||||
return UUID.randomUUID().toString();
|
return UUID.randomUUID().toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private class LazyCsrfToken implements CsrfToken {
|
||||||
|
private final Map<String,Object> attributes;
|
||||||
|
private final CsrfToken delegate;
|
||||||
|
|
||||||
|
private LazyCsrfToken(Map<String, Object> attributes, CsrfToken delegate) {
|
||||||
|
this.attributes = attributes;
|
||||||
|
this.delegate = delegate;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getHeaderName() {
|
||||||
|
return this.delegate.getHeaderName();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getParameterName() {
|
||||||
|
return this.delegate.getParameterName();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getToken() {
|
||||||
|
putToken(this.attributes, this.delegate);
|
||||||
|
return this.delegate.getToken();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o)
|
||||||
|
return true;
|
||||||
|
if (o == null || !(o instanceof CsrfToken))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
CsrfToken that = (CsrfToken) o;
|
||||||
|
|
||||||
|
if (!getToken().equals(that.getToken()))
|
||||||
|
return false;
|
||||||
|
if (!getParameterName().equals(that.getParameterName()))
|
||||||
|
return false;
|
||||||
|
return getHeaderName().equals(that.getHeaderName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
int result = getToken().hashCode();
|
||||||
|
result = 31 * result + getParameterName().hashCode();
|
||||||
|
result = 31 * result + getHeaderName().hashCode();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "LazyCsrfToken{" + "delegate=" + this.delegate + '}';
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,9 +61,8 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
|
||||||
private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
|
private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
|
||||||
MultiValueMap<String, String> queryParams = exchange.getRequest()
|
MultiValueMap<String, String> queryParams = exchange.getRequest()
|
||||||
.getQueryParams();
|
.getQueryParams();
|
||||||
Mono<CsrfToken> token = (Mono<CsrfToken>) exchange.getAttributes()
|
CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
|
||||||
.getOrDefault(CsrfToken.class.getName(), Mono.<CsrfToken>empty());
|
return Mono.justOrEmpty(token)
|
||||||
return token
|
|
||||||
.map(LoginPageGeneratingWebFilter::csrfToken)
|
.map(LoginPageGeneratingWebFilter::csrfToken)
|
||||||
.defaultIfEmpty("")
|
.defaultIfEmpty("")
|
||||||
.map(csrfTokenHtmlInput -> {
|
.map(csrfTokenHtmlInput -> {
|
||||||
|
|
|
@ -58,9 +58,8 @@ public class LogoutPageGeneratingWebFilter implements WebFilter {
|
||||||
}
|
}
|
||||||
|
|
||||||
private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
|
private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
|
||||||
Mono<CsrfToken> token = (Mono<CsrfToken>) exchange.getAttributes()
|
CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
|
||||||
.getOrDefault(CsrfToken.class.getName(), Mono.<CsrfToken>empty());
|
return Mono.justOrEmpty(token)
|
||||||
return token
|
|
||||||
.map(LogoutPageGeneratingWebFilter::csrfToken)
|
.map(LogoutPageGeneratingWebFilter::csrfToken)
|
||||||
.defaultIfEmpty("")
|
.defaultIfEmpty("")
|
||||||
.map(csrfTokenHtmlInput -> {
|
.map(csrfTokenHtmlInput -> {
|
||||||
|
|
|
@ -37,7 +37,7 @@ public class WebSessionServerCsrfTokenRepositoryTests {
|
||||||
private MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"));
|
private MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"));
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void generateTokenWhenNoSubscriptionThenNoSession() {
|
public void generateTokenThenNoSession() {
|
||||||
Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
|
Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
|
||||||
|
|
||||||
Mono<Boolean> isSessionStarted = this.exchange.getSession()
|
Mono<Boolean> isSessionStarted = this.exchange.getSession()
|
||||||
|
@ -49,12 +49,21 @@ public class WebSessionServerCsrfTokenRepositoryTests {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void generateTokenWhenSubscriptionThenAddsToSession() {
|
public void generateTokenWhenSubscriptionThenNoSession() {
|
||||||
Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
|
Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
|
||||||
|
|
||||||
StepVerifier.create(result)
|
Mono<Boolean> isSessionStarted = this.exchange.getSession()
|
||||||
.consumeNextWith( t -> assertThat(t).isNotNull())
|
.map(WebSession::isStarted);
|
||||||
|
|
||||||
|
StepVerifier.create(isSessionStarted)
|
||||||
|
.expectNext(false)
|
||||||
.verifyComplete();
|
.verifyComplete();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void generateTokenWhenGetTokenThenAddsToSession() {
|
||||||
|
Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
|
||||||
|
result.block().getToken();
|
||||||
|
|
||||||
WebSession session = this.exchange.getSession().block();
|
WebSession session = this.exchange.getSession().block();
|
||||||
Map<String, Object> attributes = session.getAttributes();
|
Map<String, Object> attributes = session.getAttributes();
|
||||||
|
@ -62,30 +71,12 @@ public class WebSessionServerCsrfTokenRepositoryTests {
|
||||||
assertThat(session.isStarted()).isTrue();
|
assertThat(session.isStarted()).isTrue();
|
||||||
assertThat(attributes).hasSize(1);
|
assertThat(attributes).hasSize(1);
|
||||||
assertThat(attributes.values().iterator().next()).isInstanceOf(CsrfToken.class);
|
assertThat(attributes.values().iterator().next()).isInstanceOf(CsrfToken.class);
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void saveTokenWhenSetSessionAttributeNameAndSubscriptionThenAddsToSession() {
|
|
||||||
CsrfToken token = new DefaultCsrfToken("h","p", "t");
|
|
||||||
String attrName = "ATTR";
|
|
||||||
this.repository.setSessionAttributeName(attrName);
|
|
||||||
Mono<CsrfToken> result = this.repository.saveToken(this.exchange, token);
|
|
||||||
|
|
||||||
StepVerifier.create(result)
|
|
||||||
.consumeNextWith(n -> assertThat(n).isEqualTo(token))
|
|
||||||
.verifyComplete();
|
|
||||||
|
|
||||||
WebSession session = this.exchange.getSession().block();
|
|
||||||
|
|
||||||
assertThat(session.isStarted()).isTrue();
|
|
||||||
assertThat(session.<WebSession>getAttribute(attrName)).isEqualTo(token);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void saveTokenWhenNullThenDeletes() {
|
public void saveTokenWhenNullThenDeletes() {
|
||||||
CsrfToken token = new DefaultCsrfToken("h","p", "t");
|
CsrfToken token = this.repository.generateToken(this.exchange).block();
|
||||||
this.repository.saveToken(this.exchange, token).block();
|
token.getToken();
|
||||||
|
|
||||||
Mono<CsrfToken> result = this.repository.saveToken(this.exchange, null);
|
Mono<CsrfToken> result = this.repository.saveToken(this.exchange, null);
|
||||||
StepVerifier.create(result)
|
StepVerifier.create(result)
|
||||||
|
@ -99,6 +90,7 @@ public class WebSessionServerCsrfTokenRepositoryTests {
|
||||||
@Test
|
@Test
|
||||||
public void generateTokenAndLoadTokenDeleteTokenWhenNullThenDeletes() {
|
public void generateTokenAndLoadTokenDeleteTokenWhenNullThenDeletes() {
|
||||||
CsrfToken generate = this.repository.generateToken(this.exchange).block();
|
CsrfToken generate = this.repository.generateToken(this.exchange).block();
|
||||||
|
generate.getToken();
|
||||||
|
|
||||||
CsrfToken load = this.repository.loadToken(this.exchange).block();
|
CsrfToken load = this.repository.loadToken(this.exchange).block();
|
||||||
assertThat(load).isEqualTo(generate);
|
assertThat(load).isEqualTo(generate);
|
||||||
|
|
Loading…
Reference in New Issue