From 4f5b17334ea976af13f848185734aaf05a66dcfe Mon Sep 17 00:00:00 2001 From: Josh Cummings <3627351+jzheaux@users.noreply.github.com> Date: Mon, 7 Jul 2025 12:53:27 -0600 Subject: [PATCH] Pick Up csrfChannelInterceptor in XML Closes gh-17493 --- ...ageBrokerSecurityBeanDefinitionParser.java | 9 ++++- .../WebSocketMessageBrokerConfigTests.java | 11 ++++++ ...rokerConfigTests-CustomCsrfInterceptor.xml | 34 +++++++++++++++++++ 3 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 config/src/test/resources/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests-CustomCsrfInterceptor.xml diff --git a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java index e833c6f3f7..876a3a93c7 100644 --- a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java @@ -301,6 +301,8 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements private static final String CLIENT_INBOUND_CHANNEL_BEAN_ID = "clientInboundChannel"; + private static final String CSRF_CHANNEL_INTERCEPTOR_BEAN_ID = "csrfChannelInterceptor"; + private static final String INTERCEPTORS_PROP = "interceptors"; private static final String CUSTOM_ARG_RESOLVERS_PROP = "customArgumentResolvers"; @@ -356,7 +358,12 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements ManagedList interceptors = new ManagedList(); interceptors.add(new RootBeanDefinition(SecurityContextChannelInterceptor.class)); if (!this.sameOriginDisabled) { - interceptors.add(new RootBeanDefinition(CsrfChannelInterceptor.class)); + if (!registry.containsBeanDefinition(CSRF_CHANNEL_INTERCEPTOR_BEAN_ID)) { + interceptors.add(new RootBeanDefinition(CsrfChannelInterceptor.class)); + } + else { + interceptors.add(new RuntimeBeanReference(CSRF_CHANNEL_INTERCEPTOR_BEAN_ID)); + } } interceptors.add(registry.getBeanDefinition(this.inboundSecurityInterceptorId)); BeanDefinition inboundChannel = registry.getBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID); diff --git a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java index c8bf1d8eb1..4d45105d00 100644 --- a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java @@ -44,6 +44,7 @@ import org.springframework.messaging.handler.invocation.HandlerMethodArgumentRes import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.messaging.support.ExecutorSubscribableChannel; import org.springframework.messaging.support.GenericMessage; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.expression.SecurityExpressionOperations; @@ -496,6 +497,16 @@ public class WebSocketMessageBrokerConfigTests { verify(authorizationManager).check(any(), any()); } + @Test + public void configureWhenCsrfChannelInterceptorBeanThenUses() { + this.spring.configLocations(xml("CustomCsrfInterceptor")).autowire(); + ExecutorSubscribableChannel channel = this.spring.getContext() + .getBean("clientInboundChannel", ExecutorSubscribableChannel.class); + ChannelInterceptor interceptor = this.spring.getContext() + .getBean("csrfChannelInterceptor", ChannelInterceptor.class); + assertThat(channel.getInterceptors()).contains(interceptor); + } + private String xml(String configName) { return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } diff --git a/config/src/test/resources/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests-CustomCsrfInterceptor.xml b/config/src/test/resources/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests-CustomCsrfInterceptor.xml new file mode 100644 index 0000000000..9883f56adf --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests-CustomCsrfInterceptor.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + +