Automatically add CsrfServerLogoutHandler if csrf enabled

The configuration DSL should automatically add CsrfServerLogoutHandler if csrf is enabled

Fixes gh-5337
This commit is contained in:
Eric Deandrea 2018-05-24 14:16:04 -04:00 committed by Rob Winch
parent 725b3b5482
commit b060ec050a
5 changed files with 205 additions and 21 deletions

View File

@ -27,6 +27,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import reactor.core.publisher.Mono;
@ -92,7 +93,9 @@ import org.springframework.security.web.server.authentication.ServerAuthenticati
import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler;
import org.springframework.security.web.server.authentication.ServerFormLoginAuthenticationConverter;
import org.springframework.security.web.server.authentication.ServerHttpBasicAuthenticationConverter;
import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler;
import org.springframework.security.web.server.authentication.logout.LogoutWebFilter;
import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler;
import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler;
import org.springframework.security.web.server.authentication.logout.ServerLogoutSuccessHandler;
import org.springframework.security.web.server.authorization.AuthorizationContext;
@ -106,8 +109,10 @@ import org.springframework.security.web.server.context.ReactorContextWebFilter;
import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
import org.springframework.security.web.server.csrf.CsrfWebFilter;
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository;
import org.springframework.security.web.server.header.CacheControlServerHttpHeadersWriter;
import org.springframework.security.web.server.header.CompositeServerHttpHeadersWriter;
import org.springframework.security.web.server.header.ContentSecurityPolicyServerHttpHeadersWriter;
@ -1538,6 +1543,7 @@ public class ServerHttpSecurity {
*/
public class CsrfSpec {
private CsrfWebFilter filter = new CsrfWebFilter();
private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository();
private boolean specifiedRequireCsrfProtectionMatcher;
@ -1563,7 +1569,7 @@ public class ServerHttpSecurity {
*/
public CsrfSpec csrfTokenRepository(
ServerCsrfTokenRepository csrfTokenRepository) {
this.filter.setCsrfTokenRepository(csrfTokenRepository);
this.csrfTokenRepository = csrfTokenRepository;
return this;
}
@ -1600,6 +1606,10 @@ public class ServerHttpSecurity {
}
protected void configure(ServerHttpSecurity http) {
Optional.ofNullable(this.csrfTokenRepository).ifPresent(serverCsrfTokenRepository -> {
this.filter.setCsrfTokenRepository(serverCsrfTokenRepository);
http.logout().logoutHandler(new CsrfServerLogoutHandler(serverCsrfTokenRepository));
});
http.addFilterAt(this.filter, SecurityWebFiltersOrder.CSRF);
}
@ -2332,6 +2342,7 @@ public class ServerHttpSecurity {
*/
public final class LogoutSpec {
private LogoutWebFilter logoutWebFilter = new LogoutWebFilter();
private List<ServerLogoutHandler> logoutHandlers = new ArrayList<>(Arrays.asList(new SecurityContextServerLogoutHandler()));
/**
* Configures the logout handler. Default is {@code SecurityContextServerLogoutHandler}
@ -2339,7 +2350,10 @@ public class ServerHttpSecurity {
* @return the {@link LogoutSpec} to configure
*/
public LogoutSpec logoutHandler(ServerLogoutHandler logoutHandler) {
this.logoutWebFilter.setLogoutHandler(logoutHandler);
if (logoutHandler != null) {
this.logoutHandlers.add(logoutHandler);
}
return this;
}
@ -2387,7 +2401,19 @@ public class ServerHttpSecurity {
return and();
}
private Optional<ServerLogoutHandler> createLogoutHandler() {
if (this.logoutHandlers.isEmpty()) {
return Optional.empty();
}
else if (this.logoutHandlers.size() == 1) {
return Optional.of(this.logoutHandlers.get(0));
}
return Optional.of(new DelegatingServerLogoutHandler(this.logoutHandlers));
}
protected void configure(ServerHttpSecurity http) {
createLogoutHandler().ifPresent(this.logoutWebFilter::setLogoutHandler);
http.addFilterAt(this.logoutWebFilter, SecurityWebFiltersOrder.LOGOUT);
}

View File

@ -16,12 +16,27 @@
package org.springframework.security.config.web.server;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.when;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.http.HttpHeaders;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import reactor.core.publisher.Mono;
import reactor.test.publisher.TestPublisher;
import org.springframework.security.authentication.ReactiveAuthenticationManager;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
@ -29,21 +44,23 @@ import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.WebFilterChainProxy;
import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler;
import org.springframework.security.web.server.authentication.logout.LogoutWebFilter;
import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler;
import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
import org.springframework.security.web.server.csrf.CsrfWebFilter;
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.reactive.server.EntityExchangeResult;
import org.springframework.test.web.reactive.server.FluxExchangeResult;
import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import reactor.test.publisher.TestPublisher;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.when;
import org.springframework.web.server.WebFilter;
/**
* @author Rob Winch
@ -55,6 +72,8 @@ public class ServerHttpSecurityTests {
private ServerSecurityContextRepository contextRepository;
@Mock
private ReactiveAuthenticationManager authenticationManager;
@Mock
private ServerCsrfTokenRepository csrfTokenRepository;
private ServerHttpSecurity http;
@ -134,6 +153,51 @@ public class ServerHttpSecurityTests {
.expectBody(String.class).isEqualTo("/foo/bar");
}
@Test
public void csrfServerLogoutHandlerNotAppliedIfCsrfIsntEnabled() {
SecurityWebFilterChain securityWebFilterChain = this.http.csrf().disable().build();
assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class))
.isNotPresent();
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
assertThat(logoutHandler)
.get()
.isExactlyInstanceOf(SecurityContextServerLogoutHandler.class);
}
@Test
public void csrfServerLogoutHandlerAppliedIfCsrfIsEnabled() {
SecurityWebFilterChain securityWebFilterChain = this.http.csrf().csrfTokenRepository(this.csrfTokenRepository).and().build();
assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class))
.get()
.extracting(csrfWebFilter -> ReflectionTestUtils.getField(csrfWebFilter, "csrfTokenRepository"))
.isEqualTo(this.csrfTokenRepository);
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
assertThat(logoutHandler)
.get()
.isExactlyInstanceOf(DelegatingServerLogoutHandler.class)
.extracting(delegatingLogoutHandler ->
((List<ServerLogoutHandler>) ReflectionTestUtils.getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
.map(ServerLogoutHandler::getClass)
.collect(Collectors.toList()))
.isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class));
}
private <T extends WebFilter> Optional<T> getWebFilter(SecurityWebFilterChain filterChain, Class<T> filterClass) {
return (Optional<T>) filterChain.getWebFilters()
.filter(Objects::nonNull)
.filter(filter -> filter.getClass().isAssignableFrom(filterClass))
.singleOrEmpty()
.blockOptional();
}
private WebTestClient buildClient() {
WebFilterChainProxy springSecurityFilterChain = new WebFilterChainProxy(
this.http.build());

View File

@ -18,16 +18,17 @@ package org.springframework.security.web.server.authentication.logout;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import reactor.core.publisher.Mono;
import org.springframework.security.core.Authentication;
import org.springframework.security.web.server.WebFilterExchange;
import org.springframework.util.Assert;
import reactor.core.publisher.Mono;
/**
* Delegates to a collection of {@link ServerLogoutHandler} implementations.
*
@ -35,21 +36,24 @@ import reactor.core.publisher.Mono;
* @since 5.1
*/
public class DelegatingServerLogoutHandler implements ServerLogoutHandler {
private final List<ServerLogoutHandler> delegates;
private final List<ServerLogoutHandler> delegates = new ArrayList<>();
public DelegatingServerLogoutHandler(ServerLogoutHandler... delegates) {
Assert.notEmpty(delegates, "delegates cannot be null or empty");
this.delegates = Arrays.asList(delegates);
this.delegates.addAll(Arrays.asList(delegates));
}
public DelegatingServerLogoutHandler(List<ServerLogoutHandler> delegates) {
public DelegatingServerLogoutHandler(Collection<ServerLogoutHandler> delegates) {
Assert.notEmpty(delegates, "delegates cannot be null or empty");
this.delegates = new ArrayList<>(delegates);
this.delegates.addAll(delegates);
}
@Override
public Mono<Void> logout(WebFilterExchange exchange, Authentication authentication) {
Stream<Mono<Void>> results = this.delegates.stream().map(delegate -> delegate.logout(exchange, authentication));
return Mono.when(results.collect(Collectors.toList()));
return Mono.when(this.delegates.stream()
.filter(Objects::nonNull)
.map(delegate -> delegate.logout(exchange, authentication))
.collect(Collectors.toList())
);
}
}

View File

@ -16,17 +16,17 @@
package org.springframework.security.web.server.authentication.logout;
import org.springframework.http.HttpMethod;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.util.Assert;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpMethod;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.web.server.WebFilterExchange;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
@ -85,6 +85,10 @@ public class LogoutWebFilter implements WebFilter {
this.logoutSuccessHandler = logoutSuccessHandler;
}
/**
* Sets the {@link ServerLogoutHandler}. The default is {@link SecurityContextServerLogoutHandler}.
* @param logoutHandler The handler to use
*/
public void setLogoutHandler(ServerLogoutHandler logoutHandler) {
Assert.notNull(logoutHandler, "logoutHandler must not be null");
this.logoutHandler = logoutHandler;

View File

@ -0,0 +1,86 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.server.authentication.logout;
import static org.assertj.core.api.Assertions.assertThat;
import java.util.Arrays;
import java.util.Collection;
import java.util.stream.Collectors;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.test.util.ReflectionTestUtils;
/**
* @author Eric Deandrea
* @since 5.1
*/
@RunWith(MockitoJUnitRunner.class)
public class LogoutWebFilterTests {
@Mock
private ServerLogoutHandler handler1;
@Mock
private ServerLogoutHandler handler2;
@Mock
private ServerLogoutHandler handler3;
private LogoutWebFilter logoutWebFilter = new LogoutWebFilter();
@Test
public void defaultLogoutHandler() {
assertThat(getLogoutHandler())
.isNotNull()
.isExactlyInstanceOf(SecurityContextServerLogoutHandler.class);
}
@Test
public void singleLogoutHandler() {
this.logoutWebFilter.setLogoutHandler(this.handler1);
this.logoutWebFilter.setLogoutHandler(this.handler2);
assertThat(getLogoutHandler())
.isNotNull()
.isInstanceOf(ServerLogoutHandler.class)
.isNotInstanceOf(SecurityContextServerLogoutHandler.class)
.extracting(ServerLogoutHandler::getClass)
.isEqualTo(this.handler2.getClass());
}
@Test
public void multipleLogoutHandlers() {
this.logoutWebFilter.setLogoutHandler(new DelegatingServerLogoutHandler(this.handler1, this.handler2, this.handler3));
assertThat(getLogoutHandler())
.isNotNull()
.isExactlyInstanceOf(DelegatingServerLogoutHandler.class)
.extracting(delegatingLogoutHandler -> ((Collection<ServerLogoutHandler>) ReflectionTestUtils.getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates"))
.stream()
.map(ServerLogoutHandler::getClass)
.collect(Collectors.toList()))
.isEqualTo(Arrays.asList(this.handler1.getClass(), this.handler2.getClass(), this.handler3.getClass()));
}
private ServerLogoutHandler getLogoutHandler() {
return (ServerLogoutHandler) ReflectionTestUtils.getField(this.logoutWebFilter, LogoutWebFilter.class, "logoutHandler");
}
}