diff --git a/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/HelloRSocketWithWebFluxITests.java b/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/HelloRSocketWithWebFluxITests.java new file mode 100644 index 0000000000..d2437c76ce --- /dev/null +++ b/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/HelloRSocketWithWebFluxITests.java @@ -0,0 +1,169 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.rsocket; + +import java.util.ArrayList; +import java.util.List; + +import io.rsocket.core.RSocketServer; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.messaging.handler.annotation.MessageMapping; +import org.springframework.messaging.rsocket.RSocketRequester; +import org.springframework.messaging.rsocket.RSocketStrategies; +import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler; +import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity; +import org.springframework.security.core.userdetails.MapReactiveUserDetailsService; +import org.springframework.security.core.userdetails.User; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.rsocket.core.SecuritySocketAcceptorInterceptor; +import org.springframework.security.rsocket.metadata.BasicAuthenticationEncoder; +import org.springframework.stereotype.Controller; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +/** + * @author Rob Winch + */ +@ContextConfiguration +@ExtendWith(SpringExtension.class) +public class HelloRSocketWithWebFluxITests { + + @Autowired + RSocketMessageHandler handler; + + @Autowired + SecuritySocketAcceptorInterceptor interceptor; + + @Autowired + ServerController controller; + + private CloseableChannel server; + + private RSocketRequester requester; + + @BeforeEach + public void setup() { + // @formatter:off + this.server = RSocketServer.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .interceptors((registry) -> + registry.forSocketAcceptor(this.interceptor) + ) + .acceptor(this.handler.responder()) + .bind(TcpServerTransport.create("localhost", 0)) + .block(); + // @formatter:on + } + + @AfterEach + public void dispose() { + this.requester.rsocket().dispose(); + this.server.dispose(); + this.controller.payloads.clear(); + } + + // gh-16161 + @Test + public void retrieveMonoWhenSecureThenDenied() { + // @formatter:off + this.requester = RSocketRequester.builder() + .rsocketStrategies(this.handler.getRSocketStrategies()) + .connectTcp("localhost", this.server.address().getPort()) + .block(); + // @formatter:on + String data = "rob"; + // @formatter:off + assertThatExceptionOfType(Exception.class).isThrownBy( + () -> this.requester.route("secure.retrieve-mono") + .data(data) + .retrieveMono(String.class) + .block() + ) + .matches((ex) -> ex instanceof RejectedSetupException + || ex.getClass().toString().contains("ReactiveException")); + // @formatter:on + assertThat(this.controller.payloads).isEmpty(); + } + + @Configuration + @EnableRSocketSecurity + @EnableWebFluxSecurity + static class Config { + + @Bean + ServerController controller() { + return new ServerController(); + } + + @Bean + RSocketMessageHandler messageHandler() { + RSocketMessageHandler handler = new RSocketMessageHandler(); + handler.setRSocketStrategies(rsocketStrategies()); + return handler; + } + + @Bean + RSocketStrategies rsocketStrategies() { + return RSocketStrategies.builder().encoder(new BasicAuthenticationEncoder()).build(); + } + + @Bean + MapReactiveUserDetailsService uds() { + // @formatter:off + UserDetails rob = User.withDefaultPasswordEncoder() + .username("rob") + .password("password") + .roles("USER", "ADMIN") + .build(); + // @formatter:on + return new MapReactiveUserDetailsService(rob); + } + + } + + @Controller + static class ServerController { + + private List payloads = new ArrayList<>(); + + @MessageMapping("**") + String retrieveMono(String payload) { + add(payload); + return "Hi " + payload; + } + + private void add(String p) { + this.payloads.add(p); + } + + } + +} diff --git a/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurityConfiguration.java index da5a5fcb1d..b76ae5ac0a 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurityConfiguration.java @@ -16,6 +16,8 @@ package org.springframework.security.config.annotation.rsocket; +import java.util.Map; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; @@ -62,8 +64,12 @@ class RSocketSecurityConfiguration { } @Autowired(required = false) - void setAuthenticationManagerPostProcessor(ObjectPostProcessor postProcessor) { - this.postProcessor = postProcessor; + void setAuthenticationManagerPostProcessor( + Map> postProcessors) { + if (postProcessors.size() == 1) { + this.postProcessor = postProcessors.values().iterator().next(); + } + this.postProcessor = postProcessors.get("rSocketAuthenticationManagerPostProcessor"); } @Bean(name = RSOCKET_SECURITY_BEAN_NAME) diff --git a/config/src/main/java/org/springframework/security/config/annotation/rsocket/ReactiveObservationConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/rsocket/ReactiveObservationConfiguration.java index 14862d79b9..b8d3880474 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/rsocket/ReactiveObservationConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/rsocket/ReactiveObservationConfiguration.java @@ -29,9 +29,7 @@ import org.springframework.security.authorization.ObservationReactiveAuthorizati import org.springframework.security.authorization.ReactiveAuthorizationManager; import org.springframework.security.config.ObjectPostProcessor; import org.springframework.security.config.observation.SecurityObservationSettings; -import org.springframework.security.web.server.ObservationWebFilterChainDecorator; -import org.springframework.security.web.server.WebFilterChainProxy.WebFilterChainDecorator; -import org.springframework.web.server.ServerWebExchange; +import org.springframework.security.rsocket.api.PayloadExchange; @Configuration(proxyBeanMethods = false) @Role(BeanDefinition.ROLE_INFRASTRUCTURE) @@ -45,7 +43,7 @@ class ReactiveObservationConfiguration { @Bean @Role(BeanDefinition.ROLE_INFRASTRUCTURE) - static ObjectPostProcessor> rSocketAuthorizationManagerPostProcessor( + static ObjectPostProcessor> rSocketAuthorizationManagerPostProcessor( ObjectProvider registry, ObjectProvider predicate) { return new ObjectPostProcessor<>() { @Override @@ -71,18 +69,4 @@ class ReactiveObservationConfiguration { }; } - @Bean - @Role(BeanDefinition.ROLE_INFRASTRUCTURE) - static ObjectPostProcessor rSocketFilterChainDecoratorPostProcessor( - ObjectProvider registry, ObjectProvider predicate) { - return new ObjectPostProcessor<>() { - @Override - public WebFilterChainDecorator postProcess(WebFilterChainDecorator object) { - ObservationRegistry r = registry.getIfUnique(() -> ObservationRegistry.NOOP); - boolean active = !r.isNoop() && predicate.getIfUnique(() -> all).shouldObserveRequests(); - return active ? new ObservationWebFilterChainDecorator(r) : object; - } - }; - } - } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveObservationConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveObservationConfiguration.java index 663b0ba9b0..efb921c389 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveObservationConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveObservationConfiguration.java @@ -59,7 +59,7 @@ class ReactiveObservationConfiguration { @Bean @Role(BeanDefinition.ROLE_INFRASTRUCTURE) - static ObjectPostProcessor authenticationManagerPostProcessor( + static ObjectPostProcessor reactiveAuthenticationManagerPostProcessor( ObjectProvider registry, ObjectProvider predicate) { return new ObjectPostProcessor<>() { @Override diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfiguration.java index 8bc535908a..90b8bb19df 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfiguration.java @@ -16,6 +16,8 @@ package org.springframework.security.config.annotation.web.reactive; +import java.util.Map; + import org.springframework.beans.BeansException; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.ObjectProvider; @@ -96,8 +98,12 @@ class ServerHttpSecurityConfiguration { } @Autowired(required = false) - void setAuthenticationManagerPostProcessor(ObjectPostProcessor postProcessor) { - this.postProcessor = postProcessor; + void setAuthenticationManagerPostProcessor( + Map> postProcessors) { + if (postProcessors.size() == 1) { + this.postProcessor = postProcessors.values().iterator().next(); + } + this.postProcessor = postProcessors.get("reactiveAuthenticationManagerPostProcessor"); } @Autowired(required = false) diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTests.java index 82a5d9020b..4129702b34 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTests.java @@ -242,6 +242,31 @@ public class ServerHttpSecurityConfigurationTests { assertThat(contexts.next().getContextualName()).isEqualTo("security filterchain after"); } + // gh-16161 + @Test + public void getWhenUsingRSocketThenObservesRequest() { + this.spring.register(ObservationRegistryConfig.class, RSocketSecurityConfig.class).autowire(); + // @formatter:off + this.webClient + .get() + .uri("/hello") + .headers((headers) -> headers.setBasicAuth("user", "password")) + .exchange() + .expectStatus() + .isNotFound(); + // @formatter:on + ObservationHandler handler = this.spring.getContext().getBean(ObservationHandler.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(Observation.Context.class); + verify(handler, times(6)).onStart(captor.capture()); + Iterator contexts = captor.getAllValues().iterator(); + assertThat(contexts.next().getContextualName()).isEqualTo("http get"); + assertThat(contexts.next().getContextualName()).isEqualTo("security filterchain before"); + assertThat(contexts.next().getName()).isEqualTo("spring.security.authentications"); + assertThat(contexts.next().getName()).isEqualTo("spring.security.authorizations"); + assertThat(contexts.next().getName()).isEqualTo("spring.security.http.secured.requests"); + assertThat(contexts.next().getContextualName()).isEqualTo("security filterchain after"); + } + @Configuration static class SubclassConfig extends ServerHttpSecurityConfiguration {