Pick up SecurityContextHolderStrategy for WebClient integration

Issue gh-11061
This commit is contained in:
Josh Cummings 2022-06-21 17:13:31 -06:00
parent a218d3e140
commit d24a89ad53
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
3 changed files with 90 additions and 14 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2021 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -36,10 +36,13 @@ import reactor.util.context.Context;
import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert;
import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.context.request.ServletRequestAttributes;
@ -61,24 +64,37 @@ import org.springframework.web.context.request.ServletRequestAttributes;
@Configuration(proxyBeanMethods = false) @Configuration(proxyBeanMethods = false)
class SecurityReactorContextConfiguration { class SecurityReactorContextConfiguration {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
@Bean @Bean
SecurityReactorContextSubscriberRegistrar securityReactorContextSubscriberRegistrar() { SecurityReactorContextSubscriberRegistrar securityReactorContextSubscriberRegistrar() {
return new SecurityReactorContextSubscriberRegistrar(); SecurityReactorContextSubscriberRegistrar registrar = new SecurityReactorContextSubscriberRegistrar();
registrar.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
return registrar;
}
@Autowired(required = false)
void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
} }
static class SecurityReactorContextSubscriberRegistrar implements InitializingBean, DisposableBean { static class SecurityReactorContextSubscriberRegistrar implements InitializingBean, DisposableBean {
private static final String SECURITY_REACTOR_CONTEXT_OPERATOR_KEY = "org.springframework.security.SECURITY_REACTOR_CONTEXT_OPERATOR"; private static final String SECURITY_REACTOR_CONTEXT_OPERATOR_KEY = "org.springframework.security.SECURITY_REACTOR_CONTEXT_OPERATOR";
private static final Map<Object, Supplier<Object>> CONTEXT_ATTRIBUTE_VALUE_LOADERS = new HashMap<>(); private final Map<Object, Supplier<Object>> CONTEXT_ATTRIBUTE_VALUE_LOADERS = new HashMap<>();
static { private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletRequest.class, .getContextHolderStrategy();
SecurityReactorContextSubscriberRegistrar() {
this.CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletRequest.class,
SecurityReactorContextSubscriberRegistrar::getRequest); SecurityReactorContextSubscriberRegistrar::getRequest);
CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletResponse.class, this.CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletResponse.class,
SecurityReactorContextSubscriberRegistrar::getResponse); SecurityReactorContextSubscriberRegistrar::getResponse);
CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(Authentication.class, this.CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(Authentication.class, this::getAuthentication);
SecurityReactorContextSubscriberRegistrar::getAuthentication);
} }
@Override @Override
@ -93,6 +109,11 @@ class SecurityReactorContextConfiguration {
Hooks.resetOnLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY); Hooks.resetOnLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY);
} }
void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
<T> CoreSubscriber<T> createSubscriberIfNecessary(CoreSubscriber<T> delegate) { <T> CoreSubscriber<T> createSubscriberIfNecessary(CoreSubscriber<T> delegate) {
if (delegate.currentContext().hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES)) { if (delegate.currentContext().hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES)) {
// Already enriched. No need to create Subscriber so return original // Already enriched. No need to create Subscriber so return original
@ -101,8 +122,8 @@ class SecurityReactorContextConfiguration {
return new SecurityReactorContextSubscriber<>(delegate, getContextAttributes()); return new SecurityReactorContextSubscriber<>(delegate, getContextAttributes());
} }
private static Map<Object, Object> getContextAttributes() { private Map<Object, Object> getContextAttributes() {
return new LoadingMap<>(CONTEXT_ATTRIBUTE_VALUE_LOADERS); return new LoadingMap<>(this.CONTEXT_ATTRIBUTE_VALUE_LOADERS);
} }
private static HttpServletRequest getRequest() { private static HttpServletRequest getRequest() {
@ -123,8 +144,8 @@ class SecurityReactorContextConfiguration {
return null; return null;
} }
private static Authentication getAuthentication() { private Authentication getAuthentication() {
return SecurityContextHolder.getContext().getAuthentication(); return this.securityContextHolderStrategy.getContext().getAuthentication();
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,9 +28,11 @@ import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig;
import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension; import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthentication; import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthentication;
import org.springframework.security.oauth2.server.resource.authentication.TestBearerTokenAuthentications; import org.springframework.security.oauth2.server.resource.authentication.TestBearerTokenAuthentications;
import org.springframework.security.oauth2.server.resource.web.reactive.function.client.ServletBearerExchangeFilterFunction; import org.springframework.security.oauth2.server.resource.web.reactive.function.client.ServletBearerExchangeFilterFunction;
@ -40,6 +42,8 @@ import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClient;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.verify;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
@ -85,6 +89,21 @@ public class SecurityReactorContextConfigurationResourceServerTests {
// @formatter:on // @formatter:on
} }
@Test
public void requestWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
BearerTokenAuthentication authentication = TestBearerTokenAuthentications.bearer();
this.spring.register(BearerFilterConfig.class, WebServerConfig.class, Controller.class,
SecurityContextChangedListenerConfig.class).autowire();
MockHttpServletRequestBuilder authenticatedRequest = get("/token").with(authentication(authentication));
// @formatter:off
this.mockMvc.perform(authenticatedRequest)
.andExpect(status().isOk())
.andExpect(content().string("Bearer token"));
// @formatter:on
SecurityContextHolderStrategy strategy = this.spring.getContext().getBean(SecurityContextHolderStrategy.class);
verify(strategy, atLeastOnce()).getContext();
}
@EnableWebSecurity @EnableWebSecurity
static class BearerFilterConfig extends WebSecurityConfigurerAdapter { static class BearerFilterConfig extends WebSecurityConfigurerAdapter {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -38,12 +38,14 @@ import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig;
import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber; import org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber;
import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension; import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction; import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.RequestContextHolder;
@ -54,6 +56,8 @@ import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.entry; import static org.assertj.core.api.Assertions.entry;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/** /**
* Tests for {@link SecurityReactorContextConfiguration}. * Tests for {@link SecurityReactorContextConfiguration}.
@ -232,6 +236,38 @@ public class SecurityReactorContextConfigurationTests {
// @formatter:on // @formatter:on
} }
@Test
public void createPublisherWhenCustomSecurityContextHolderStrategyThenUses() {
this.spring.register(SecurityConfig.class, SecurityContextChangedListenerConfig.class).autowire();
SecurityContextHolderStrategy strategy = this.spring.getContext().getBean(SecurityContextHolderStrategy.class);
strategy.getContext().setAuthentication(this.authentication);
ClientResponse clientResponseOk = ClientResponse.create(HttpStatus.OK).build();
// @formatter:off
ExchangeFilterFunction filter = (req, next) -> Mono.deferContextual(Mono::just)
.filter((ctx) -> ctx.hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
.map((ctx) -> ctx.get(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
.cast(Map.class)
.map((attributes) -> clientResponseOk);
// @formatter:on
ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build();
MockExchangeFunction exchange = new MockExchangeFunction();
Map<Object, Object> expectedContextAttributes = new HashMap<>();
expectedContextAttributes.put(HttpServletRequest.class, null);
expectedContextAttributes.put(HttpServletResponse.class, null);
expectedContextAttributes.put(Authentication.class, this.authentication);
Mono<ClientResponse> clientResponseMono = filter.filter(clientRequest, exchange)
.flatMap((response) -> filter.filter(clientRequest, exchange));
// @formatter:off
StepVerifier.create(clientResponseMono)
.expectAccessibleContext()
.contains(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES, expectedContextAttributes)
.then()
.expectNext(clientResponseOk)
.verifyComplete();
// @formatter:on
verify(strategy, times(2)).getContext();
}
@EnableWebSecurity @EnableWebSecurity
static class SecurityConfig extends WebSecurityConfigurerAdapter { static class SecurityConfig extends WebSecurityConfigurerAdapter {