Pick up SecurityContextHolderStrategy for WebClient integration

Issue gh-11061
This commit is contained in:
Josh Cummings 2022-06-21 17:13:31 -06:00
parent 27de315e5e
commit e8723f1f43
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
3 changed files with 90 additions and 14 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -37,10 +37,13 @@ import reactor.util.context.Context;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
@ -62,24 +65,37 @@ import org.springframework.web.context.request.ServletRequestAttributes;
@Configuration(proxyBeanMethods = false)
class SecurityReactorContextConfiguration {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
@Bean
SecurityReactorContextSubscriberRegistrar securityReactorContextSubscriberRegistrar() {
return new SecurityReactorContextSubscriberRegistrar();
SecurityReactorContextSubscriberRegistrar registrar = new SecurityReactorContextSubscriberRegistrar();
registrar.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
return registrar;
}
@Autowired(required = false)
void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
static class SecurityReactorContextSubscriberRegistrar implements InitializingBean, DisposableBean {
private static final String SECURITY_REACTOR_CONTEXT_OPERATOR_KEY = "org.springframework.security.SECURITY_REACTOR_CONTEXT_OPERATOR";
private static final Map<Object, Supplier<Object>> CONTEXT_ATTRIBUTE_VALUE_LOADERS = new HashMap<>();
private final Map<Object, Supplier<Object>> CONTEXT_ATTRIBUTE_VALUE_LOADERS = new HashMap<>();
static {
CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletRequest.class,
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
SecurityReactorContextSubscriberRegistrar() {
this.CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletRequest.class,
SecurityReactorContextSubscriberRegistrar::getRequest);
CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletResponse.class,
this.CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletResponse.class,
SecurityReactorContextSubscriberRegistrar::getResponse);
CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(Authentication.class,
SecurityReactorContextSubscriberRegistrar::getAuthentication);
this.CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(Authentication.class, this::getAuthentication);
}
@Override
@ -94,6 +110,11 @@ class SecurityReactorContextConfiguration {
Hooks.resetOnLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY);
}
void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
<T> CoreSubscriber<T> createSubscriberIfNecessary(CoreSubscriber<T> delegate) {
if (delegate.currentContext().hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES)) {
// Already enriched. No need to create Subscriber so return original
@ -102,8 +123,8 @@ class SecurityReactorContextConfiguration {
return new SecurityReactorContextSubscriber<>(delegate, getContextAttributes());
}
private static Map<Object, Object> getContextAttributes() {
return new LoadingMap<>(CONTEXT_ATTRIBUTE_VALUE_LOADERS);
private Map<Object, Object> getContextAttributes() {
return new LoadingMap<>(this.CONTEXT_ATTRIBUTE_VALUE_LOADERS);
}
private static HttpServletRequest getRequest() {
@ -124,8 +145,8 @@ class SecurityReactorContextConfiguration {
return null;
}
private static Authentication getAuthentication() {
return SecurityContextHolder.getContext().getAuthentication();
private Authentication getAuthentication() {
return this.securityContextHolderStrategy.getContext().getAuthentication();
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,9 +29,11 @@ import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthentication;
import org.springframework.security.oauth2.server.resource.authentication.TestBearerTokenAuthentications;
import org.springframework.security.oauth2.server.resource.web.reactive.function.client.ServletBearerExchangeFilterFunction;
@ -41,6 +43,8 @@ import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.reactive.function.client.WebClient;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.verify;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
@ -86,6 +90,21 @@ public class SecurityReactorContextConfigurationResourceServerTests {
// @formatter:on
}
@Test
public void requestWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
BearerTokenAuthentication authentication = TestBearerTokenAuthentications.bearer();
this.spring.register(BearerFilterConfig.class, WebServerConfig.class, Controller.class,
SecurityContextChangedListenerConfig.class).autowire();
MockHttpServletRequestBuilder authenticatedRequest = get("/token").with(authentication(authentication));
// @formatter:off
this.mockMvc.perform(authenticatedRequest)
.andExpect(status().isOk())
.andExpect(content().string("Bearer token"));
// @formatter:on
SecurityContextHolderStrategy strategy = this.spring.getContext().getBean(SecurityContextHolderStrategy.class);
verify(strategy, atLeastOnce()).getContext();
}
@EnableWebSecurity
static class BearerFilterConfig extends WebSecurityConfigurerAdapter {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -39,12 +39,14 @@ import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber;
import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
@ -55,6 +57,8 @@ import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.entry;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link SecurityReactorContextConfiguration}.
@ -233,6 +237,38 @@ public class SecurityReactorContextConfigurationTests {
// @formatter:on
}
@Test
public void createPublisherWhenCustomSecurityContextHolderStrategyThenUses() {
this.spring.register(SecurityConfig.class, SecurityContextChangedListenerConfig.class).autowire();
SecurityContextHolderStrategy strategy = this.spring.getContext().getBean(SecurityContextHolderStrategy.class);
strategy.getContext().setAuthentication(this.authentication);
ClientResponse clientResponseOk = ClientResponse.create(HttpStatus.OK).build();
// @formatter:off
ExchangeFilterFunction filter = (req, next) -> Mono.deferContextual(Mono::just)
.filter((ctx) -> ctx.hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
.map((ctx) -> ctx.get(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
.cast(Map.class)
.map((attributes) -> clientResponseOk);
// @formatter:on
ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build();
MockExchangeFunction exchange = new MockExchangeFunction();
Map<Object, Object> expectedContextAttributes = new HashMap<>();
expectedContextAttributes.put(HttpServletRequest.class, null);
expectedContextAttributes.put(HttpServletResponse.class, null);
expectedContextAttributes.put(Authentication.class, this.authentication);
Mono<ClientResponse> clientResponseMono = filter.filter(clientRequest, exchange)
.flatMap((response) -> filter.filter(clientRequest, exchange));
// @formatter:off
StepVerifier.create(clientResponseMono)
.expectAccessibleContext()
.contains(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES, expectedContextAttributes)
.then()
.expectNext(clientResponseOk)
.verifyComplete();
// @formatter:on
verify(strategy, times(2)).getContext();
}
@EnableWebSecurity
static class SecurityConfig extends WebSecurityConfigurerAdapter {