diff --git a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java index 39b06dc8b1..52848daacf 100644 --- a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java @@ -306,6 +306,8 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements private static final String CLIENT_INBOUND_CHANNEL_BEAN_ID = "clientInboundChannel"; + private static final String CSRF_CHANNEL_INTERCEPTOR_BEAN_ID = "csrfChannelInterceptor"; + private static final String INTERCEPTORS_PROP = "interceptors"; private static final String CUSTOM_ARG_RESOLVERS_PROP = "customArgumentResolvers"; @@ -365,7 +367,12 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements ManagedList interceptors = new ManagedList(); interceptors.add(new RootBeanDefinition(SecurityContextChannelInterceptor.class)); if (!this.sameOriginDisabled) { - interceptors.add(new RootBeanDefinition(CsrfChannelInterceptor.class)); + if (!registry.containsBeanDefinition(CSRF_CHANNEL_INTERCEPTOR_BEAN_ID)) { + interceptors.add(new RootBeanDefinition(CsrfChannelInterceptor.class)); + } + else { + interceptors.add(new RuntimeBeanReference(CSRF_CHANNEL_INTERCEPTOR_BEAN_ID)); + } } interceptors.add(registry.getBeanDefinition(this.inboundSecurityInterceptorId)); BeanDefinition inboundChannel = registry.getBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID); diff --git a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java index d95d70fb42..d7128e7ecb 100644 --- a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java @@ -48,6 +48,7 @@ import org.springframework.messaging.handler.invocation.HandlerMethodArgumentRes import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.messaging.support.ExecutorSubscribableChannel; import org.springframework.messaging.support.GenericMessage; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.expression.SecurityExpressionOperations; @@ -520,6 +521,16 @@ public class WebSocketMessageBrokerConfigTests { verify(authorizationManager).authorize(any(), any()); } + @Test + public void configureWhenCsrfChannelInterceptorBeanThenUses() { + this.spring.configLocations(xml("CustomCsrfInterceptor")).autowire(); + ExecutorSubscribableChannel channel = this.spring.getContext() + .getBean("clientInboundChannel", ExecutorSubscribableChannel.class); + ChannelInterceptor interceptor = this.spring.getContext() + .getBean("csrfChannelInterceptor", ChannelInterceptor.class); + assertThat(channel.getInterceptors()).contains(interceptor); + } + private String xml(String configName) { return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } diff --git a/config/src/test/resources/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests-CustomCsrfInterceptor.xml b/config/src/test/resources/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests-CustomCsrfInterceptor.xml new file mode 100644 index 0000000000..9883f56adf --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests-CustomCsrfInterceptor.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + +