diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java index 8c5db34fb5..8dfc4da381 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java @@ -32,6 +32,8 @@ import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.security.authorization.AuthorizationManager; import org.springframework.security.authorization.SpringAuthorizationEventPublisher; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.messaging.access.intercept.AuthorizationChannelInterceptor; import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager; import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver; @@ -59,7 +61,10 @@ final class WebSocketMessageBrokerSecurityConfiguration private static final AuthorizationManager> ANY_MESSAGE_AUTHENTICATED = MessageMatcherDelegatingAuthorizationManager .builder().anyMessage().authenticated().build(); - private final ChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor(); + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + + private final SecurityContextChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor(); private final ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor(); @@ -74,17 +79,27 @@ final class WebSocketMessageBrokerSecurityConfiguration @Override public void addArgumentResolvers(List argumentResolvers) { - argumentResolvers.add(new AuthenticationPrincipalArgumentResolver()); + AuthenticationPrincipalArgumentResolver resolver = new AuthenticationPrincipalArgumentResolver(); + resolver.setSecurityContextHolderStrategy(this.securityContextHolderStrategy); + argumentResolvers.add(resolver); } @Override public void configureClientInboundChannel(ChannelRegistration registration) { this.authorizationChannelInterceptor .setAuthorizationEventPublisher(new SpringAuthorizationEventPublisher(this.context)); + this.authorizationChannelInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy); + this.securityContextChannelInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy); registration.interceptors(this.securityContextChannelInterceptor, this.csrfChannelInterceptor, this.authorizationChannelInterceptor); } + @Autowired(required = false) + void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + @Autowired(required = false) void setAuthorizationManager(AuthorizationManager> authorizationManager) { this.authorizationChannelInterceptor = new AuthorizationChannelInterceptor(authorizationManager); diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java index d2db1059a0..e9fef676c6 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java @@ -54,9 +54,11 @@ import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authorization.AuthorizationDecision; import org.springframework.security.authorization.AuthorizationManager; +import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig; import org.springframework.security.config.annotation.web.messaging.MessageSecurityMetadataSourceRegistry; import org.springframework.security.core.Authentication; import org.springframework.security.core.annotation.AuthenticationPrincipal; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.messaging.access.intercept.AuthorizationChannelInterceptor; import org.springframework.security.messaging.access.intercept.MessageAuthorizationContext; import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager; @@ -84,6 +86,8 @@ import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSo import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.fail; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.verify; public class WebSocketMessageBrokerSecurityConfigurationTests { @@ -225,6 +229,18 @@ public class WebSocketMessageBrokerSecurityConfigurationTests { assertHandshake(request); } + @Test + public void messagesContextWebSocketUseSecurityContextHolderStrategy() { + loadConfig(WebSocketSecurityConfig.class, SecurityContextChangedListenerConfig.class); + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); + headers.setNativeHeader(this.token.getHeaderName(), this.token.getToken()); + Message message = message(headers, "/authenticated"); + headers.getSessionAttributes().put(CsrfToken.class.getName(), this.token); + MessageChannel messageChannel = clientInboundChannel(); + messageChannel.send(message); + verify(this.context.getBean(SecurityContextHolderStrategy.class), atLeastOnce()).getContext(); + } + @Test public void msmsRegistryCustomPatternMatcher() { loadConfig(MsmsRegistryCustomPatternMatcherConfig.class); @@ -691,6 +707,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests { // @formatter:off messages .simpDestMatchers("/permitAll/**").permitAll() + .simpDestMatchers("/authenticated/**").authenticated() .anyMessage().denyAll(); // @formatter:on return messages.build();