Add ObservationRegistry Tests

Issue gh-11989
Issue gh-11990
This commit is contained in:
Josh Cummings 2024-09-20 15:14:14 -06:00
parent 717529deb4
commit 1ed20aa210
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
7 changed files with 378 additions and 1 deletions

View File

@ -28,6 +28,10 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Supplier; import java.util.function.Supplier;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationHandler;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.ObservationTextPublisher;
import jakarta.annotation.security.DenyAll; import jakarta.annotation.security.DenyAll;
import org.aopalliance.intercept.MethodInterceptor; import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation; import org.aopalliance.intercept.MethodInvocation;
@ -40,9 +44,12 @@ import org.springframework.aop.Advisor;
import org.springframework.aop.config.AopConfigUtils; import org.springframework.aop.config.AopConfigUtils;
import org.springframework.aop.support.DefaultPointcutAdvisor; import org.springframework.aop.support.DefaultPointcutAdvisor;
import org.springframework.aop.support.JdkRegexpMethodPointcut; import org.springframework.aop.support.JdkRegexpMethodPointcut;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.context.annotation.AdviceMode; import org.springframework.context.annotation.AdviceMode;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
@ -1018,6 +1025,43 @@ public class PrePostMethodSecurityConfigurationTests {
assertThat(service.getIdPath("uid")).isEqualTo("uid"); assertThat(service.getIdPath("uid")).isEqualTo("uid");
} }
@Test
@WithMockUser
public void prePostMethodWhenObservationRegistryThenObserved() {
this.spring.register(MethodSecurityServiceEnabledConfig.class, ObservationRegistryConfig.class).autowire();
this.methodSecurityService.preAuthorizePermitAll();
ObservationHandler<?> handler = this.spring.getContext().getBean(ObservationHandler.class);
verify(handler).onStart(any());
verify(handler).onStop(any());
assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(this.methodSecurityService::preAuthorize);
verify(handler).onError(any());
}
@Test
@WithMockUser
public void securedMethodWhenObservationRegistryThenObserved() {
this.spring.register(MethodSecurityServiceEnabledConfig.class, ObservationRegistryConfig.class).autowire();
this.methodSecurityService.securedUser();
ObservationHandler<?> handler = this.spring.getContext().getBean(ObservationHandler.class);
verify(handler).onStart(any());
verify(handler).onStop(any());
assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(this.methodSecurityService::secured);
verify(handler).onError(any());
}
@Test
@WithMockUser
public void jsr250MethodWhenObservationRegistryThenObserved() {
this.spring.register(MethodSecurityServiceEnabledConfig.class, ObservationRegistryConfig.class).autowire();
this.methodSecurityService.jsr250RolesAllowedUser();
ObservationHandler<?> handler = this.spring.getContext().getBean(ObservationHandler.class);
verify(handler).onStart(any());
verify(handler).onStop(any());
assertThatExceptionOfType(AccessDeniedException.class)
.isThrownBy(this.methodSecurityService::jsr250RolesAllowed);
verify(handler).onError(any());
}
private static Consumer<ConfigurableWebApplicationContext> disallowBeanOverriding() { private static Consumer<ConfigurableWebApplicationContext> disallowBeanOverriding() {
return (context) -> ((AnnotationConfigWebApplicationContext) context).setAllowBeanDefinitionOverriding(false); return (context) -> ((AnnotationConfigWebApplicationContext) context).setAllowBeanDefinitionOverriding(false);
} }
@ -1655,4 +1699,47 @@ public class PrePostMethodSecurityConfigurationTests {
} }
@Configuration
static class ObservationRegistryConfig {
private final ObservationRegistry registry = ObservationRegistry.create();
private final ObservationHandler<Observation.Context> handler = spy(new ObservationTextPublisher());
@Bean
ObservationRegistry observationRegistry() {
return this.registry;
}
@Bean
ObservationHandler<Observation.Context> observationHandler() {
return this.handler;
}
@Bean
ObservationRegistryPostProcessor observationRegistryPostProcessor(
ObjectProvider<ObservationHandler<Observation.Context>> handler) {
return new ObservationRegistryPostProcessor(handler);
}
}
static class ObservationRegistryPostProcessor implements BeanPostProcessor {
private final ObjectProvider<ObservationHandler<Observation.Context>> handler;
ObservationRegistryPostProcessor(ObjectProvider<ObservationHandler<Observation.Context>> handler) {
this.handler = handler;
}
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
if (bean instanceof ObservationRegistry registry) {
registry.observationConfig().observationHandler(this.handler.getObject());
}
return bean;
}
}
} }

View File

@ -23,12 +23,17 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationHandler;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.ObservationTextPublisher;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.test.StepVerifier; import reactor.test.StepVerifier;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
@ -62,6 +67,7 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
/** /**
@ -235,6 +241,25 @@ public class ReactiveMethodSecurityConfigurationTests {
verify(handler, never()).handleDeniedInvocation(any(), any(Authz.AuthzResult.class)); verify(handler, never()).handleDeniedInvocation(any(), any(Authz.AuthzResult.class));
} }
@Test
public void prePostMethodWhenObservationRegistryThenObserved() {
this.spring.register(MethodSecurityServiceConfig.class, ObservationRegistryConfig.class).autowire();
ReactiveMethodSecurityService service = this.spring.getContext().getBean(ReactiveMethodSecurityService.class);
Authentication user = TestAuthentication.authenticatedUser();
StepVerifier
.create(service.preAuthorizeUser().contextWrite(ReactiveSecurityContextHolder.withAuthentication(user)))
.expectNextCount(1)
.verifyComplete();
ObservationHandler<?> handler = this.spring.getContext().getBean(ObservationHandler.class);
verify(handler).onStart(any());
verify(handler).onStop(any());
StepVerifier
.create(service.preAuthorizeAdmin().contextWrite(ReactiveSecurityContextHolder.withAuthentication(user)))
.expectError()
.verify();
verify(handler).onError(any());
}
private static Consumer<User.UserBuilder> authorities(String... authorities) { private static Consumer<User.UserBuilder> authorities(String... authorities) {
return (builder) -> builder.authorities(authorities); return (builder) -> builder.authorities(authorities);
} }
@ -388,4 +413,30 @@ public class ReactiveMethodSecurityConfigurationTests {
} }
@Configuration
@EnableReactiveMethodSecurity
static class ObservationRegistryConfig {
private final ObservationRegistry registry = ObservationRegistry.create();
private final ObservationHandler<Observation.Context> handler = spy(new ObservationTextPublisher());
@Bean
ObservationRegistry observationRegistry() {
return this.registry;
}
@Bean
ObservationHandler<Observation.Context> observationHandler() {
return this.handler;
}
@Bean
PrePostMethodSecurityConfigurationTests.ObservationRegistryPostProcessor observationRegistryPostProcessor(
ObjectProvider<ObservationHandler<Observation.Context>> handler) {
return new PrePostMethodSecurityConfigurationTests.ObservationRegistryPostProcessor(handler);
}
}
} }

View File

@ -48,6 +48,12 @@ import org.springframework.util.StringUtils;
@ReactiveMethodSecurityService.Mask("classmask") @ReactiveMethodSecurityService.Mask("classmask")
public interface ReactiveMethodSecurityService { public interface ReactiveMethodSecurityService {
@PreAuthorize("hasRole('USER')")
Mono<String> preAuthorizeUser();
@PreAuthorize("hasRole('ADMIN')")
Mono<String> preAuthorizeAdmin();
@PreAuthorize("hasRole('ADMIN')") @PreAuthorize("hasRole('ADMIN')")
@HandleAuthorizationDenied(handlerClass = StarMaskingHandler.class) @HandleAuthorizationDenied(handlerClass = StarMaskingHandler.class)
Mono<String> preAuthorizeGetCardNumberIfAdmin(String cardNumber); Mono<String> preAuthorizeGetCardNumberIfAdmin(String cardNumber);

View File

@ -25,6 +25,16 @@ import org.springframework.security.authorization.AuthorizationDeniedException;
public class ReactiveMethodSecurityServiceImpl implements ReactiveMethodSecurityService { public class ReactiveMethodSecurityServiceImpl implements ReactiveMethodSecurityService {
@Override
public Mono<String> preAuthorizeUser() {
return Mono.just("user");
}
@Override
public Mono<String> preAuthorizeAdmin() {
return Mono.just("admin");
}
@Override @Override
public Mono<String> preAuthorizeGetCardNumberIfAdmin(String cardNumber) { public Mono<String> preAuthorizeGetCardNumberIfAdmin(String cardNumber) {
return Mono.just(cardNumber); return Mono.just(cardNumber);

View File

@ -18,12 +18,20 @@ package org.springframework.security.config.annotation.web.configurers;
import java.util.function.Supplier; import java.util.function.Supplier;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationHandler;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.ObservationTextPublisher;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.BeanCreationException;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.security.access.hierarchicalroles.RoleHierarchy; import org.springframework.security.access.hierarchicalroles.RoleHierarchy;
@ -33,6 +41,7 @@ import org.springframework.security.authentication.TestAuthentication;
import org.springframework.security.authorization.AuthorizationDecision; import org.springframework.security.authorization.AuthorizationDecision;
import org.springframework.security.authorization.AuthorizationEventPublisher; import org.springframework.security.authorization.AuthorizationEventPublisher;
import org.springframework.security.authorization.AuthorizationManager; import org.springframework.security.authorization.AuthorizationManager;
import org.springframework.security.authorization.AuthorizationObservationContext;
import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.config.annotation.ObjectPostProcessor;
import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry; import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry;
import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity;
@ -43,6 +52,7 @@ import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.provisioning.InMemoryUserDetailsManager;
@ -63,8 +73,10 @@ import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import org.springframework.web.servlet.handler.HandlerMappingIntrospector; import org.springframework.web.servlet.handler.HandlerMappingIntrospector;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.Mockito.any; import static org.mockito.Mockito.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -153,7 +165,8 @@ public class AuthorizeHttpRequestsConfigurerTests {
@Test @Test
public void configureWhenObjectPostProcessorRegisteredThenInvokedOnAuthorizationManagerAndAuthorizationFilter() { public void configureWhenObjectPostProcessorRegisteredThenInvokedOnAuthorizationManagerAndAuthorizationFilter() {
this.spring.register(ObjectPostProcessorConfig.class).autowire(); this.spring.register(ObjectPostProcessorConfig.class).autowire();
ObjectPostProcessor objectPostProcessor = this.spring.getContext().getBean(ObjectPostProcessor.class); ObjectPostProcessor<Object> objectPostProcessor = this.spring.getContext()
.getBean(ObjectPostProcessorConfig.class).objectPostProcessor;
verify(objectPostProcessor).postProcess(any(RequestMatcherDelegatingAuthorizationManager.class)); verify(objectPostProcessor).postProcess(any(RequestMatcherDelegatingAuthorizationManager.class));
verify(objectPostProcessor).postProcess(any(AuthorizationFilter.class)); verify(objectPostProcessor).postProcess(any(AuthorizationFilter.class));
} }
@ -623,6 +636,20 @@ public class AuthorizeHttpRequestsConfigurerTests {
this.mvc.perform(requestWithUser).andExpect(status().isOk()); this.mvc.perform(requestWithUser).andExpect(status().isOk());
} }
@Test
public void getWhenObservationRegistryThenObserves() throws Exception {
this.spring.register(RoleUserConfig.class, BasicController.class, ObservationRegistryConfig.class).autowire();
ObservationHandler<Observation.Context> handler = this.spring.getContext().getBean(ObservationHandler.class);
this.mvc.perform(get("/").with(user("user").roles("USER"))).andExpect(status().isOk());
ArgumentCaptor<Observation.Context> context = ArgumentCaptor.forClass(Observation.Context.class);
verify(handler, atLeastOnce()).onStart(context.capture());
assertThat(context.getAllValues()).anyMatch((c) -> c instanceof AuthorizationObservationContext);
verify(handler, atLeastOnce()).onStop(context.capture());
assertThat(context.getAllValues()).anyMatch((c) -> c instanceof AuthorizationObservationContext);
this.mvc.perform(get("/").with(user("user").roles("WRONG"))).andExpect(status().isForbidden());
verify(handler).onError(any());
}
@Configuration @Configuration
@EnableWebSecurity @EnableWebSecurity
static class GrantedAuthorityDefaultHasRoleConfig { static class GrantedAuthorityDefaultHasRoleConfig {
@ -1015,6 +1042,12 @@ public class AuthorizeHttpRequestsConfigurerTests {
// @formatter:on // @formatter:on
} }
@Bean
UserDetailsService users() {
return new InMemoryUserDetailsManager(
User.withUsername("user").password("{noop}password").roles("USER").build());
}
} }
@Configuration @Configuration
@ -1212,4 +1245,47 @@ public class AuthorizeHttpRequestsConfigurerTests {
} }
@Configuration
static class ObservationRegistryConfig {
private final ObservationRegistry registry = ObservationRegistry.create();
private final ObservationHandler<Observation.Context> handler = spy(new ObservationTextPublisher());
@Bean
ObservationRegistry observationRegistry() {
return this.registry;
}
@Bean
ObservationHandler<Observation.Context> observationHandler() {
return this.handler;
}
@Bean
ObservationRegistryPostProcessor observationRegistryPostProcessor(
ObjectProvider<ObservationHandler<Observation.Context>> handler) {
return new ObservationRegistryPostProcessor(handler);
}
}
static class ObservationRegistryPostProcessor implements BeanPostProcessor {
private final ObjectProvider<ObservationHandler<Observation.Context>> handler;
ObservationRegistryPostProcessor(ObjectProvider<ObservationHandler<Observation.Context>> handler) {
this.handler = handler;
}
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
if (bean instanceof ObservationRegistry registry) {
registry.observationConfig().observationHandler(this.handler.getObject());
}
return bean;
}
}
} }

View File

@ -16,14 +16,23 @@
package org.springframework.security.config.annotation.web.configurers; package org.springframework.security.config.annotation.web.configurers;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationHandler;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.ObservationTextPublisher;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.security.authentication.AuthenticationObservationContext;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.config.Customizer; import org.springframework.security.config.Customizer;
import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.config.annotation.ObjectPostProcessor;
@ -50,7 +59,9 @@ import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -161,6 +172,23 @@ public class HttpBasicConfigurerTests {
.saveContext(any(SecurityContext.class), any(HttpServletRequest.class), any(HttpServletResponse.class)); .saveContext(any(SecurityContext.class), any(HttpServletRequest.class), any(HttpServletResponse.class));
} }
@Test
public void httpBasicWhenObservationRegistryThenObserves() throws Exception {
this.spring.register(HttpBasic.class, Users.class, Home.class, ObservationRegistryConfig.class).autowire();
ObservationHandler<Observation.Context> handler = this.spring.getContext().getBean(ObservationHandler.class);
this.mvc.perform(get("/").with(httpBasic("user", "password")))
.andExpect(status().isOk())
.andExpect(content().string("user"));
ArgumentCaptor<Observation.Context> context = ArgumentCaptor.forClass(Observation.Context.class);
verify(handler, atLeastOnce()).onStart(context.capture());
assertThat(context.getAllValues()).anyMatch((c) -> c instanceof AuthenticationObservationContext);
verify(handler, atLeastOnce()).onStop(context.capture());
assertThat(context.getAllValues()).anyMatch((c) -> c instanceof AuthenticationObservationContext);
this.mvc.perform(get("/").with(httpBasic("user", "wrong"))).andExpect(status().isUnauthorized());
verify(handler).onError(context.capture());
assertThat(context.getValue()).isInstanceOf(AuthenticationObservationContext.class);
}
@Configuration @Configuration
@EnableWebSecurity @EnableWebSecurity
static class ObjectPostProcessorConfig { static class ObjectPostProcessorConfig {
@ -384,4 +412,47 @@ public class HttpBasicConfigurerTests {
} }
@Configuration
static class ObservationRegistryConfig {
private final ObservationRegistry registry = ObservationRegistry.create();
private final ObservationHandler<Observation.Context> handler = spy(new ObservationTextPublisher());
@Bean
ObservationRegistry observationRegistry() {
return this.registry;
}
@Bean
ObservationHandler<Observation.Context> observationHandler() {
return this.handler;
}
@Bean
ObservationRegistryPostProcessor observationRegistryPostProcessor(
ObjectProvider<ObservationHandler<Observation.Context>> handler) {
return new ObservationRegistryPostProcessor(handler);
}
}
static class ObservationRegistryPostProcessor implements BeanPostProcessor {
private final ObjectProvider<ObservationHandler<Observation.Context>> handler;
ObservationRegistryPostProcessor(ObjectProvider<ObservationHandler<Observation.Context>> handler) {
this.handler = handler;
}
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
if (bean instanceof ObservationRegistry registry) {
registry.observationConfig().observationHandler(this.handler.getObject());
}
return bean;
}
}
} }

View File

@ -26,13 +26,20 @@ import java.util.Map;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.stream.Stream; import java.util.stream.Stream;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationHandler;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.ObservationTextPublisher;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
@ -95,7 +102,9 @@ import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSo
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.fail; import static org.assertj.core.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
@ -381,6 +390,30 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
clientInboundChannel().send(message("/anonymous")); clientInboundChannel().send(message("/anonymous"));
} }
@Test
public void sendMessageWhenObservationRegistryThenObserves() {
loadConfig(WebSocketSecurityConfig.class, ObservationRegistryConfig.class);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
headers.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE);
Message<?> message = message(headers, "/authenticated");
headers.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
clientInboundChannel().send(message);
ObservationHandler<Observation.Context> observationHandler = this.context.getBean(ObservationHandler.class);
verify(observationHandler).onStart(any());
verify(observationHandler).onStop(any());
headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
headers.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE);
message = message(headers, "/denyAll");
headers.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
try {
clientInboundChannel().send(message);
}
catch (MessageDeliveryException ex) {
// okay
}
verify(observationHandler).onError(any());
}
private void assertHandshake(HttpServletRequest request) { private void assertHandshake(HttpServletRequest request) {
TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class); TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class);
assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token); assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token);
@ -892,4 +925,47 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
} }
@Configuration
static class ObservationRegistryConfig {
private final ObservationRegistry registry = ObservationRegistry.create();
private final ObservationHandler<Observation.Context> handler = spy(new ObservationTextPublisher());
@Bean
ObservationRegistry observationRegistry() {
return this.registry;
}
@Bean
ObservationHandler<Observation.Context> observationHandler() {
return this.handler;
}
@Bean
ObservationRegistryPostProcessor observationRegistryPostProcessor(
ObjectProvider<ObservationHandler<Observation.Context>> handler) {
return new ObservationRegistryPostProcessor(handler);
}
}
static class ObservationRegistryPostProcessor implements BeanPostProcessor {
private final ObjectProvider<ObservationHandler<Observation.Context>> handler;
ObservationRegistryPostProcessor(ObjectProvider<ObservationHandler<Observation.Context>> handler) {
this.handler = handler;
}
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
if (bean instanceof ObservationRegistry registry) {
registry.observationConfig().observationHandler(this.handler.getObject());
}
return bean;
}
}
} }