From a55021539a3b7f8e9aebe6781a39ce4033050215 Mon Sep 17 00:00:00 2001 From: Josh Cummings <3627351+jzheaux@users.noreply.github.com> Date: Mon, 25 Nov 2024 12:32:44 -0700 Subject: [PATCH] Add RSocket and WebFlux Observation Tests Issue gh-11989 Issue gh-11990 --- .../HelloRSocketObservationITests.java | 200 ++++++++++++++++++ .../ServerHttpSecurityConfigurationTests.java | 81 +++++++ 2 files changed, 281 insertions(+) create mode 100644 config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/HelloRSocketObservationITests.java diff --git a/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/HelloRSocketObservationITests.java b/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/HelloRSocketObservationITests.java new file mode 100644 index 0000000000..9b6cf17b8c --- /dev/null +++ b/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/HelloRSocketObservationITests.java @@ -0,0 +1,200 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.rsocket; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationHandler; +import io.micrometer.observation.ObservationRegistry; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.messaging.handler.annotation.MessageMapping; +import org.springframework.messaging.rsocket.RSocketRequester; +import org.springframework.messaging.rsocket.RSocketStrategies; +import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler; +import org.springframework.security.core.userdetails.MapReactiveUserDetailsService; +import org.springframework.security.core.userdetails.User; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.rsocket.core.SecuritySocketAcceptorInterceptor; +import org.springframework.security.rsocket.metadata.SimpleAuthenticationEncoder; +import org.springframework.security.rsocket.metadata.UsernamePasswordMetadata; +import org.springframework.stereotype.Controller; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; +import org.springframework.util.MimeTypeUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * @author Rob Winch + */ +@ContextConfiguration +@ExtendWith(SpringExtension.class) +public class HelloRSocketObservationITests { + + @Autowired + RSocketMessageHandler handler; + + @Autowired + SecuritySocketAcceptorInterceptor interceptor; + + @Autowired + ServerController controller; + + @Autowired + ObservationHandler observationHandler; + + private CloseableChannel server; + + private RSocketRequester requester; + + @BeforeEach + public void setup() { + // @formatter:off + this.server = RSocketServer.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .interceptors((registry) -> + registry.forSocketAcceptor(this.interceptor) + ) + .acceptor(this.handler.responder()) + .bind(TcpServerTransport.create("localhost", 0)) + .block(); + // @formatter:on + } + + @AfterEach + public void dispose() { + this.requester.rsocket().dispose(); + this.server.dispose(); + this.controller.payloads.clear(); + } + + @Test + public void getWhenUsingObservationRegistryThenObservesRequest() { + UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("rob", "password"); + // @formatter:off + this.requester = RSocketRequester.builder() + .setupMetadata(credentials, MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION.getString())) + .rsocketStrategies(this.handler.getRSocketStrategies()) + .connectTcp("localhost", this.server.address().getPort()) + .block(); + // @formatter:on + String data = "rob"; + // @formatter:off + this.requester.route("secure.retrieve-mono") + .metadata(credentials, MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION.getString())) + .data(data) + .retrieveMono(String.class) + .block(); + // @formatter:on + ArgumentCaptor captor = ArgumentCaptor.forClass(Observation.Context.class); + verify(this.observationHandler, times(2)).onStart(captor.capture()); + Iterator contexts = captor.getAllValues().iterator(); + // once for setup + assertThat(contexts.next().getName()).isEqualTo("spring.security.authentications"); + // once for request + assertThat(contexts.next().getName()).isEqualTo("spring.security.authentications"); + } + + @Configuration + @EnableRSocketSecurity + static class Config { + + private ObservationHandler handler = mock(ObservationHandler.class); + + @Bean + ServerController controller() { + return new ServerController(); + } + + @Bean + RSocketMessageHandler messageHandler() { + RSocketMessageHandler handler = new RSocketMessageHandler(); + handler.setRSocketStrategies(rsocketStrategies()); + return handler; + } + + @Bean + RSocketStrategies rsocketStrategies() { + return RSocketStrategies.builder().encoder(new SimpleAuthenticationEncoder()).build(); + } + + @Bean + MapReactiveUserDetailsService uds() { + // @formatter:off + UserDetails rob = User.withDefaultPasswordEncoder() + .username("rob") + .password("password") + .roles("USER", "ADMIN") + .build(); + // @formatter:on + return new MapReactiveUserDetailsService(rob); + } + + @Bean + ObservationHandler observationHandler() { + return this.handler; + } + + @Bean + ObservationRegistry observationRegistry() { + given(this.handler.supportsContext(any())).willReturn(true); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(this.handler); + return registry; + } + + } + + @Controller + static class ServerController { + + private List payloads = new ArrayList<>(); + + @MessageMapping("**") + String retrieveMono(String payload) { + add(payload); + return "Hi " + payload; + } + + private void add(String p) { + this.payloads.add(p); + } + + } + +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTests.java index 006b3f190a..82a5d9020b 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTests.java @@ -21,20 +21,27 @@ import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import java.net.URI; +import java.util.Iterator; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationHandler; +import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; import reactor.core.publisher.Mono; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.password.CompromisedPasswordDecision; import org.springframework.security.authentication.password.CompromisedPasswordException; import org.springframework.security.authentication.password.ReactiveCompromisedPasswordChecker; import org.springframework.security.config.Customizer; +import org.springframework.security.config.annotation.rsocket.EnableRSocketSecurity; import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContextExtension; import org.springframework.security.config.users.ReactiveAuthenticationTestConfiguration; @@ -60,6 +67,12 @@ import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.csrf; import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockAuthentication; import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.springSecurity; @@ -205,6 +218,30 @@ public class ServerHttpSecurityConfigurationTests { .isEqualTo("harold"); } + @Test + public void getWhenUsingObservationRegistryThenObservesRequest() { + this.spring.register(ObservationRegistryConfig.class).autowire(); + // @formatter:off + this.webClient + .get() + .uri("/hello") + .headers((headers) -> headers.setBasicAuth("user", "password")) + .exchange() + .expectStatus() + .isNotFound(); + // @formatter:on + ObservationHandler handler = this.spring.getContext().getBean(ObservationHandler.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(Observation.Context.class); + verify(handler, times(6)).onStart(captor.capture()); + Iterator contexts = captor.getAllValues().iterator(); + assertThat(contexts.next().getContextualName()).isEqualTo("http get"); + assertThat(contexts.next().getContextualName()).isEqualTo("security filterchain before"); + assertThat(contexts.next().getName()).isEqualTo("spring.security.authentications"); + assertThat(contexts.next().getName()).isEqualTo("spring.security.authorizations"); + assertThat(contexts.next().getName()).isEqualTo("spring.security.http.secured.requests"); + assertThat(contexts.next().getContextualName()).isEqualTo("security filterchain after"); + } + @Configuration static class SubclassConfig extends ServerHttpSecurityConfiguration { @@ -368,4 +405,48 @@ public class ServerHttpSecurityConfigurationTests { } + @Configuration + @EnableWebFlux + @EnableWebFluxSecurity + static class ObservationRegistryConfig { + + private ObservationHandler handler = mock(ObservationHandler.class); + + @Bean + SecurityWebFilterChain app(ServerHttpSecurity http) throws Exception { + http.httpBasic(withDefaults()).authorizeExchange((authorize) -> authorize.anyExchange().authenticated()); + return http.build(); + } + + @Bean + ReactiveUserDetailsService userDetailsService() { + return new MapReactiveUserDetailsService( + User.withDefaultPasswordEncoder().username("user").password("password").authorities("app").build()); + } + + @Bean + ObservationHandler observationHandler() { + return this.handler; + } + + @Bean + ObservationRegistry observationRegistry() { + given(this.handler.supportsContext(any())).willReturn(true); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(this.handler); + return registry; + } + + } + + @EnableRSocketSecurity + static class RSocketSecurityConfig { + + @Bean + RSocketMessageHandler messageHandler() { + return new RSocketMessageHandler(); + } + + } + }