diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java index 18f170a548..6fd1d50140 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java @@ -24,6 +24,7 @@ import org.springframework.beans.factory.SmartInitializingSingleton; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Import; import org.springframework.core.Ordered; import org.springframework.core.annotation.Order; import org.springframework.messaging.Message; @@ -33,7 +34,10 @@ import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.expression.SecurityExpressionHandler; import org.springframework.security.access.vote.AffirmativeBased; +import org.springframework.security.config.annotation.ObjectPostProcessor; +import org.springframework.security.config.annotation.configuration.ObjectPostProcessorConfiguration; import org.springframework.security.config.annotation.web.messaging.MessageSecurityMetadataSourceRegistry; +import org.springframework.security.messaging.access.expression.DefaultMessageSecurityExpressionHandler; import org.springframework.security.messaging.access.expression.MessageExpressionVoter; import org.springframework.security.messaging.access.intercept.ChannelSecurityInterceptor; import org.springframework.security.messaging.access.intercept.MessageSecurityMetadataSource; @@ -78,10 +82,13 @@ import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsSe * @author Rob Winch */ @Order(Ordered.HIGHEST_PRECEDENCE + 100) +@Import(ObjectPostProcessorConfiguration.class) public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends AbstractWebSocketMessageBrokerConfigurer implements SmartInitializingSingleton { private final WebSocketMessageSecurityMetadataSourceRegistry inboundRegistry = new WebSocketMessageSecurityMetadataSourceRegistry(); + private SecurityExpressionHandler> defaultExpressionHandler = new DefaultMessageSecurityExpressionHandler(); + private SecurityExpressionHandler> expressionHandler; private ApplicationContext context; @@ -150,9 +157,7 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends ChannelSecurityInterceptor channelSecurityInterceptor = new ChannelSecurityInterceptor( inboundMessageSecurityMetadataSource()); MessageExpressionVoter voter = new MessageExpressionVoter(); - if(expressionHandler != null) { - voter.setExpressionHandler(expressionHandler); - } + voter.setExpressionHandler(getMessageExpressionHandler()); List> voters = new ArrayList>(); voters.add(voter); @@ -169,9 +174,7 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends @Bean public MessageSecurityMetadataSource inboundMessageSecurityMetadataSource() { - if(expressionHandler != null) { - inboundRegistry.expressionHandler(expressionHandler); - } + inboundRegistry.expressionHandler(getMessageExpressionHandler()); configureInbound(inboundRegistry); return inboundRegistry.createMetadataSource(); } @@ -218,6 +221,18 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends } } + @Autowired(required = false) + public void setObjectPostProcessor(ObjectPostProcessor objectPostProcessor) { + defaultExpressionHandler = objectPostProcessor.postProcess(defaultExpressionHandler); + } + + private SecurityExpressionHandler> getMessageExpressionHandler() { + if(expressionHandler == null) { + return defaultExpressionHandler; + } + return expressionHandler; + } + public void afterSingletonsInstantiated() { if (sameOriginDisabled()) { return; diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java index 9e012b2dfd..759fd7dbb1 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java @@ -117,6 +117,15 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { clientInboundChannel().send(message("/permitAll")); } + // gh-3797 + @Test + public void beanResolver() { + loadConfig(SockJsSecurityConfig.class); + + messageUser = null; + clientInboundChannel().send(message("/beanResolver")); + } + @Test public void addsAuthenticationPrincipalResolver() throws InterruptedException { loadConfig(SockJsSecurityConfig.class); @@ -594,6 +603,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { messages .simpDestMatchers("/permitAll/**").permitAll() + .simpDestMatchers("/beanResolver/**").access("@security.check()") .anyMessage().denyAll(); } // @formatter:on @@ -613,6 +623,20 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { public TestHandshakeHandler testHandshakeHandler() { return new TestHandshakeHandler(); } + + @Bean + public SecurityCheck security() { + return new SecurityCheck(); + } + + static class SecurityCheck { + private boolean check; + + public boolean check() { + check = !check; + return check; + } + } } @Configuration