Verify ReactorContext when using Virtual Threads
Closes gh-12791
This commit is contained in:
parent
d6fac11bfe
commit
ff374935fb
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2022 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,14 +17,20 @@
|
|||
package org.springframework.security.config.annotation.web.configuration;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.ThreadFactory;
|
||||
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.DisabledOnJre;
|
||||
import org.junit.jupiter.api.condition.JRE;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import reactor.core.CoreSubscriber;
|
||||
import reactor.core.publisher.BaseSubscriber;
|
||||
|
@ -35,6 +41,8 @@ import reactor.util.context.Context;
|
|||
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.core.task.SimpleAsyncTaskExecutor;
|
||||
import org.springframework.core.task.VirtualThreadTaskExecutor;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
|
@ -46,6 +54,7 @@ import org.springframework.security.config.annotation.web.configuration.Security
|
|||
import org.springframework.security.config.test.SpringTestContext;
|
||||
import org.springframework.security.config.test.SpringTestContextExtension;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.context.SecurityContext;
|
||||
import org.springframework.security.core.context.SecurityContextHolder;
|
||||
import org.springframework.security.core.context.SecurityContextHolderStrategy;
|
||||
import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
|
||||
|
@ -271,6 +280,58 @@ public class SecurityReactorContextConfigurationTests {
|
|||
verify(strategy, times(2)).getContext();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createPublisherWhenThreadFactoryIsPlatformThenSecurityContextAttributesAvailable() throws Exception {
|
||||
this.spring.register(SecurityConfig.class).autowire();
|
||||
|
||||
ThreadFactory threadFactory = Executors.defaultThreadFactory();
|
||||
assertContextAttributesAvailable(threadFactory);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisabledOnJre(JRE.JAVA_17)
|
||||
public void createPublisherWhenThreadFactoryIsVirtualThenSecurityContextAttributesAvailable() throws Exception {
|
||||
this.spring.register(SecurityConfig.class).autowire();
|
||||
|
||||
ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
|
||||
assertContextAttributesAvailable(threadFactory);
|
||||
}
|
||||
|
||||
private void assertContextAttributesAvailable(ThreadFactory threadFactory) throws Exception {
|
||||
Map<Object, Object> expectedContextAttributes = new HashMap<>();
|
||||
expectedContextAttributes.put(HttpServletRequest.class, this.servletRequest);
|
||||
expectedContextAttributes.put(HttpServletResponse.class, this.servletResponse);
|
||||
expectedContextAttributes.put(Authentication.class, this.authentication);
|
||||
|
||||
try (SimpleAsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(threadFactory)) {
|
||||
Future<Map<Object, Object>> future = taskExecutor.submit(this::propagateRequestAttributes);
|
||||
assertThat(future.get()).isEqualTo(expectedContextAttributes);
|
||||
}
|
||||
}
|
||||
|
||||
private Map<Object, Object> propagateRequestAttributes() {
|
||||
RequestAttributes requestAttributes = new ServletRequestAttributes(this.servletRequest, this.servletResponse);
|
||||
RequestContextHolder.setRequestAttributes(requestAttributes);
|
||||
|
||||
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
|
||||
securityContext.setAuthentication(this.authentication);
|
||||
SecurityContextHolder.setContext(securityContext);
|
||||
|
||||
// @formatter:off
|
||||
return Mono.deferContextual(Mono::just)
|
||||
.filter((ctx) -> ctx.hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
|
||||
.map((ctx) -> ctx.<Map<Object, Object>>get(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
|
||||
.map((attributes) -> {
|
||||
Map<Object, Object> map = new HashMap<>();
|
||||
// Copy over items from lazily loaded map
|
||||
Arrays.asList(HttpServletRequest.class, HttpServletResponse.class, Authentication.class)
|
||||
.forEach((key) -> map.put(key, attributes.get(key)));
|
||||
return map;
|
||||
})
|
||||
.block();
|
||||
// @formatter:on
|
||||
}
|
||||
|
||||
@Configuration
|
||||
@EnableWebSecurity
|
||||
static class SecurityConfig {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2017 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,10 +16,17 @@
|
|||
|
||||
package org.springframework.security.core.context;
|
||||
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.ThreadFactory;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.DisabledOnJre;
|
||||
import org.junit.jupiter.api.condition.JRE;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
import reactor.test.StepVerifier;
|
||||
|
||||
import org.springframework.core.task.VirtualThreadTaskExecutor;
|
||||
import org.springframework.security.authentication.TestingAuthenticationToken;
|
||||
import org.springframework.security.core.Authentication;
|
||||
|
||||
|
@ -99,4 +106,53 @@ public class ReactiveSecurityContextHolderTests {
|
|||
// @formatter:on
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getContextWhenThreadFactoryIsPlatformThenPropagated() {
|
||||
verifySecurityContextIsPropagated(Executors.defaultThreadFactory());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisabledOnJre(JRE.JAVA_17)
|
||||
public void getContextWhenThreadFactoryIsVirtualThenPropagated() {
|
||||
verifySecurityContextIsPropagated(new VirtualThreadTaskExecutor().getVirtualThreadFactory());
|
||||
}
|
||||
|
||||
private static void verifySecurityContextIsPropagated(ThreadFactory threadFactory) {
|
||||
Authentication authentication = new TestingAuthenticationToken("user", null);
|
||||
|
||||
// @formatter:off
|
||||
Mono<Authentication> publisher = ReactiveSecurityContextHolder.getContext()
|
||||
.map(SecurityContext::getAuthentication)
|
||||
.contextWrite((context) -> ReactiveSecurityContextHolder.withAuthentication(authentication))
|
||||
.subscribeOn(Schedulers.newSingle(threadFactory));
|
||||
// @formatter:on
|
||||
|
||||
StepVerifier.create(publisher).expectNext(authentication).verifyComplete();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void clearContextWhenThreadFactoryIsPlatformThenCleared() {
|
||||
verifySecurityContextIsCleared(Executors.defaultThreadFactory());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisabledOnJre(JRE.JAVA_17)
|
||||
public void clearContextWhenThreadFactoryIsVirtualThenCleared() {
|
||||
verifySecurityContextIsCleared(new VirtualThreadTaskExecutor().getVirtualThreadFactory());
|
||||
}
|
||||
|
||||
private static void verifySecurityContextIsCleared(ThreadFactory threadFactory) {
|
||||
Authentication authentication = new TestingAuthenticationToken("user", null);
|
||||
|
||||
// @formatter:off
|
||||
Mono<Authentication> publisher = ReactiveSecurityContextHolder.getContext()
|
||||
.map(SecurityContext::getAuthentication)
|
||||
.contextWrite(ReactiveSecurityContextHolder.clearContext())
|
||||
.contextWrite((context) -> ReactiveSecurityContextHolder.withAuthentication(authentication))
|
||||
.subscribeOn(Schedulers.newSingle(threadFactory));
|
||||
// @formatter:on
|
||||
|
||||
StepVerifier.create(publisher).verifyComplete();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2017 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,17 +17,23 @@
|
|||
package org.springframework.security.web.server.context;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.ThreadFactory;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.DisabledOnJre;
|
||||
import org.junit.jupiter.api.condition.JRE;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
import reactor.test.StepVerifier;
|
||||
import reactor.test.publisher.TestPublisher;
|
||||
import reactor.util.context.Context;
|
||||
|
||||
import org.springframework.core.task.VirtualThreadTaskExecutor;
|
||||
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
|
||||
import org.springframework.mock.web.server.MockServerWebExchange;
|
||||
import org.springframework.security.core.Authentication;
|
||||
|
@ -117,4 +123,32 @@ public class ReactorContextWebFilterTests {
|
|||
StepVerifier.create(filter).expectAccessibleContext().hasKey(contextKey).then().verifyComplete();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void filterWhenThreadFactoryIsPlatformThenSecurityContextLoaded() {
|
||||
ThreadFactory threadFactory = Executors.defaultThreadFactory();
|
||||
assertSecurityContextLoaded(threadFactory);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisabledOnJre(JRE.JAVA_17)
|
||||
public void filterWhenThreadFactoryIsVirtualThenSecurityContextLoaded() {
|
||||
ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
|
||||
assertSecurityContextLoaded(threadFactory);
|
||||
}
|
||||
|
||||
private void assertSecurityContextLoaded(ThreadFactory threadFactory) {
|
||||
SecurityContextImpl context = new SecurityContextImpl(this.principal);
|
||||
given(this.repository.load(any())).willReturn(Mono.just(context));
|
||||
// @formatter:off
|
||||
WebFilter subscribeOnThreadFactory = (exchange, chain) -> chain.filter(exchange)
|
||||
.subscribeOn(Schedulers.newSingle(threadFactory));
|
||||
WebFilter assertSecurityContext = (exchange, chain) -> ReactiveSecurityContextHolder.getContext()
|
||||
.map(SecurityContext::getAuthentication)
|
||||
.doOnSuccess((authentication) -> assertThat(authentication).isSameAs(this.principal))
|
||||
.then(chain.filter(exchange));
|
||||
// @formatter:on
|
||||
this.handler = WebTestHandler.bindToWebFilters(subscribeOnThreadFactory, this.filter, assertSecurityContext);
|
||||
this.handler.exchange(this.exchange);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2017 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,17 +17,25 @@
|
|||
package org.springframework.security.web.server.context;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.ThreadFactory;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.DisabledOnJre;
|
||||
import org.junit.jupiter.api.condition.JRE;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
import reactor.test.StepVerifier;
|
||||
|
||||
import org.springframework.core.task.VirtualThreadTaskExecutor;
|
||||
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
|
||||
import org.springframework.mock.web.server.MockServerWebExchange;
|
||||
import org.springframework.security.authentication.TestingAuthenticationToken;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
||||
import org.springframework.security.test.web.reactive.server.WebTestHandler;
|
||||
import org.springframework.web.server.ServerWebExchange;
|
||||
import org.springframework.web.server.WebFilter;
|
||||
import org.springframework.web.server.handler.DefaultWebFilterChain;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
@ -80,4 +88,31 @@ public class SecurityContextServerWebExchangeWebFilterTests {
|
|||
StepVerifier.create(result).verifyComplete();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void filterWhenThreadFactoryIsPlatformThenContextPopulated() {
|
||||
ThreadFactory threadFactory = Executors.defaultThreadFactory();
|
||||
assertPrincipalPopulated(threadFactory);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisabledOnJre(JRE.JAVA_17)
|
||||
public void filterWhenThreadFactoryIsVirtualThenContextPopulated() {
|
||||
ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
|
||||
assertPrincipalPopulated(threadFactory);
|
||||
}
|
||||
|
||||
private void assertPrincipalPopulated(ThreadFactory threadFactory) {
|
||||
// @formatter:off
|
||||
WebFilter subscribeOnThreadFactory = (exchange, chain) -> chain.filter(exchange)
|
||||
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(this.principal))
|
||||
.subscribeOn(Schedulers.newSingle(threadFactory));
|
||||
WebFilter assertPrincipal = (exchange, chain) -> exchange.getPrincipal()
|
||||
.doOnSuccess((principal) -> assertThat(principal).isSameAs(this.principal))
|
||||
.then(chain.filter(exchange));
|
||||
// @formatter:on
|
||||
WebTestHandler handler = WebTestHandler.bindToWebFilters(subscribeOnThreadFactory, this.filter,
|
||||
assertPrincipal);
|
||||
handler.exchange(this.exchange);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue