Simplify Csrf Processor Decision Logic

Replaces repeated if-else string comparisons with a Set.contains() check
for known WebSocket handshake handler class names in MessageSecurityPostProcessor.

Improves readability and maintainability without changing behavior.

Signed-off-by: Wonpyo Hong <evga7@naver.com>
This commit is contained in:
evga7 2025-06-15 05:04:38 +09:00 committed by Josh Cummings
parent 676b44ebb0
commit 06ed6ef342

View File

@ -19,6 +19,11 @@ package org.springframework.security.config.websocket;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.HashSet;
import java.util.Arrays;
import java.util.Collections;
import java.util.function.Supplier; import java.util.function.Supplier;
import org.w3c.dom.Element; import org.w3c.dom.Element;
@ -307,6 +312,13 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements
private static final String TEMPLATE_EXPRESSION_BEAN_ID = "annotationExpressionTemplateDefaults"; private static final String TEMPLATE_EXPRESSION_BEAN_ID = "annotationExpressionTemplateDefaults";
private static final Set<String> CSRF_HANDSHAKE_HANDLER_CLASSES = Collections.unmodifiableSet(
new HashSet<>(Arrays.asList(
"org.springframework.web.socket.server.support.WebSocketHttpRequestHandler",
"org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService",
"org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService"
)));
private final String inboundSecurityInterceptorId; private final String inboundSecurityInterceptorId;
private final boolean sameOriginDisabled; private final boolean sameOriginDisabled;
@ -345,16 +357,7 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements
} }
} }
} }
else if ("org.springframework.web.socket.server.support.WebSocketHttpRequestHandler" else if (CSRF_HANDSHAKE_HANDLER_CLASSES.contains(beanClassName)) {
.equals(beanClassName)) {
addCsrfTokenHandshakeInterceptor(bd);
}
else if ("org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService"
.equals(beanClassName)) {
addCsrfTokenHandshakeInterceptor(bd);
}
else if ("org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService"
.equals(beanClassName)) {
addCsrfTokenHandshakeInterceptor(bd); addCsrfTokenHandshakeInterceptor(bd);
} }
} }