diff --git a/web/src/main/java/org/springframework/security/web/ObservationFilterChainDecorator.java b/web/src/main/java/org/springframework/security/web/ObservationFilterChainDecorator.java index bb6c6d2c34..0c709b50e2 100644 --- a/web/src/main/java/org/springframework/security/web/ObservationFilterChainDecorator.java +++ b/web/src/main/java/org/springframework/security/web/ObservationFilterChainDecorator.java @@ -177,17 +177,19 @@ public final class ObservationFilterChainDecorator implements FilterChainProxy.F private void wrapFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { AroundFilterObservation parent = observation((HttpServletRequest) request); - FilterChainObservationContext parentBefore = (FilterChainObservationContext) parent.before().getContext(); - parentBefore.setChainSize(this.size); - parentBefore.setFilterName(this.name); - parentBefore.setChainPosition(this.position); + if (parent.before().getContext() instanceof FilterChainObservationContext parentBefore) { + parentBefore.setChainSize(this.size); + parentBefore.setFilterName(this.name); + parentBefore.setChainPosition(this.position); + } parent.before().event(Observation.Event.of(this.name + " before")); this.filter.doFilter(request, response, chain); parent.start(); - FilterChainObservationContext parentAfter = (FilterChainObservationContext) parent.after().getContext(); - parentAfter.setChainSize(this.size); - parentAfter.setFilterName(this.name); - parentAfter.setChainPosition(this.size - this.position + 1); + if (parent.after().getContext() instanceof FilterChainObservationContext parentAfter) { + parentAfter.setChainSize(this.size); + parentAfter.setFilterName(this.name); + parentAfter.setChainPosition(this.size - this.position + 1); + } parent.after().event(Observation.Event.of(this.name + " after")); } diff --git a/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java b/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java index efd4482a9c..11ecf33ac9 100644 --- a/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java +++ b/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java @@ -196,18 +196,18 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP private Mono wrapFilter(ServerWebExchange exchange, WebFilterChain chain) { AroundWebFilterObservation parent = observation(exchange); - WebFilterChainObservationContext parentBefore = (WebFilterChainObservationContext) parent.before() - .getContext(); - parentBefore.setChainSize(this.size); - parentBefore.setFilterName(this.name); - parentBefore.setChainPosition(this.position); + if (parent.before().getContext() instanceof WebFilterChainObservationContext parentBefore) { + parentBefore.setChainSize(this.size); + parentBefore.setFilterName(this.name); + parentBefore.setChainPosition(this.position); + } return this.filter.filter(exchange, chain).doOnSuccess((result) -> { parent.start(); - WebFilterChainObservationContext parentAfter = (WebFilterChainObservationContext) parent.after() - .getContext(); - parentAfter.setChainSize(this.size); - parentAfter.setFilterName(this.name); - parentAfter.setChainPosition(this.size - this.position + 1); + if (parent.after().getContext() instanceof WebFilterChainObservationContext parentAfter) { + parentAfter.setChainSize(this.size); + parentAfter.setFilterName(this.name); + parentAfter.setChainPosition(this.size - this.position + 1); + } }); } diff --git a/web/src/test/java/org/springframework/security/web/ObservationFilterChainDecoratorTests.java b/web/src/test/java/org/springframework/security/web/ObservationFilterChainDecoratorTests.java new file mode 100644 index 0000000000..caf7c71ba1 --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/ObservationFilterChainDecoratorTests.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web; + +import io.micrometer.observation.ObservationHandler; +import io.micrometer.observation.ObservationRegistry; +import jakarta.servlet.FilterChain; +import org.junit.jupiter.api.Test; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +/** + * Tests for {@link ObservationFilterChainDecorator} + */ +public class ObservationFilterChainDecoratorTests { + + @Test + void decorateWhenDefaultsThenObserves() throws Exception { + ObservationHandler handler = mock(ObservationHandler.class); + given(handler.supportsContext(any())).willReturn(true); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(handler); + ObservationFilterChainDecorator decorator = new ObservationFilterChainDecorator(registry); + FilterChain chain = mock(FilterChain.class); + FilterChain decorated = decorator.decorate(chain); + decorated.doFilter(new MockHttpServletRequest("GET", "/"), new MockHttpServletResponse()); + verify(handler).onStart(any()); + } + + @Test + void decorateWhenNoopThenDoesNotObserve() throws Exception { + ObservationHandler handler = mock(ObservationHandler.class); + given(handler.supportsContext(any())).willReturn(true); + ObservationRegistry registry = ObservationRegistry.NOOP; + registry.observationConfig().observationHandler(handler); + ObservationFilterChainDecorator decorator = new ObservationFilterChainDecorator(registry); + FilterChain chain = mock(FilterChain.class); + FilterChain decorated = decorator.decorate(chain); + decorated.doFilter(new MockHttpServletRequest("GET", "/"), new MockHttpServletResponse()); + verifyNoInteractions(handler); + } + +} diff --git a/web/src/test/java/org/springframework/security/web/server/ObservationWebFilterChainDecoratorTests.java b/web/src/test/java/org/springframework/security/web/server/ObservationWebFilterChainDecoratorTests.java new file mode 100644 index 0000000000..08aba40d9f --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/server/ObservationWebFilterChainDecoratorTests.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.server; + +import io.micrometer.observation.ObservationHandler; +import io.micrometer.observation.ObservationRegistry; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; + +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.web.server.WebFilterChain; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +/** + * Tests for {@link ObservationWebFilterChainDecorator} + */ +public class ObservationWebFilterChainDecoratorTests { + + @Test + void decorateWhenDefaultsThenObserves() { + ObservationHandler handler = mock(ObservationHandler.class); + given(handler.supportsContext(any())).willReturn(true); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(handler); + ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry); + WebFilterChain chain = mock(WebFilterChain.class); + given(chain.filter(any())).willReturn(Mono.empty()); + WebFilterChain decorated = decorator.decorate(chain); + decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build())).block(); + verify(handler).onStart(any()); + } + + @Test + void decorateWhenNoopThenDoesNotObserve() { + ObservationHandler handler = mock(ObservationHandler.class); + given(handler.supportsContext(any())).willReturn(true); + ObservationRegistry registry = ObservationRegistry.NOOP; + registry.observationConfig().observationHandler(handler); + ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry); + WebFilterChain chain = mock(WebFilterChain.class); + given(chain.filter(any())).willReturn(Mono.empty()); + WebFilterChain decorated = decorator.decorate(chain); + decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build())).block(); + verifyNoInteractions(handler); + } + +}