parent
9920cb41d1
commit
aa12748c9b
|
@ -35,6 +35,8 @@ import org.springframework.security.web.util.matcher.RequestMatcher;
|
|||
import org.springframework.util.Assert;
|
||||
import org.springframework.web.filter.OncePerRequestFilter;
|
||||
|
||||
import static java.lang.Boolean.TRUE;
|
||||
|
||||
/**
|
||||
* <p>
|
||||
* Applies
|
||||
|
@ -63,6 +65,16 @@ public final class CsrfFilter extends OncePerRequestFilter {
|
|||
*/
|
||||
public static final RequestMatcher DEFAULT_CSRF_MATCHER = new DefaultRequiresCsrfMatcher();
|
||||
|
||||
/**
|
||||
* The attribute name to use when marking a given request as one that should not be filtered.
|
||||
*
|
||||
* To use, set the attribute on your {@link HttpServletRequest}:
|
||||
* <pre>
|
||||
* CsrfFilter.skipRequest(request);
|
||||
* </pre>
|
||||
*/
|
||||
private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfFilter.class.getName();
|
||||
|
||||
private final Log logger = LogFactory.getLog(getClass());
|
||||
private final CsrfTokenRepository tokenRepository;
|
||||
private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
|
||||
|
@ -73,6 +85,11 @@ public final class CsrfFilter extends OncePerRequestFilter {
|
|||
this.tokenRepository = csrfTokenRepository;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
|
||||
return TRUE.equals(request.getAttribute(SHOULD_NOT_FILTER));
|
||||
}
|
||||
|
||||
/*
|
||||
* (non-Javadoc)
|
||||
*
|
||||
|
@ -124,6 +141,10 @@ public final class CsrfFilter extends OncePerRequestFilter {
|
|||
filterChain.doFilter(request, response);
|
||||
}
|
||||
|
||||
public static void skipRequest(HttpServletRequest request) {
|
||||
request.setAttribute(SHOULD_NOT_FILTER, TRUE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Specifies a {@link RequestMatcher} that is used to determine if CSRF protection
|
||||
* should be applied. If the {@link RequestMatcher} returns true for a given request,
|
||||
|
|
|
@ -32,6 +32,8 @@ import org.springframework.web.server.ServerWebExchange;
|
|||
import org.springframework.web.server.WebFilter;
|
||||
import org.springframework.web.server.WebFilterChain;
|
||||
|
||||
import static java.lang.Boolean.TRUE;
|
||||
|
||||
/**
|
||||
* <p>
|
||||
* Applies
|
||||
|
@ -60,6 +62,16 @@ import org.springframework.web.server.WebFilterChain;
|
|||
public class CsrfWebFilter implements WebFilter {
|
||||
public static final ServerWebExchangeMatcher DEFAULT_CSRF_MATCHER = new DefaultRequireCsrfProtectionMatcher();
|
||||
|
||||
/**
|
||||
* The attribute name to use when marking a given request as one that should not be filtered.
|
||||
*
|
||||
* To use, set the attribute on your {@link ServerWebExchange}:
|
||||
* <pre>
|
||||
* CsrfWebFilter.skipExchange(exchange);
|
||||
* </pre>
|
||||
*/
|
||||
private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfWebFilter.class.getName();
|
||||
|
||||
private ServerWebExchangeMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
|
||||
|
||||
private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository();
|
||||
|
@ -86,6 +98,10 @@ public class CsrfWebFilter implements WebFilter {
|
|||
|
||||
@Override
|
||||
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
|
||||
if (TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) {
|
||||
return chain.filter(exchange).then(Mono.empty());
|
||||
}
|
||||
|
||||
return this.requireCsrfProtectionMatcher.matches(exchange)
|
||||
.filter( matchResult -> matchResult.isMatch())
|
||||
.filter( matchResult -> !exchange.getAttributes().containsKey(CsrfToken.class.getName()))
|
||||
|
@ -96,6 +112,10 @@ public class CsrfWebFilter implements WebFilter {
|
|||
.handle(exchange, e));
|
||||
}
|
||||
|
||||
public static void skipExchange(ServerWebExchange exchange) {
|
||||
exchange.getAttributes().put(SHOULD_NOT_FILTER, TRUE);
|
||||
}
|
||||
|
||||
private Mono<Void> validateToken(ServerWebExchange exchange) {
|
||||
return this.csrfTokenRepository.loadToken(exchange)
|
||||
.switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("CSRF Token has been associated to this client"))))
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.junit.runner.RunWith;
|
|||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
|
||||
import org.springframework.mock.web.MockFilterChain;
|
||||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.mock.web.MockHttpServletResponse;
|
||||
import org.springframework.security.web.access.AccessDeniedHandler;
|
||||
|
@ -39,6 +40,8 @@ import org.springframework.security.web.util.matcher.RequestMatcher;
|
|||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.lenient;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
@ -390,6 +393,22 @@ public class CsrfFilterTests {
|
|||
verifyZeroInteractions(this.filterChain);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenSkipRequestInvokedThenSkips()
|
||||
throws Exception {
|
||||
|
||||
CsrfTokenRepository repository = mock(CsrfTokenRepository.class);
|
||||
CsrfFilter filter = new CsrfFilter(repository);
|
||||
|
||||
lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token);
|
||||
|
||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||
CsrfFilter.skipRequest(request);
|
||||
filter.doFilter(request, new MockHttpServletResponse(), new MockFilterChain());
|
||||
|
||||
verifyZeroInteractions(repository);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void setRequireCsrfProtectionMatcherNull() {
|
||||
this.filter.setRequireCsrfProtectionMatcher(null);
|
||||
|
|
|
@ -20,19 +20,24 @@ import org.junit.Test;
|
|||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
|
||||
import org.springframework.mock.web.server.MockServerWebExchange;
|
||||
import org.springframework.web.server.WebFilterChain;
|
||||
import org.springframework.web.server.WebSession;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.test.StepVerifier;
|
||||
import reactor.test.publisher.PublisherProbe;
|
||||
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
|
||||
import org.springframework.mock.web.server.MockServerWebExchange;
|
||||
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
|
||||
import org.springframework.web.server.WebFilterChain;
|
||||
import org.springframework.web.server.WebSession;
|
||||
|
||||
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verifyZeroInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.mock.web.server.MockServerWebExchange.from;
|
||||
|
||||
/**
|
||||
* @author Rob Winch
|
||||
|
@ -49,10 +54,10 @@ public class CsrfWebFilterTests {
|
|||
|
||||
private CsrfWebFilter csrfFilter = new CsrfWebFilter();
|
||||
|
||||
private MockServerWebExchange get = MockServerWebExchange.from(
|
||||
private MockServerWebExchange get = from(
|
||||
MockServerHttpRequest.get("/"));
|
||||
|
||||
private MockServerWebExchange post = MockServerWebExchange.from(
|
||||
private MockServerWebExchange post = from(
|
||||
MockServerHttpRequest.post("/"));
|
||||
|
||||
@Test
|
||||
|
@ -104,7 +109,7 @@ public class CsrfWebFilterTests {
|
|||
this.csrfFilter.setCsrfTokenRepository(this.repository);
|
||||
when(this.repository.loadToken(any()))
|
||||
.thenReturn(Mono.just(this.token));
|
||||
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
|
||||
this.post = from(MockServerHttpRequest.post("/")
|
||||
.body(this.token.getParameterName() + "="+this.token.getToken()+"INVALID"));
|
||||
|
||||
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
|
||||
|
@ -125,7 +130,7 @@ public class CsrfWebFilterTests {
|
|||
.thenReturn(Mono.just(this.token));
|
||||
when(this.repository.generateToken(any()))
|
||||
.thenReturn(Mono.just(this.token));
|
||||
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
|
||||
this.post = from(MockServerHttpRequest.post("/")
|
||||
.contentType(MediaType.APPLICATION_FORM_URLENCODED)
|
||||
.body(this.token.getParameterName() + "="+this.token.getToken()));
|
||||
|
||||
|
@ -142,7 +147,7 @@ public class CsrfWebFilterTests {
|
|||
this.csrfFilter.setCsrfTokenRepository(this.repository);
|
||||
when(this.repository.loadToken(any()))
|
||||
.thenReturn(Mono.just(this.token));
|
||||
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
|
||||
this.post = from(MockServerHttpRequest.post("/")
|
||||
.header(this.token.getHeaderName(), this.token.getToken()+"INVALID"));
|
||||
|
||||
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
|
||||
|
@ -163,7 +168,7 @@ public class CsrfWebFilterTests {
|
|||
.thenReturn(Mono.just(this.token));
|
||||
when(this.repository.generateToken(any()))
|
||||
.thenReturn(Mono.just(this.token));
|
||||
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
|
||||
this.post = from(MockServerHttpRequest.post("/")
|
||||
.header(this.token.getHeaderName(), this.token.getToken()));
|
||||
|
||||
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
|
||||
|
@ -173,4 +178,19 @@ public class CsrfWebFilterTests {
|
|||
|
||||
chainResult.assertWasSubscribed();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenSkipExchangeInvokedThenSkips() {
|
||||
PublisherProbe<Void> chainResult = PublisherProbe.empty();
|
||||
when(this.chain.filter(any())).thenReturn(chainResult.mono());
|
||||
|
||||
ServerWebExchangeMatcher matcher = mock(ServerWebExchangeMatcher.class);
|
||||
this.csrfFilter.setRequireCsrfProtectionMatcher(matcher);
|
||||
|
||||
MockServerWebExchange exchange = from(MockServerHttpRequest.post("/post").build());
|
||||
CsrfWebFilter.skipExchange(exchange);
|
||||
this.csrfFilter.filter(exchange, this.chain).block();
|
||||
|
||||
verifyZeroInteractions(matcher);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue