diff --git a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java index a993d53d01..8926888bdb 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java @@ -25,6 +25,7 @@ import org.mockito.junit.MockitoJUnitRunner; import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; import org.springframework.security.web.server.WebFilterChainProxy; import org.springframework.security.web.server.context.ServerSecurityContextRepository; @@ -33,11 +34,12 @@ import org.springframework.test.web.reactive.server.EntityExchangeResult; import org.springframework.test.web.reactive.server.FluxExchangeResult; import org.springframework.test.web.reactive.server.WebTestClient; import reactor.core.publisher.Mono; +import reactor.test.publisher.TestPublisher; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.BDDMockito.given; import static org.mockito.Matchers.any; -import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.basicAuthentication; /** @@ -61,6 +63,8 @@ public class ServerHttpSecurityTests { @Test public void defaults() { + TestPublisher securityContext = TestPublisher.create(); + when(this.contextRepository.load(any())).thenReturn(securityContext.mono()); this.http.securityContextRepository(this.contextRepository); WebTestClient client = buildClient(); @@ -73,7 +77,7 @@ public class ServerHttpSecurityTests { assertThat(result.getResponseCookies()).isEmpty(); // there is no need to try and load the SecurityContext by default - verifyZeroInteractions(this.contextRepository); + securityContext.assertWasNotSubscribed(); } @Test diff --git a/web/src/main/java/org/springframework/security/web/server/context/ReactorContextWebFilter.java b/web/src/main/java/org/springframework/security/web/server/context/ReactorContextWebFilter.java index 7db3bdf4b2..e4e95078b7 100644 --- a/web/src/main/java/org/springframework/security/web/server/context/ReactorContextWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/context/ReactorContextWebFilter.java @@ -45,7 +45,7 @@ public class ReactorContextWebFilter implements WebFilter { } private Context withSecurityContext(Context mainContext, ServerWebExchange exchange) { - return mainContext.putAll(Mono.defer(() -> this.repository.load(exchange)) + return mainContext.putAll(this.repository.load(exchange) .as(ReactiveSecurityContextHolder::withSecurityContext)); } } diff --git a/web/src/test/java/org/springframework/security/web/server/context/ReactorContextWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/context/ReactorContextWebFilterTests.java index f3472a9b7f..ccf472e2f2 100644 --- a/web/src/test/java/org/springframework/security/web/server/context/ReactorContextWebFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/server/context/ReactorContextWebFilterTests.java @@ -33,6 +33,7 @@ import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.handler.DefaultWebFilterChain; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; import reactor.util.context.Context; import static org.assertj.core.api.Assertions.assertThat; @@ -53,6 +54,8 @@ public class ReactorContextWebFilterTests { private MockServerHttpRequest.BaseBuilder exchange = MockServerHttpRequest.get("/"); + private TestPublisher securityContext = TestPublisher.create(); + private ReactorContextWebFilter filter; private WebTestHandler handler; @@ -62,6 +65,7 @@ public class ReactorContextWebFilterTests { public void setup() { this.filter = new ReactorContextWebFilter(this.repository); this.handler = WebTestHandler.bindToWebFilters(this.filter); + when(this.repository.load(any())).thenReturn(this.securityContext.mono()); } @Test(expected = IllegalArgumentException.class) @@ -74,7 +78,7 @@ public class ReactorContextWebFilterTests { public void filterWhenNoPrincipalAccessThenNoInteractions() { this.handler.exchange(this.exchange); - verifyZeroInteractions(this.repository); + this.securityContext.assertWasNotSubscribed(); } @Test @@ -86,7 +90,7 @@ public class ReactorContextWebFilterTests { this.handler.exchange(this.exchange); - verifyZeroInteractions(this.repository); + this.securityContext.assertWasNotSubscribed(); } @Test @@ -102,7 +106,7 @@ public class ReactorContextWebFilterTests { WebTestHandler.WebHandlerResult result = this.handler.exchange(this.exchange); - verify(this.repository).load(any()); + this.securityContext.assertWasNotSubscribed(); } @Test