Add CORS WebFlux Support

Fixes: gh-4832
This commit is contained in:
Rob Winch 2018-07-31 11:37:20 -05:00
parent fe17c71775
commit cecbc2175b
4 changed files with 236 additions and 0 deletions

View File

@ -23,6 +23,10 @@ package org.springframework.security.config.web.server;
public enum SecurityWebFiltersOrder { public enum SecurityWebFiltersOrder {
FIRST(Integer.MIN_VALUE), FIRST(Integer.MIN_VALUE),
HTTP_HEADERS_WRITER, HTTP_HEADERS_WRITER,
/**
* {@link org.springframework.web.cors.reactive.CorsWebFilter}
*/
CORS,
/** /**
* {@link org.springframework.security.web.server.csrf.CsrfWebFilter} * {@link org.springframework.security.web.server.csrf.CsrfWebFilter}
*/ */

View File

@ -111,6 +111,10 @@ import org.springframework.security.web.server.util.matcher.ServerWebExchangeMat
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
import org.springframework.web.cors.reactive.CorsConfigurationSource;
import org.springframework.web.cors.reactive.CorsProcessor;
import org.springframework.web.cors.reactive.CorsWebFilter;
import org.springframework.web.cors.reactive.DefaultCorsProcessor;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebFilterChain;
@ -181,6 +185,8 @@ public class ServerHttpSecurity {
private CsrfSpec csrf = new CsrfSpec(); private CsrfSpec csrf = new CsrfSpec();
private CorsSpec cors = new CorsSpec();
private ExceptionHandlingSpec exceptionHandling = new ExceptionHandlingSpec(); private ExceptionHandlingSpec exceptionHandling = new ExceptionHandlingSpec();
private HttpBasicSpec httpBasic; private HttpBasicSpec httpBasic;
@ -299,6 +305,80 @@ public class ServerHttpSecurity {
return this.csrf; return this.csrf;
} }
/**
* Configures CORS headers. By default if a {@link CorsConfigurationSource} Bean is found, it will be used
* to create a {@link CorsWebFilter}. If {@link CorsSpec#configurationSource(CorsConfigurationSource)} is invoked
* it will be used instead. If neither has been configured, the Cors configuration will do nothing.
* @return the {@link CorsSpec} to customize
*/
public CorsSpec cors() {
if (this.cors == null) {
this.cors = new CorsSpec();
}
return this.cors;
}
/**
* Configures CORS support within Spring Security. This ensures that the {@link CorsWebFilter} is place in the
* correct order.
*/
public class CorsSpec {
private CorsWebFilter corsFilter;
/**
* Configures the {@link CorsConfigurationSource} to be used
* @param source the source to use
* @return the {@link CorsSpec} for additional configuration
*/
public CorsSpec configurationSource(CorsConfigurationSource source) {
this.corsFilter = new CorsWebFilter(source);
return this;
}
/**
* Disables CORS support within Spring Security.
* @return the {@link ServerHttpSecurity} to continue configuring
*/
public ServerHttpSecurity disable() {
ServerHttpSecurity.this.cors = null;
return ServerHttpSecurity.this;
}
/**
* Allows method chaining to continue configuring the {@link ServerHttpSecurity}
* @return the {@link ServerHttpSecurity} to continue configuring
*/
public ServerHttpSecurity and() {
return ServerHttpSecurity.this;
}
protected void configure(ServerHttpSecurity http) {
CorsWebFilter corsFilter = getCorsFilter();
if (corsFilter != null) {
http.addFilterAt(this.corsFilter, SecurityWebFiltersOrder.CORS);
}
}
private CorsWebFilter getCorsFilter() {
if (this.corsFilter != null) {
return this.corsFilter;
}
CorsConfigurationSource source = getBeanOrNull(CorsConfigurationSource.class);
if (source == null) {
return null;
}
CorsProcessor processor = getBeanOrNull(CorsProcessor.class);
if (processor == null) {
processor = new DefaultCorsProcessor();
}
this.corsFilter = new CorsWebFilter(source, processor);
return this.corsFilter;
}
private CorsSpec() {}
}
/** /**
* Configures HTTP Basic authentication. An example configuration is provided below: * Configures HTTP Basic authentication. An example configuration is provided below:
* *
@ -782,6 +862,9 @@ public class ServerHttpSecurity {
if(this.csrf != null) { if(this.csrf != null) {
this.csrf.configure(this); this.csrf.configure(this);
} }
if (this.cors != null) {
this.cors.configure(this);
}
if(this.httpBasic != null) { if(this.httpBasic != null) {
this.httpBasic.authenticationManager(this.authenticationManager); this.httpBasic.authenticationManager(this.authenticationManager);
this.httpBasic.configure(this); this.httpBasic.configure(this);

View File

@ -0,0 +1,117 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.config.web.server;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.context.ApplicationContext;
import org.springframework.core.ResolvableType;
import org.springframework.http.HttpHeaders;
import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
import org.springframework.test.web.reactive.server.FluxExchangeResult;
import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.reactive.CorsConfigurationSource;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
/**
* @author Rob Winch
* @since 5.0
*/
@RunWith(MockitoJUnitRunner.class)
public class CorsSpecTests {
@Mock
private CorsConfigurationSource source;
@Mock
private ApplicationContext context;
ServerHttpSecurity http;
HttpHeaders expectedHeaders = new HttpHeaders();
Set<String> headerNamesNotPresent = new HashSet<>();
@Before
public void setup() {
this.http = new TestingServerHttpSecurity()
.applicationContext(this.context);
CorsConfiguration value = new CorsConfiguration();
value.setAllowedOrigins(Arrays.asList("*"));
when(this.source.getCorsConfiguration(any())).thenReturn(value);
}
@Test
public void corsWhenEnabledThenAccessControlAllowOriginAndSecurityHeaders() {
this.http.cors().configurationSource(this.source);
this.expectedHeaders.set("Access-Control-Allow-Origin", "*");
this.expectedHeaders.set("X-Frame-Options", "DENY");
assertHeaders();
}
@Test
public void corsWhenCorsConfigurationSourceBeanThenAccessControlAllowOriginAndSecurityHeaders() {
when(this.context.getBeanNamesForType(any(ResolvableType.class))).thenReturn(new String[] {"source"}, new String[0]);
when(this.context.getBean("source")).thenReturn(this.source);
this.expectedHeaders.set("Access-Control-Allow-Origin", "*");
this.expectedHeaders.set("X-Frame-Options", "DENY");
assertHeaders();
}
@Test
public void corsWhenNoConfigurationSourceThenNoCorsHeaders() {
when(this.context.getBeanNamesForType(any(ResolvableType.class))).thenReturn(new String[0]);
this.headerNamesNotPresent.add("Access-Control-Allow-Origin");
assertHeaders();
}
private void assertHeaders() {
WebTestClient client = buildClient();
FluxExchangeResult<String> response = client.get()
.uri("https://example.com/")
.headers(h -> h.setOrigin("https://origin.example.com"))
.exchange()
.returnResult(String.class);
Map<String, List<String>> responseHeaders = response.getResponseHeaders();
if (!this.expectedHeaders.isEmpty()) {
assertThat(responseHeaders).describedAs(response.toString())
.containsAllEntriesOf(this.expectedHeaders);
}
if (!this.headerNamesNotPresent.isEmpty()) {
assertThat(responseHeaders.keySet()).doesNotContainAnyElementsOf(this.headerNamesNotPresent);
}
}
private WebTestClient buildClient() {
return WebTestClientBuilder
.bindToWebFilters(this.http.build())
.build();
}
}

View File

@ -0,0 +1,32 @@
/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.config.web.server;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
/**
* @author Rob Winch
* @since 5.1
*/
public class TestingServerHttpSecurity extends ServerHttpSecurity {
public TestingServerHttpSecurity applicationContext(ApplicationContext applicationContext)
throws BeansException {
super.setApplicationContext(applicationContext);
return this;
}
}