diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/messaging/MessageSecurityMetadataSourceRegistry.java b/config/src/main/java/org/springframework/security/config/annotation/web/messaging/MessageSecurityMetadataSourceRegistry.java index fcf46e17fd..894e98fbc7 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/messaging/MessageSecurityMetadataSourceRegistry.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/messaging/MessageSecurityMetadataSourceRegistry.java @@ -47,7 +47,9 @@ public class MessageSecurityMetadataSourceRegistry { private final LinkedHashMap matcherToExpression = new LinkedHashMap(); - private PathMatcher pathMatcher = new AntPathMatcher(); + private DelegatingPathMatcher pathMatcher = new DelegatingPathMatcher(); + + private boolean defaultPathMatcher = true; /** * Maps any {@link Message} to a security expression. @@ -169,10 +171,20 @@ public class MessageSecurityMetadataSourceRegistry { public MessageSecurityMetadataSourceRegistry simpDestPathMatcher( PathMatcher pathMatcher) { Assert.notNull(pathMatcher, "pathMatcher cannot be null"); - this.pathMatcher = pathMatcher; + this.pathMatcher.setPathMatcher(pathMatcher); + this.defaultPathMatcher = false; return this; } + /** + * Determines if the {@link #simpDestPathMatcher(PathMatcher)} has been explicitly set. + * + * @return true if {@link #simpDestPathMatcher(PathMatcher)} has been explicitly set, else false. + */ + protected boolean isSimpDestPathMatcherConfigured() { + return !this.defaultPathMatcher; + } + /** * Maps a {@link List} of {@link MessageMatcher} instances to a security expression. * @@ -439,4 +451,42 @@ public class MessageSecurityMetadataSourceRegistry { private interface MatcherBuilder { MessageMatcher build(); } + + + static class DelegatingPathMatcher implements PathMatcher { + + private PathMatcher delegate = new AntPathMatcher(); + + public boolean isPattern(String path) { + return delegate.isPattern(path); + } + + public boolean match(String pattern, String path) { + return delegate.match(pattern, path); + } + + public boolean matchStart(String pattern, String path) { + return delegate.matchStart(pattern, path); + } + + public String extractPathWithinPattern(String pattern, String path) { + return delegate.extractPathWithinPattern(pattern, path); + } + + public Map extractUriTemplateVariables(String pattern, String path) { + return delegate.extractUriTemplateVariables(pattern, path); + } + + public Comparator getPatternComparator(String path) { + return delegate.getPatternComparator(path); + } + + public String combine(String pattern1, String pattern2) { + return delegate.combine(pattern1, pattern2); + } + + void setPathMatcher(PathMatcher pathMatcher) { + this.delegate = pathMatcher; + } + } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java index 40e2dfe2c4..1646ca4a12 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java @@ -15,6 +15,11 @@ */ package org.springframework.security.config.annotation.web.socket; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.SmartInitializingSingleton; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; @@ -22,6 +27,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.Ordered; import org.springframework.core.annotation.Order; import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; +import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler; import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.vote.AffirmativeBased; @@ -33,6 +39,8 @@ import org.springframework.security.messaging.context.AuthenticationPrincipalArg import org.springframework.security.messaging.context.SecurityContextChannelInterceptor; import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor; import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor; +import org.springframework.util.AntPathMatcher; +import org.springframework.util.PathMatcher; import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; import org.springframework.web.socket.config.annotation.AbstractWebSocketMessageBrokerConfigurer; import org.springframework.web.socket.config.annotation.StompEndpointRegistry; @@ -42,10 +50,6 @@ import org.springframework.web.socket.sockjs.SockJsService; import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler; import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - /** * Allows configuring WebSocket Authorization. * @@ -57,7 +61,7 @@ import java.util.Map; * @Configuration * public class WebSocketSecurityConfig extends * AbstractSecurityWebSocketMessageBrokerConfigurer { - * + * * @Override * protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { * messages.simpDestMatchers("/user/queue/errors").permitAll() @@ -99,6 +103,14 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends customizeClientInboundChannel(registration); } + private PathMatcher getDefaultPathMatcher() { + try { + return context.getBean(SimpAnnotationMethodMessageHandler.class).getPathMatcher(); + } catch(NoSuchBeanDefinitionException e) { + return new AntPathMatcher(); + } + } + /** *

* Determines if a CSRF token is required for connecting. This protects against remote @@ -169,6 +181,11 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends protected boolean containsMapping() { return super.containsMapping(); } + + @Override + protected boolean isSimpDestPathMatcherConfigured() { + return super.isSimpDestPathMatcherConfigured(); + } } @Autowired @@ -225,5 +242,10 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends + object); } } + + if (inboundRegistry.containsMapping() && !inboundRegistry.isSimpDestPathMatcherConfigured()) { + PathMatcher pathMatcher = getDefaultPathMatcher(); + inboundRegistry.simpDestPathMatcher(pathMatcher); + } } } \ No newline at end of file diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java index f13edd902e..3c366131f5 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java @@ -1,5 +1,4 @@ /* - * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of @@ -17,7 +16,6 @@ package org.springframework.security.config.annotation.web.socket; import org.junit.After; import org.junit.Before; - import org.junit.Test; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -46,6 +44,7 @@ import org.springframework.security.web.csrf.DefaultCsrfToken; import org.springframework.security.web.csrf.MissingCsrfTokenException; import org.springframework.stereotype.Controller; import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.util.AntPathMatcher; import org.springframework.web.HttpRequestHandler; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; import org.springframework.web.servlet.HandlerMapping; @@ -59,6 +58,7 @@ import org.springframework.web.socket.sockjs.transport.handler.SockJsWebSocketHa import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSockJsSession; import javax.servlet.http.HttpServletRequest; + import java.util.HashMap; import java.util.Map; @@ -232,6 +232,163 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { assertHandshake(request); } + @Test + public void msmsRegistryCustomPatternMatcher() + throws Exception { + loadConfig(MsmsRegistryCustomPatternMatcherConfig.class); + + clientInboundChannel().send(message("/app/a.b")); + + try { + clientInboundChannel().send(message("/app/a.b.c")); + fail("Expected Exception"); + } + catch (MessageDeliveryException expected) { + assertThat(expected.getCause()).isInstanceOf(AccessDeniedException.class); + } + } + + @Configuration + @EnableWebSocketMessageBroker + @Import(SyncExecutorConfig.class) + static class MsmsRegistryCustomPatternMatcherConfig extends + AbstractSecurityWebSocketMessageBrokerConfigurer { + + // @formatter:off + public void registerStompEndpoints(StompEndpointRegistry registry) { + registry + .addEndpoint("/other") + .setHandshakeHandler(testHandshakeHandler()); + } + // @formatter:on + + // @formatter:off + @Override + protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { + messages + .simpDestMatchers("/app/a.*").permitAll() + .anyMessage().denyAll(); + } + // @formatter:on + + @Override + public void configureMessageBroker(MessageBrokerRegistry registry) { + registry.setPathMatcher(new AntPathMatcher(".")); + registry.enableSimpleBroker("/queue/", "/topic/"); + registry.setApplicationDestinationPrefixes("/app"); + } + + @Bean + public TestHandshakeHandler testHandshakeHandler() { + return new TestHandshakeHandler(); + } + } + + @Test + public void overrideMsmsRegistryCustomPatternMatcher() + throws Exception { + loadConfig(OverrideMsmsRegistryCustomPatternMatcherConfig.class); + + clientInboundChannel().send(message("/app/a/b")); + + try { + clientInboundChannel().send(message("/app/a/b/c")); + fail("Expected Exception"); + } + catch (MessageDeliveryException expected) { + assertThat(expected.getCause()).isInstanceOf(AccessDeniedException.class); + } + } + + @Configuration + @EnableWebSocketMessageBroker + @Import(SyncExecutorConfig.class) + static class OverrideMsmsRegistryCustomPatternMatcherConfig extends + AbstractSecurityWebSocketMessageBrokerConfigurer { + + // @formatter:off + public void registerStompEndpoints(StompEndpointRegistry registry) { + registry + .addEndpoint("/other") + .setHandshakeHandler(testHandshakeHandler()); + } + // @formatter:on + + + // @formatter:off + @Override + protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { + messages + .simpDestPathMatcher(new AntPathMatcher()) + .simpDestMatchers("/app/a/*").permitAll() + .anyMessage().denyAll(); + } + // @formatter:on + + @Override + public void configureMessageBroker(MessageBrokerRegistry registry) { + registry.setPathMatcher(new AntPathMatcher(".")); + registry.enableSimpleBroker("/queue/", "/topic/"); + registry.setApplicationDestinationPrefixes("/app"); + } + + @Bean + public TestHandshakeHandler testHandshakeHandler() { + return new TestHandshakeHandler(); + } + } + + @Test + public void defaultPatternMatcher() + throws Exception { + loadConfig(DefaultPatternMatcherConfig.class); + + clientInboundChannel().send(message("/app/a/b")); + + try { + clientInboundChannel().send(message("/app/a/b/c")); + fail("Expected Exception"); + } + catch (MessageDeliveryException expected) { + assertThat(expected.getCause()).isInstanceOf(AccessDeniedException.class); + } + } + + @Configuration + @EnableWebSocketMessageBroker + @Import(SyncExecutorConfig.class) + static class DefaultPatternMatcherConfig extends + AbstractSecurityWebSocketMessageBrokerConfigurer { + + // @formatter:off + public void registerStompEndpoints(StompEndpointRegistry registry) { + registry + .addEndpoint("/other") + .setHandshakeHandler(testHandshakeHandler()); + } + // @formatter:on + + // @formatter:off + @Override + protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { + messages + .simpDestMatchers("/app/a/*").permitAll() + .anyMessage().denyAll(); + } + // @formatter:on + + @Override + public void configureMessageBroker(MessageBrokerRegistry registry) { + registry.enableSimpleBroker("/queue/", "/topic/"); + registry.setApplicationDestinationPrefixes("/app"); + } + + @Bean + public TestHandshakeHandler testHandshakeHandler() { + return new TestHandshakeHandler(); + } + } + private void assertHandshake(HttpServletRequest request) { TestHandshakeHandler handshakeHandler = context .getBean(TestHandshakeHandler.class); @@ -358,10 +515,14 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { .withSockJS().setInterceptors(new HttpSessionHandshakeInterceptor()); } + // @formatter:off @Override protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { - messages.simpDestMatchers("/permitAll/**").permitAll().anyMessage().denyAll(); + messages + .simpDestMatchers("/permitAll/**").permitAll() + .anyMessage().denyAll(); } + // @formatter:on @Override public void configureMessageBroker(MessageBrokerRegistry registry) { @@ -431,10 +592,14 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { .addInterceptors(new HttpSessionHandshakeInterceptor()); } + // @formatter:off @Override protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { - messages.simpDestMatchers("/permitAll/**").permitAll().anyMessage().denyAll(); + messages + .simpDestMatchers("/permitAll/**").permitAll() + .anyMessage().denyAll(); } + // @formatter:on @Bean public TestHandshakeHandler testHandshakeHandler() {