diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java index 104cb0b766..dae657b36e 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,14 +17,20 @@ package org.springframework.security.config.annotation.web.configuration; import java.net.URI; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadFactory; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnJre; +import org.junit.jupiter.api.condition.JRE; import org.junit.jupiter.api.extension.ExtendWith; import reactor.core.CoreSubscriber; import reactor.core.publisher.BaseSubscriber; @@ -35,6 +41,8 @@ import reactor.util.context.Context; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.core.task.VirtualThreadTaskExecutor; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.mock.web.MockHttpServletRequest; @@ -46,6 +54,7 @@ import org.springframework.security.config.annotation.web.configuration.Security import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContextExtension; import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction; @@ -271,6 +280,58 @@ public class SecurityReactorContextConfigurationTests { verify(strategy, times(2)).getContext(); } + @Test + public void createPublisherWhenThreadFactoryIsPlatformThenSecurityContextAttributesAvailable() throws Exception { + this.spring.register(SecurityConfig.class).autowire(); + + ThreadFactory threadFactory = Executors.defaultThreadFactory(); + assertContextAttributesAvailable(threadFactory); + } + + @Test + @DisabledOnJre(JRE.JAVA_17) + public void createPublisherWhenThreadFactoryIsVirtualThenSecurityContextAttributesAvailable() throws Exception { + this.spring.register(SecurityConfig.class).autowire(); + + ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory(); + assertContextAttributesAvailable(threadFactory); + } + + private void assertContextAttributesAvailable(ThreadFactory threadFactory) throws Exception { + Map expectedContextAttributes = new HashMap<>(); + expectedContextAttributes.put(HttpServletRequest.class, this.servletRequest); + expectedContextAttributes.put(HttpServletResponse.class, this.servletResponse); + expectedContextAttributes.put(Authentication.class, this.authentication); + + try (SimpleAsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(threadFactory)) { + Future> future = taskExecutor.submit(this::propagateRequestAttributes); + assertThat(future.get()).isEqualTo(expectedContextAttributes); + } + } + + private Map propagateRequestAttributes() { + RequestAttributes requestAttributes = new ServletRequestAttributes(this.servletRequest, this.servletResponse); + RequestContextHolder.setRequestAttributes(requestAttributes); + + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(this.authentication); + SecurityContextHolder.setContext(securityContext); + + // @formatter:off + return Mono.deferContextual(Mono::just) + .filter((ctx) -> ctx.hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES)) + .map((ctx) -> ctx.>get(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES)) + .map((attributes) -> { + Map map = new HashMap<>(); + // Copy over items from lazily loaded map + Arrays.asList(HttpServletRequest.class, HttpServletResponse.class, Authentication.class) + .forEach((key) -> map.put(key, attributes.get(key))); + return map; + }) + .block(); + // @formatter:on + } + @Configuration @EnableWebSecurity static class SecurityConfig { diff --git a/core/src/test/java/org/springframework/security/core/context/ReactiveSecurityContextHolderTests.java b/core/src/test/java/org/springframework/security/core/context/ReactiveSecurityContextHolderTests.java index df8b7eef09..12b997de5f 100644 --- a/core/src/test/java/org/springframework/security/core/context/ReactiveSecurityContextHolderTests.java +++ b/core/src/test/java/org/springframework/security/core/context/ReactiveSecurityContextHolderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,17 @@ package org.springframework.security.core.context; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; + import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnJre; +import org.junit.jupiter.api.condition.JRE; import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; +import org.springframework.core.task.VirtualThreadTaskExecutor; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; @@ -99,4 +106,53 @@ public class ReactiveSecurityContextHolderTests { // @formatter:on } + @Test + public void getContextWhenThreadFactoryIsPlatformThenPropagated() { + verifySecurityContextIsPropagated(Executors.defaultThreadFactory()); + } + + @Test + @DisabledOnJre(JRE.JAVA_17) + public void getContextWhenThreadFactoryIsVirtualThenPropagated() { + verifySecurityContextIsPropagated(new VirtualThreadTaskExecutor().getVirtualThreadFactory()); + } + + private static void verifySecurityContextIsPropagated(ThreadFactory threadFactory) { + Authentication authentication = new TestingAuthenticationToken("user", null); + + // @formatter:off + Mono publisher = ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication) + .contextWrite((context) -> ReactiveSecurityContextHolder.withAuthentication(authentication)) + .subscribeOn(Schedulers.newSingle(threadFactory)); + // @formatter:on + + StepVerifier.create(publisher).expectNext(authentication).verifyComplete(); + } + + @Test + public void clearContextWhenThreadFactoryIsPlatformThenCleared() { + verifySecurityContextIsCleared(Executors.defaultThreadFactory()); + } + + @Test + @DisabledOnJre(JRE.JAVA_17) + public void clearContextWhenThreadFactoryIsVirtualThenCleared() { + verifySecurityContextIsCleared(new VirtualThreadTaskExecutor().getVirtualThreadFactory()); + } + + private static void verifySecurityContextIsCleared(ThreadFactory threadFactory) { + Authentication authentication = new TestingAuthenticationToken("user", null); + + // @formatter:off + Mono publisher = ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication) + .contextWrite(ReactiveSecurityContextHolder.clearContext()) + .contextWrite((context) -> ReactiveSecurityContextHolder.withAuthentication(authentication)) + .subscribeOn(Schedulers.newSingle(threadFactory)); + // @formatter:on + + StepVerifier.create(publisher).verifyComplete(); + } + } diff --git a/web/src/test/java/org/springframework/security/web/server/context/ReactorContextWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/context/ReactorContextWebFilterTests.java index 73273b8292..5213f9c75b 100644 --- a/web/src/test/java/org/springframework/security/web/server/context/ReactorContextWebFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/server/context/ReactorContextWebFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,17 +17,23 @@ package org.springframework.security.web.server.context; import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnJre; +import org.junit.jupiter.api.condition.JRE; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; import reactor.test.publisher.TestPublisher; import reactor.util.context.Context; +import org.springframework.core.task.VirtualThreadTaskExecutor; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.core.Authentication; @@ -117,4 +123,32 @@ public class ReactorContextWebFilterTests { StepVerifier.create(filter).expectAccessibleContext().hasKey(contextKey).then().verifyComplete(); } + @Test + public void filterWhenThreadFactoryIsPlatformThenSecurityContextLoaded() { + ThreadFactory threadFactory = Executors.defaultThreadFactory(); + assertSecurityContextLoaded(threadFactory); + } + + @Test + @DisabledOnJre(JRE.JAVA_17) + public void filterWhenThreadFactoryIsVirtualThenSecurityContextLoaded() { + ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory(); + assertSecurityContextLoaded(threadFactory); + } + + private void assertSecurityContextLoaded(ThreadFactory threadFactory) { + SecurityContextImpl context = new SecurityContextImpl(this.principal); + given(this.repository.load(any())).willReturn(Mono.just(context)); + // @formatter:off + WebFilter subscribeOnThreadFactory = (exchange, chain) -> chain.filter(exchange) + .subscribeOn(Schedulers.newSingle(threadFactory)); + WebFilter assertSecurityContext = (exchange, chain) -> ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication) + .doOnSuccess((authentication) -> assertThat(authentication).isSameAs(this.principal)) + .then(chain.filter(exchange)); + // @formatter:on + this.handler = WebTestHandler.bindToWebFilters(subscribeOnThreadFactory, this.filter, assertSecurityContext); + this.handler.exchange(this.exchange); + } + } diff --git a/web/src/test/java/org/springframework/security/web/server/context/SecurityContextServerWebExchangeWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/context/SecurityContextServerWebExchangeWebFilterTests.java index caf12af6e4..6203287628 100644 --- a/web/src/test/java/org/springframework/security/web/server/context/SecurityContextServerWebExchangeWebFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/server/context/SecurityContextServerWebExchangeWebFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,17 +17,25 @@ package org.springframework.security.web.server.context; import java.util.Collections; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnJre; +import org.junit.jupiter.api.condition.JRE; import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; +import org.springframework.core.task.VirtualThreadTaskExecutor; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.test.web.reactive.server.WebTestHandler; import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; import org.springframework.web.server.handler.DefaultWebFilterChain; import static org.assertj.core.api.Assertions.assertThat; @@ -80,4 +88,31 @@ public class SecurityContextServerWebExchangeWebFilterTests { StepVerifier.create(result).verifyComplete(); } + @Test + public void filterWhenThreadFactoryIsPlatformThenContextPopulated() { + ThreadFactory threadFactory = Executors.defaultThreadFactory(); + assertPrincipalPopulated(threadFactory); + } + + @Test + @DisabledOnJre(JRE.JAVA_17) + public void filterWhenThreadFactoryIsVirtualThenContextPopulated() { + ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory(); + assertPrincipalPopulated(threadFactory); + } + + private void assertPrincipalPopulated(ThreadFactory threadFactory) { + // @formatter:off + WebFilter subscribeOnThreadFactory = (exchange, chain) -> chain.filter(exchange) + .contextWrite(ReactiveSecurityContextHolder.withAuthentication(this.principal)) + .subscribeOn(Schedulers.newSingle(threadFactory)); + WebFilter assertPrincipal = (exchange, chain) -> exchange.getPrincipal() + .doOnSuccess((principal) -> assertThat(principal).isSameAs(this.principal)) + .then(chain.filter(exchange)); + // @formatter:on + WebTestHandler handler = WebTestHandler.bindToWebFilters(subscribeOnThreadFactory, this.filter, + assertPrincipal); + handler.exchange(this.exchange); + } + }