Verify ReactorContext when using Virtual Threads
Closes gh-12791
This commit is contained in:
parent
d6fac11bfe
commit
ff374935fb
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
* Copyright 2002-2022 the original author or authors.
|
* Copyright 2002-2023 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -17,14 +17,20 @@
|
||||||
package org.springframework.security.config.annotation.web.configuration;
|
package org.springframework.security.config.annotation.web.configuration;
|
||||||
|
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.Future;
|
||||||
|
import java.util.concurrent.ThreadFactory;
|
||||||
|
|
||||||
import jakarta.servlet.http.HttpServletRequest;
|
import jakarta.servlet.http.HttpServletRequest;
|
||||||
import jakarta.servlet.http.HttpServletResponse;
|
import jakarta.servlet.http.HttpServletResponse;
|
||||||
import org.junit.jupiter.api.AfterEach;
|
import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.condition.DisabledOnJre;
|
||||||
|
import org.junit.jupiter.api.condition.JRE;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
import reactor.core.CoreSubscriber;
|
import reactor.core.CoreSubscriber;
|
||||||
import reactor.core.publisher.BaseSubscriber;
|
import reactor.core.publisher.BaseSubscriber;
|
||||||
|
@ -35,6 +41,8 @@ import reactor.util.context.Context;
|
||||||
|
|
||||||
import org.springframework.context.annotation.Bean;
|
import org.springframework.context.annotation.Bean;
|
||||||
import org.springframework.context.annotation.Configuration;
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
import org.springframework.core.task.SimpleAsyncTaskExecutor;
|
||||||
|
import org.springframework.core.task.VirtualThreadTaskExecutor;
|
||||||
import org.springframework.http.HttpMethod;
|
import org.springframework.http.HttpMethod;
|
||||||
import org.springframework.http.HttpStatus;
|
import org.springframework.http.HttpStatus;
|
||||||
import org.springframework.mock.web.MockHttpServletRequest;
|
import org.springframework.mock.web.MockHttpServletRequest;
|
||||||
|
@ -46,6 +54,7 @@ import org.springframework.security.config.annotation.web.configuration.Security
|
||||||
import org.springframework.security.config.test.SpringTestContext;
|
import org.springframework.security.config.test.SpringTestContext;
|
||||||
import org.springframework.security.config.test.SpringTestContextExtension;
|
import org.springframework.security.config.test.SpringTestContextExtension;
|
||||||
import org.springframework.security.core.Authentication;
|
import org.springframework.security.core.Authentication;
|
||||||
|
import org.springframework.security.core.context.SecurityContext;
|
||||||
import org.springframework.security.core.context.SecurityContextHolder;
|
import org.springframework.security.core.context.SecurityContextHolder;
|
||||||
import org.springframework.security.core.context.SecurityContextHolderStrategy;
|
import org.springframework.security.core.context.SecurityContextHolderStrategy;
|
||||||
import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
|
import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
|
||||||
|
@ -271,6 +280,58 @@ public class SecurityReactorContextConfigurationTests {
|
||||||
verify(strategy, times(2)).getContext();
|
verify(strategy, times(2)).getContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void createPublisherWhenThreadFactoryIsPlatformThenSecurityContextAttributesAvailable() throws Exception {
|
||||||
|
this.spring.register(SecurityConfig.class).autowire();
|
||||||
|
|
||||||
|
ThreadFactory threadFactory = Executors.defaultThreadFactory();
|
||||||
|
assertContextAttributesAvailable(threadFactory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisabledOnJre(JRE.JAVA_17)
|
||||||
|
public void createPublisherWhenThreadFactoryIsVirtualThenSecurityContextAttributesAvailable() throws Exception {
|
||||||
|
this.spring.register(SecurityConfig.class).autowire();
|
||||||
|
|
||||||
|
ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
|
||||||
|
assertContextAttributesAvailable(threadFactory);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertContextAttributesAvailable(ThreadFactory threadFactory) throws Exception {
|
||||||
|
Map<Object, Object> expectedContextAttributes = new HashMap<>();
|
||||||
|
expectedContextAttributes.put(HttpServletRequest.class, this.servletRequest);
|
||||||
|
expectedContextAttributes.put(HttpServletResponse.class, this.servletResponse);
|
||||||
|
expectedContextAttributes.put(Authentication.class, this.authentication);
|
||||||
|
|
||||||
|
try (SimpleAsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(threadFactory)) {
|
||||||
|
Future<Map<Object, Object>> future = taskExecutor.submit(this::propagateRequestAttributes);
|
||||||
|
assertThat(future.get()).isEqualTo(expectedContextAttributes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<Object, Object> propagateRequestAttributes() {
|
||||||
|
RequestAttributes requestAttributes = new ServletRequestAttributes(this.servletRequest, this.servletResponse);
|
||||||
|
RequestContextHolder.setRequestAttributes(requestAttributes);
|
||||||
|
|
||||||
|
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
|
||||||
|
securityContext.setAuthentication(this.authentication);
|
||||||
|
SecurityContextHolder.setContext(securityContext);
|
||||||
|
|
||||||
|
// @formatter:off
|
||||||
|
return Mono.deferContextual(Mono::just)
|
||||||
|
.filter((ctx) -> ctx.hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
|
||||||
|
.map((ctx) -> ctx.<Map<Object, Object>>get(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
|
||||||
|
.map((attributes) -> {
|
||||||
|
Map<Object, Object> map = new HashMap<>();
|
||||||
|
// Copy over items from lazily loaded map
|
||||||
|
Arrays.asList(HttpServletRequest.class, HttpServletResponse.class, Authentication.class)
|
||||||
|
.forEach((key) -> map.put(key, attributes.get(key)));
|
||||||
|
return map;
|
||||||
|
})
|
||||||
|
.block();
|
||||||
|
// @formatter:on
|
||||||
|
}
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
@EnableWebSecurity
|
@EnableWebSecurity
|
||||||
static class SecurityConfig {
|
static class SecurityConfig {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
* Copyright 2002-2017 the original author or authors.
|
* Copyright 2002-2023 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -16,10 +16,17 @@
|
||||||
|
|
||||||
package org.springframework.security.core.context;
|
package org.springframework.security.core.context;
|
||||||
|
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.ThreadFactory;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.condition.DisabledOnJre;
|
||||||
|
import org.junit.jupiter.api.condition.JRE;
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
|
import reactor.core.scheduler.Schedulers;
|
||||||
import reactor.test.StepVerifier;
|
import reactor.test.StepVerifier;
|
||||||
|
|
||||||
|
import org.springframework.core.task.VirtualThreadTaskExecutor;
|
||||||
import org.springframework.security.authentication.TestingAuthenticationToken;
|
import org.springframework.security.authentication.TestingAuthenticationToken;
|
||||||
import org.springframework.security.core.Authentication;
|
import org.springframework.security.core.Authentication;
|
||||||
|
|
||||||
|
@ -99,4 +106,53 @@ public class ReactiveSecurityContextHolderTests {
|
||||||
// @formatter:on
|
// @formatter:on
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void getContextWhenThreadFactoryIsPlatformThenPropagated() {
|
||||||
|
verifySecurityContextIsPropagated(Executors.defaultThreadFactory());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisabledOnJre(JRE.JAVA_17)
|
||||||
|
public void getContextWhenThreadFactoryIsVirtualThenPropagated() {
|
||||||
|
verifySecurityContextIsPropagated(new VirtualThreadTaskExecutor().getVirtualThreadFactory());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void verifySecurityContextIsPropagated(ThreadFactory threadFactory) {
|
||||||
|
Authentication authentication = new TestingAuthenticationToken("user", null);
|
||||||
|
|
||||||
|
// @formatter:off
|
||||||
|
Mono<Authentication> publisher = ReactiveSecurityContextHolder.getContext()
|
||||||
|
.map(SecurityContext::getAuthentication)
|
||||||
|
.contextWrite((context) -> ReactiveSecurityContextHolder.withAuthentication(authentication))
|
||||||
|
.subscribeOn(Schedulers.newSingle(threadFactory));
|
||||||
|
// @formatter:on
|
||||||
|
|
||||||
|
StepVerifier.create(publisher).expectNext(authentication).verifyComplete();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void clearContextWhenThreadFactoryIsPlatformThenCleared() {
|
||||||
|
verifySecurityContextIsCleared(Executors.defaultThreadFactory());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisabledOnJre(JRE.JAVA_17)
|
||||||
|
public void clearContextWhenThreadFactoryIsVirtualThenCleared() {
|
||||||
|
verifySecurityContextIsCleared(new VirtualThreadTaskExecutor().getVirtualThreadFactory());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void verifySecurityContextIsCleared(ThreadFactory threadFactory) {
|
||||||
|
Authentication authentication = new TestingAuthenticationToken("user", null);
|
||||||
|
|
||||||
|
// @formatter:off
|
||||||
|
Mono<Authentication> publisher = ReactiveSecurityContextHolder.getContext()
|
||||||
|
.map(SecurityContext::getAuthentication)
|
||||||
|
.contextWrite(ReactiveSecurityContextHolder.clearContext())
|
||||||
|
.contextWrite((context) -> ReactiveSecurityContextHolder.withAuthentication(authentication))
|
||||||
|
.subscribeOn(Schedulers.newSingle(threadFactory));
|
||||||
|
// @formatter:on
|
||||||
|
|
||||||
|
StepVerifier.create(publisher).verifyComplete();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
* Copyright 2002-2017 the original author or authors.
|
* Copyright 2002-2023 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -17,17 +17,23 @@
|
||||||
package org.springframework.security.web.server.context;
|
package org.springframework.security.web.server.context;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.ThreadFactory;
|
||||||
|
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.condition.DisabledOnJre;
|
||||||
|
import org.junit.jupiter.api.condition.JRE;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
import org.mockito.Mock;
|
import org.mockito.Mock;
|
||||||
import org.mockito.junit.jupiter.MockitoExtension;
|
import org.mockito.junit.jupiter.MockitoExtension;
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
|
import reactor.core.scheduler.Schedulers;
|
||||||
import reactor.test.StepVerifier;
|
import reactor.test.StepVerifier;
|
||||||
import reactor.test.publisher.TestPublisher;
|
import reactor.test.publisher.TestPublisher;
|
||||||
import reactor.util.context.Context;
|
import reactor.util.context.Context;
|
||||||
|
|
||||||
|
import org.springframework.core.task.VirtualThreadTaskExecutor;
|
||||||
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
|
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
|
||||||
import org.springframework.mock.web.server.MockServerWebExchange;
|
import org.springframework.mock.web.server.MockServerWebExchange;
|
||||||
import org.springframework.security.core.Authentication;
|
import org.springframework.security.core.Authentication;
|
||||||
|
@ -117,4 +123,32 @@ public class ReactorContextWebFilterTests {
|
||||||
StepVerifier.create(filter).expectAccessibleContext().hasKey(contextKey).then().verifyComplete();
|
StepVerifier.create(filter).expectAccessibleContext().hasKey(contextKey).then().verifyComplete();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void filterWhenThreadFactoryIsPlatformThenSecurityContextLoaded() {
|
||||||
|
ThreadFactory threadFactory = Executors.defaultThreadFactory();
|
||||||
|
assertSecurityContextLoaded(threadFactory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisabledOnJre(JRE.JAVA_17)
|
||||||
|
public void filterWhenThreadFactoryIsVirtualThenSecurityContextLoaded() {
|
||||||
|
ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
|
||||||
|
assertSecurityContextLoaded(threadFactory);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertSecurityContextLoaded(ThreadFactory threadFactory) {
|
||||||
|
SecurityContextImpl context = new SecurityContextImpl(this.principal);
|
||||||
|
given(this.repository.load(any())).willReturn(Mono.just(context));
|
||||||
|
// @formatter:off
|
||||||
|
WebFilter subscribeOnThreadFactory = (exchange, chain) -> chain.filter(exchange)
|
||||||
|
.subscribeOn(Schedulers.newSingle(threadFactory));
|
||||||
|
WebFilter assertSecurityContext = (exchange, chain) -> ReactiveSecurityContextHolder.getContext()
|
||||||
|
.map(SecurityContext::getAuthentication)
|
||||||
|
.doOnSuccess((authentication) -> assertThat(authentication).isSameAs(this.principal))
|
||||||
|
.then(chain.filter(exchange));
|
||||||
|
// @formatter:on
|
||||||
|
this.handler = WebTestHandler.bindToWebFilters(subscribeOnThreadFactory, this.filter, assertSecurityContext);
|
||||||
|
this.handler.exchange(this.exchange);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
* Copyright 2002-2017 the original author or authors.
|
* Copyright 2002-2023 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -17,17 +17,25 @@
|
||||||
package org.springframework.security.web.server.context;
|
package org.springframework.security.web.server.context;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.ThreadFactory;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.condition.DisabledOnJre;
|
||||||
|
import org.junit.jupiter.api.condition.JRE;
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
|
import reactor.core.scheduler.Schedulers;
|
||||||
import reactor.test.StepVerifier;
|
import reactor.test.StepVerifier;
|
||||||
|
|
||||||
|
import org.springframework.core.task.VirtualThreadTaskExecutor;
|
||||||
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
|
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
|
||||||
import org.springframework.mock.web.server.MockServerWebExchange;
|
import org.springframework.mock.web.server.MockServerWebExchange;
|
||||||
import org.springframework.security.authentication.TestingAuthenticationToken;
|
import org.springframework.security.authentication.TestingAuthenticationToken;
|
||||||
import org.springframework.security.core.Authentication;
|
import org.springframework.security.core.Authentication;
|
||||||
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
||||||
|
import org.springframework.security.test.web.reactive.server.WebTestHandler;
|
||||||
import org.springframework.web.server.ServerWebExchange;
|
import org.springframework.web.server.ServerWebExchange;
|
||||||
|
import org.springframework.web.server.WebFilter;
|
||||||
import org.springframework.web.server.handler.DefaultWebFilterChain;
|
import org.springframework.web.server.handler.DefaultWebFilterChain;
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
@ -80,4 +88,31 @@ public class SecurityContextServerWebExchangeWebFilterTests {
|
||||||
StepVerifier.create(result).verifyComplete();
|
StepVerifier.create(result).verifyComplete();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void filterWhenThreadFactoryIsPlatformThenContextPopulated() {
|
||||||
|
ThreadFactory threadFactory = Executors.defaultThreadFactory();
|
||||||
|
assertPrincipalPopulated(threadFactory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisabledOnJre(JRE.JAVA_17)
|
||||||
|
public void filterWhenThreadFactoryIsVirtualThenContextPopulated() {
|
||||||
|
ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
|
||||||
|
assertPrincipalPopulated(threadFactory);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertPrincipalPopulated(ThreadFactory threadFactory) {
|
||||||
|
// @formatter:off
|
||||||
|
WebFilter subscribeOnThreadFactory = (exchange, chain) -> chain.filter(exchange)
|
||||||
|
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(this.principal))
|
||||||
|
.subscribeOn(Schedulers.newSingle(threadFactory));
|
||||||
|
WebFilter assertPrincipal = (exchange, chain) -> exchange.getPrincipal()
|
||||||
|
.doOnSuccess((principal) -> assertThat(principal).isSameAs(this.principal))
|
||||||
|
.then(chain.filter(exchange));
|
||||||
|
// @formatter:on
|
||||||
|
WebTestHandler handler = WebTestHandler.bindToWebFilters(subscribeOnThreadFactory, this.filter,
|
||||||
|
assertPrincipal);
|
||||||
|
handler.exchange(this.exchange);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue