diff --git a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java index d5eafacf97..edb6646b36 100644 --- a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java @@ -305,7 +305,7 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements private static final String CUSTOM_ARG_RESOLVERS_PROP = "customArgumentResolvers"; - private static final String TEMPLATE_EXPRESSION_BEAN_ID = "templateDefaults"; + private static final String TEMPLATE_EXPRESSION_BEAN_ID = "annotationExpressionTemplateDefaults"; private final String inboundSecurityInterceptorId; @@ -333,7 +333,7 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements AuthenticationPrincipalArgumentResolver.class); if (registry.containsBeanDefinition(TEMPLATE_EXPRESSION_BEAN_ID)) { beanDefinition.getPropertyValues() - .add(TEMPLATE_EXPRESSION_BEAN_ID, new RuntimeBeanReference(TEMPLATE_EXPRESSION_BEAN_ID)); + .add("templateDefaults", new RuntimeBeanReference(TEMPLATE_EXPRESSION_BEAN_ID)); } argResolvers.add(beanDefinition); bd.getPropertyValues().add(CUSTOM_ARG_RESOLVERS_PROP, argResolvers); diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfigurationTests.java index 322dd22dea..2578aa3b18 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfigurationTests.java @@ -16,6 +16,11 @@ package org.springframework.security.config.annotation.web.configuration; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -26,6 +31,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.core.annotation.AnnotationTemplateExpressionDefaults; import org.springframework.security.core.annotation.AuthenticationPrincipal; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; @@ -39,12 +45,15 @@ import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.ResultMatcher; import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.ResponseBody; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.ModelAndView; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.model; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.view; @@ -97,10 +106,28 @@ public class WebMvcSecurityConfigurationTests { this.mockMvc.perform(request).andExpect(assertResult(csrfToken)); } + @Test + public void metaAnnotationWhenTemplateDefaultsBeanThenResolvesExpression() throws Exception { + this.mockMvc.perform(get("/hi")).andExpect(content().string("Hi, Stranger!")); + Authentication harold = new TestingAuthenticationToken("harold", "password", + AuthorityUtils.createAuthorityList("ROLE_USER")); + SecurityContextHolder.getContext().setAuthentication(harold); + this.mockMvc.perform(get("/hi")).andExpect(content().string("Hi, Harold!")); + } + private ResultMatcher assertResult(Object expected) { return model().attribute("result", expected); } + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.PARAMETER) + @AuthenticationPrincipal(expression = "#this.equals('{value}')") + @interface IsUser { + + String value() default "user"; + + } + @Controller static class TestController { @@ -120,6 +147,17 @@ public class WebMvcSecurityConfigurationTests { return new ModelAndView("view", "result", token); } + @GetMapping("/hi") + @ResponseBody + String ifUser(@IsUser("harold") boolean isHarold) { + if (isHarold) { + return "Hi, Harold!"; + } + else { + return "Hi, Stranger!"; + } + } + } @Configuration @@ -132,6 +170,11 @@ public class WebMvcSecurityConfigurationTests { return new TestController(); } + @Bean + AnnotationTemplateExpressionDefaults templateExpressionDefaults() { + return new AnnotationTemplateExpressionDefaults(); + } + } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTests.java index 972a6eb539..9743782942 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTests.java @@ -16,6 +16,10 @@ package org.springframework.security.config.annotation.web.reactive; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; import java.net.URI; import org.junit.jupiter.api.Test; @@ -26,6 +30,7 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.password.CompromisedPasswordDecision; import org.springframework.security.authentication.password.CompromisedPasswordException; import org.springframework.security.authentication.password.ReactiveCompromisedPasswordChecker; @@ -34,8 +39,12 @@ import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContextExtension; import org.springframework.security.config.users.ReactiveAuthenticationTestConfiguration; import org.springframework.security.config.web.server.ServerHttpSecurity; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.annotation.AnnotationTemplateExpressionDefaults; +import org.springframework.security.core.annotation.AuthenticationPrincipal; import org.springframework.security.core.userdetails.MapReactiveUserDetailsService; import org.springframework.security.core.userdetails.PasswordEncodedUser; +import org.springframework.security.core.userdetails.ReactiveUserDetailsService; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.web.server.DefaultServerRedirectStrategy; @@ -43,12 +52,16 @@ import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; import org.springframework.web.reactive.config.EnableWebFlux; import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.csrf; +import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockAuthentication; +import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.springSecurity; /** * Tests for {@link ServerHttpSecurityConfiguration}. @@ -67,7 +80,10 @@ public class ServerHttpSecurityConfigurationTests { if (!context.containsBean(WebHttpHandlerBuilder.WEB_HANDLER_BEAN_NAME)) { return; } - this.webClient = WebTestClient.bindToApplicationContext(context).configureClient().build(); + this.webClient = WebTestClient.bindToApplicationContext(context) + .apply(springSecurity()) + .configureClient() + .build(); } @Test @@ -146,6 +162,27 @@ public class ServerHttpSecurityConfigurationTests { // @formatter:on } + @Test + public void metaAnnotationWhenTemplateDefaultsBeanThenResolvesExpression() throws Exception { + this.spring.register(MetaAnnotationPlaceholderConfig.class).autowire(); + Authentication user = new TestingAuthenticationToken("user", "password", "ROLE_USER"); + this.webClient.mutateWith(mockAuthentication(user)) + .get() + .uri("/hi") + .exchange() + .expectStatus() + .isOk() + .expectBody(String.class) + .isEqualTo("Hi, Stranger!"); + Authentication harold = new TestingAuthenticationToken("harold", "password", "ROLE_USER"); + this.webClient.mutateWith(mockAuthentication(harold)) + .get() + .uri("/hi") + .exchange() + .expectBody(String.class) + .isEqualTo("Hi, Harold!"); + } + @Configuration static class SubclassConfig extends ServerHttpSecurityConfiguration { @@ -237,4 +274,61 @@ public class ServerHttpSecurityConfigurationTests { } + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.PARAMETER) + @AuthenticationPrincipal(expression = "#this.equals('{value}')") + @interface IsUser { + + String value() default "user"; + + } + + @RestController + static class TestController { + + @GetMapping("/hi") + String ifUser(@IsUser("harold") boolean isHarold) { + if (isHarold) { + return "Hi, Harold!"; + } + else { + return "Hi, Stranger!"; + } + } + + } + + @Configuration + @EnableWebFlux + @EnableWebFluxSecurity + static class MetaAnnotationPlaceholderConfig { + + @Bean + SecurityWebFilterChain filterChain(ServerHttpSecurity http) { + // @formatter:off + http + .authorizeExchange((authorize) -> authorize.anyExchange().authenticated()) + .httpBasic(Customizer.withDefaults()); + // @formatter:on + return http.build(); + } + + @Bean + ReactiveUserDetailsService userDetailsService() { + return new MapReactiveUserDetailsService( + User.withUsername("user").password("password").authorities("app").build()); + } + + @Bean + TestController testController() { + return new TestController(); + } + + @Bean + AnnotationTemplateExpressionDefaults templateExpressionDefaults() { + return new AnnotationTemplateExpressionDefaults(); + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java index a14986b1ee..2210288b57 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java @@ -16,6 +16,10 @@ package org.springframework.security.config.annotation.web.socket; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -59,6 +63,7 @@ import org.springframework.security.authorization.AuthorizationManager; import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig; import org.springframework.security.config.annotation.web.messaging.MessageSecurityMetadataSourceRegistry; import org.springframework.security.core.Authentication; +import org.springframework.security.core.annotation.AnnotationTemplateExpressionDefaults; import org.springframework.security.core.annotation.AuthenticationPrincipal; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolderStrategy; @@ -164,6 +169,17 @@ public class WebSocketMessageBrokerSecurityConfigurationTests { .isEqualTo((String) this.messageUser.getPrincipal()); } + @Test + public void sendMessageWhenMetaAnnotationThenParsesExpression() { + loadConfig(NoInboundSecurityConfig.class); + this.messageUser = new TestingAuthenticationToken("harold", "password", "ROLE_USER"); + clientInboundChannel().send(message("/permitAll/hi")); + assertThat(this.context.getBean(MyController.class).message).isEqualTo("Hi, Harold!"); + this.messageUser = new TestingAuthenticationToken("user", "password", "ROLE_USER"); + clientInboundChannel().send(message("/permitAll/hi")); + assertThat(this.context.getBean(MyController.class).message).isEqualTo("Hi, Stranger!"); + } + @Test public void addsCsrfProtectionWhenNoAuthorization() { loadConfig(NoInboundSecurityConfig.class); @@ -365,15 +381,6 @@ public class WebSocketMessageBrokerSecurityConfigurationTests { clientInboundChannel().send(message("/anonymous")); } - @Test - public void sendMessageWhenAnonymousConfiguredAndLoggedInUserThenAccessDeniedException() { - loadConfig(WebSocketSecurityConfig.class); - assertThatExceptionOfType(MessageDeliveryException.class) - .isThrownBy(() -> clientInboundChannel().send(message("/anonymous"))) - .withCauseInstanceOf(AccessDeniedException.class); - - } - private void assertHandshake(HttpServletRequest request) { TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class); assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token); @@ -585,6 +592,15 @@ public class WebSocketMessageBrokerSecurityConfigurationTests { } + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.PARAMETER) + @AuthenticationPrincipal(expression = "#this.equals('{value}')") + @interface IsUser { + + String value() default "user"; + + } + @Controller static class MyController { @@ -592,6 +608,8 @@ public class WebSocketMessageBrokerSecurityConfigurationTests { MyCustomArgument myCustomArgument; + String message; + @MessageMapping("/authentication") void authentication(@AuthenticationPrincipal String un) { this.authenticationPrincipal = un; @@ -602,6 +620,11 @@ public class WebSocketMessageBrokerSecurityConfigurationTests { this.myCustomArgument = myCustomArgument; } + @MessageMapping("/hi") + void sayHello(@IsUser("harold") boolean isHarold) { + this.message = isHarold ? "Hi, Harold!" : "Hi, Stranger!"; + } + } static class MyCustomArgument { @@ -735,6 +758,11 @@ public class WebSocketMessageBrokerSecurityConfigurationTests { return new MyController(); } + @Bean + AnnotationTemplateExpressionDefaults templateExpressionDefaults() { + return new AnnotationTemplateExpressionDefaults(); + } + } @Configuration diff --git a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java index c8bf1d8eb1..b0d37eff2a 100644 --- a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java @@ -16,6 +16,10 @@ package org.springframework.security.config.websocket; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; import java.util.HashMap; import java.util.Map; import java.util.function.Supplier; @@ -47,6 +51,7 @@ import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.GenericMessage; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.expression.SecurityExpressionOperations; +import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authorization.AuthorizationDecision; import org.springframework.security.authorization.AuthorizationManager; import org.springframework.security.config.test.SpringTestContext; @@ -55,6 +60,7 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.annotation.AuthenticationPrincipal; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.messaging.access.expression.DefaultMessageSecurityExpressionHandler; import org.springframework.security.messaging.access.expression.MessageSecurityExpressionRoot; import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners; @@ -376,6 +382,24 @@ public class WebSocketMessageBrokerConfigTests { assertThat(this.messageController.username).isEqualTo("anonymous"); } + @Test + public void sendMessageWhenMetaAnnotationThenAuthenticationPrincipalResolved() { + this.spring.configLocations(xml("SyncConfig")).autowire(); + Authentication harold = new TestingAuthenticationToken("harold", "password", "ROLE_USER"); + try { + getSecurityContextHolderStrategy().setContext(new SecurityContextImpl(harold)); + this.clientInboundChannel.send(message("/hi")); + assertThat(this.spring.getContext().getBean(MessageController.class).message).isEqualTo("Hi, Harold!"); + Authentication user = new TestingAuthenticationToken("user", "password", "ROLE_USER"); + getSecurityContextHolderStrategy().setContext(new SecurityContextImpl(user)); + this.clientInboundChannel.send(message("/hi")); + assertThat(this.spring.getContext().getBean(MessageController.class).message).isEqualTo("Hi, Stranger!"); + } + finally { + getSecurityContextHolderStrategy().clearContext(); + } + } + @Test public void requestWhenConnectMessageThenUsesCsrfTokenHandshakeInterceptor() throws Exception { this.spring.configLocations(xml("SyncConfig")).autowire(); @@ -553,16 +577,32 @@ public class WebSocketMessageBrokerConfigTests { } + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.PARAMETER) + @AuthenticationPrincipal(expression = "#this.equals('{value}')") + @interface IsUser { + + String value() default "user"; + + } + @Controller static class MessageController { String username; + String message; + @MessageMapping("/message") void authentication(@AuthenticationPrincipal String username) { this.username = username; } + @MessageMapping("/hi") + void sayHello(@IsUser("harold") boolean isHarold) { + this.message = isHarold ? "Hi, Harold!" : "Hi, Stranger!"; + } + } @Controller diff --git a/config/src/test/resources/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests-SyncConfig.xml b/config/src/test/resources/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests-SyncConfig.xml index 54dfdc79ef..667bdd73cb 100644 --- a/config/src/test/resources/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests-SyncConfig.xml +++ b/config/src/test/resources/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests-SyncConfig.xml @@ -28,4 +28,5 @@ +