Verify ReactorContext when using Virtual Threads

Closes gh-12791
This commit is contained in:
Steve Riesenberg 2023-09-21 10:27:25 -05:00
parent d6fac11bfe
commit ff374935fb
No known key found for this signature in database
GPG Key ID: 5F311AB48A55D521
4 changed files with 190 additions and 4 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2022 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,14 +17,20 @@
package org.springframework.security.config.annotation.web.configuration; package org.springframework.security.config.annotation.web.configuration;
import java.net.URI; import java.net.URI;
import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledOnJre;
import org.junit.jupiter.api.condition.JRE;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import reactor.core.CoreSubscriber; import reactor.core.CoreSubscriber;
import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.BaseSubscriber;
@ -35,6 +41,8 @@ import reactor.util.context.Context;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.core.task.VirtualThreadTaskExecutor;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
@ -46,6 +54,7 @@ import org.springframework.security.config.annotation.web.configuration.Security
import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension; import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction; import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
@ -271,6 +280,58 @@ public class SecurityReactorContextConfigurationTests {
verify(strategy, times(2)).getContext(); verify(strategy, times(2)).getContext();
} }
@Test
public void createPublisherWhenThreadFactoryIsPlatformThenSecurityContextAttributesAvailable() throws Exception {
this.spring.register(SecurityConfig.class).autowire();
ThreadFactory threadFactory = Executors.defaultThreadFactory();
assertContextAttributesAvailable(threadFactory);
}
@Test
@DisabledOnJre(JRE.JAVA_17)
public void createPublisherWhenThreadFactoryIsVirtualThenSecurityContextAttributesAvailable() throws Exception {
this.spring.register(SecurityConfig.class).autowire();
ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
assertContextAttributesAvailable(threadFactory);
}
private void assertContextAttributesAvailable(ThreadFactory threadFactory) throws Exception {
Map<Object, Object> expectedContextAttributes = new HashMap<>();
expectedContextAttributes.put(HttpServletRequest.class, this.servletRequest);
expectedContextAttributes.put(HttpServletResponse.class, this.servletResponse);
expectedContextAttributes.put(Authentication.class, this.authentication);
try (SimpleAsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(threadFactory)) {
Future<Map<Object, Object>> future = taskExecutor.submit(this::propagateRequestAttributes);
assertThat(future.get()).isEqualTo(expectedContextAttributes);
}
}
private Map<Object, Object> propagateRequestAttributes() {
RequestAttributes requestAttributes = new ServletRequestAttributes(this.servletRequest, this.servletResponse);
RequestContextHolder.setRequestAttributes(requestAttributes);
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
securityContext.setAuthentication(this.authentication);
SecurityContextHolder.setContext(securityContext);
// @formatter:off
return Mono.deferContextual(Mono::just)
.filter((ctx) -> ctx.hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
.map((ctx) -> ctx.<Map<Object, Object>>get(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
.map((attributes) -> {
Map<Object, Object> map = new HashMap<>();
// Copy over items from lazily loaded map
Arrays.asList(HttpServletRequest.class, HttpServletResponse.class, Authentication.class)
.forEach((key) -> map.put(key, attributes.get(key)));
return map;
})
.block();
// @formatter:on
}
@Configuration @Configuration
@EnableWebSecurity @EnableWebSecurity
static class SecurityConfig { static class SecurityConfig {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2017 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,10 +16,17 @@
package org.springframework.security.core.context; package org.springframework.security.core.context;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledOnJre;
import org.junit.jupiter.api.condition.JRE;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import reactor.test.StepVerifier; import reactor.test.StepVerifier;
import org.springframework.core.task.VirtualThreadTaskExecutor;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@ -99,4 +106,53 @@ public class ReactiveSecurityContextHolderTests {
// @formatter:on // @formatter:on
} }
@Test
public void getContextWhenThreadFactoryIsPlatformThenPropagated() {
verifySecurityContextIsPropagated(Executors.defaultThreadFactory());
}
@Test
@DisabledOnJre(JRE.JAVA_17)
public void getContextWhenThreadFactoryIsVirtualThenPropagated() {
verifySecurityContextIsPropagated(new VirtualThreadTaskExecutor().getVirtualThreadFactory());
}
private static void verifySecurityContextIsPropagated(ThreadFactory threadFactory) {
Authentication authentication = new TestingAuthenticationToken("user", null);
// @formatter:off
Mono<Authentication> publisher = ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.contextWrite((context) -> ReactiveSecurityContextHolder.withAuthentication(authentication))
.subscribeOn(Schedulers.newSingle(threadFactory));
// @formatter:on
StepVerifier.create(publisher).expectNext(authentication).verifyComplete();
}
@Test
public void clearContextWhenThreadFactoryIsPlatformThenCleared() {
verifySecurityContextIsCleared(Executors.defaultThreadFactory());
}
@Test
@DisabledOnJre(JRE.JAVA_17)
public void clearContextWhenThreadFactoryIsVirtualThenCleared() {
verifySecurityContextIsCleared(new VirtualThreadTaskExecutor().getVirtualThreadFactory());
}
private static void verifySecurityContextIsCleared(ThreadFactory threadFactory) {
Authentication authentication = new TestingAuthenticationToken("user", null);
// @formatter:off
Mono<Authentication> publisher = ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.contextWrite(ReactiveSecurityContextHolder.clearContext())
.contextWrite((context) -> ReactiveSecurityContextHolder.withAuthentication(authentication))
.subscribeOn(Schedulers.newSingle(threadFactory));
// @formatter:on
StepVerifier.create(publisher).verifyComplete();
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2017 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,17 +17,23 @@
package org.springframework.security.web.server.context; package org.springframework.security.web.server.context;
import java.util.List; import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledOnJre;
import org.junit.jupiter.api.condition.JRE;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import reactor.test.StepVerifier; import reactor.test.StepVerifier;
import reactor.test.publisher.TestPublisher; import reactor.test.publisher.TestPublisher;
import reactor.util.context.Context; import reactor.util.context.Context;
import org.springframework.core.task.VirtualThreadTaskExecutor;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@ -117,4 +123,32 @@ public class ReactorContextWebFilterTests {
StepVerifier.create(filter).expectAccessibleContext().hasKey(contextKey).then().verifyComplete(); StepVerifier.create(filter).expectAccessibleContext().hasKey(contextKey).then().verifyComplete();
} }
@Test
public void filterWhenThreadFactoryIsPlatformThenSecurityContextLoaded() {
ThreadFactory threadFactory = Executors.defaultThreadFactory();
assertSecurityContextLoaded(threadFactory);
}
@Test
@DisabledOnJre(JRE.JAVA_17)
public void filterWhenThreadFactoryIsVirtualThenSecurityContextLoaded() {
ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
assertSecurityContextLoaded(threadFactory);
}
private void assertSecurityContextLoaded(ThreadFactory threadFactory) {
SecurityContextImpl context = new SecurityContextImpl(this.principal);
given(this.repository.load(any())).willReturn(Mono.just(context));
// @formatter:off
WebFilter subscribeOnThreadFactory = (exchange, chain) -> chain.filter(exchange)
.subscribeOn(Schedulers.newSingle(threadFactory));
WebFilter assertSecurityContext = (exchange, chain) -> ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.doOnSuccess((authentication) -> assertThat(authentication).isSameAs(this.principal))
.then(chain.filter(exchange));
// @formatter:on
this.handler = WebTestHandler.bindToWebFilters(subscribeOnThreadFactory, this.filter, assertSecurityContext);
this.handler.exchange(this.exchange);
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2017 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,17 +17,25 @@
package org.springframework.security.web.server.context; package org.springframework.security.web.server.context;
import java.util.Collections; import java.util.Collections;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledOnJre;
import org.junit.jupiter.api.condition.JRE;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import reactor.test.StepVerifier; import reactor.test.StepVerifier;
import org.springframework.core.task.VirtualThreadTaskExecutor;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.test.web.reactive.server.WebTestHandler;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.handler.DefaultWebFilterChain; import org.springframework.web.server.handler.DefaultWebFilterChain;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -80,4 +88,31 @@ public class SecurityContextServerWebExchangeWebFilterTests {
StepVerifier.create(result).verifyComplete(); StepVerifier.create(result).verifyComplete();
} }
@Test
public void filterWhenThreadFactoryIsPlatformThenContextPopulated() {
ThreadFactory threadFactory = Executors.defaultThreadFactory();
assertPrincipalPopulated(threadFactory);
}
@Test
@DisabledOnJre(JRE.JAVA_17)
public void filterWhenThreadFactoryIsVirtualThenContextPopulated() {
ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
assertPrincipalPopulated(threadFactory);
}
private void assertPrincipalPopulated(ThreadFactory threadFactory) {
// @formatter:off
WebFilter subscribeOnThreadFactory = (exchange, chain) -> chain.filter(exchange)
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(this.principal))
.subscribeOn(Schedulers.newSingle(threadFactory));
WebFilter assertPrincipal = (exchange, chain) -> exchange.getPrincipal()
.doOnSuccess((principal) -> assertThat(principal).isSameAs(this.principal))
.then(chain.filter(exchange));
// @formatter:on
WebTestHandler handler = WebTestHandler.bindToWebFilters(subscribeOnThreadFactory, this.filter,
assertPrincipal);
handler.exchange(this.exchange);
}
} }