diff --git a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java index edb6646b36..5f0b2dc388 100644 --- a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java @@ -301,6 +301,8 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements private static final String CLIENT_INBOUND_CHANNEL_BEAN_ID = "clientInboundChannel"; + private static final String CSRF_CHANNEL_INTERCEPTOR_BEAN_ID = "csrfChannelInterceptor"; + private static final String INTERCEPTORS_PROP = "interceptors"; private static final String CUSTOM_ARG_RESOLVERS_PROP = "customArgumentResolvers"; @@ -364,7 +366,12 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements ManagedList interceptors = new ManagedList(); interceptors.add(new RootBeanDefinition(SecurityContextChannelInterceptor.class)); if (!this.sameOriginDisabled) { - interceptors.add(new RootBeanDefinition(CsrfChannelInterceptor.class)); + if (!registry.containsBeanDefinition(CSRF_CHANNEL_INTERCEPTOR_BEAN_ID)) { + interceptors.add(new RootBeanDefinition(CsrfChannelInterceptor.class)); + } + else { + interceptors.add(new RuntimeBeanReference(CSRF_CHANNEL_INTERCEPTOR_BEAN_ID)); + } } interceptors.add(registry.getBeanDefinition(this.inboundSecurityInterceptorId)); BeanDefinition inboundChannel = registry.getBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID); diff --git a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java index 6e999933a2..1a27caaac9 100644 --- a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java @@ -48,6 +48,7 @@ import org.springframework.messaging.handler.invocation.HandlerMethodArgumentRes import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.messaging.support.ExecutorSubscribableChannel; import org.springframework.messaging.support.GenericMessage; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.expression.SecurityExpressionOperations; @@ -521,6 +522,16 @@ public class WebSocketMessageBrokerConfigTests { verify(authorizationManager).check(any(), any()); } + @Test + public void configureWhenCsrfChannelInterceptorBeanThenUses() { + this.spring.configLocations(xml("CustomCsrfInterceptor")).autowire(); + ExecutorSubscribableChannel channel = this.spring.getContext() + .getBean("clientInboundChannel", ExecutorSubscribableChannel.class); + ChannelInterceptor interceptor = this.spring.getContext() + .getBean("csrfChannelInterceptor", ChannelInterceptor.class); + assertThat(channel.getInterceptors()).contains(interceptor); + } + private String xml(String configName) { return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } diff --git a/config/src/test/resources/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests-CustomCsrfInterceptor.xml b/config/src/test/resources/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests-CustomCsrfInterceptor.xml new file mode 100644 index 0000000000..9883f56adf --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests-CustomCsrfInterceptor.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + +