Align Servlet ExchangeFilterFunction CoreSubscriber

Fixes gh-7422
This commit is contained in:
Joe Grandja 2019-09-24 20:35:03 -04:00
parent d17cbe4e59
commit 2a5bd6e719
6 changed files with 426 additions and 301 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,23 +15,28 @@
*/
package org.springframework.security.config.annotation.web.configuration;
import java.util.ArrayList;
import java.util.List;
import org.springframework.context.annotation.ImportSelector;
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.util.ClassUtils;
import java.util.ArrayList;
import java.util.List;
/**
* Used by {@link EnableWebSecurity} to conditionally import {@link OAuth2ClientConfiguration}
* when the {@code spring-security-oauth2-client} module is present on the classpath and
* {@link OAuth2ResourceServerConfiguration} when the {@code spring-security-oauth2-resource-server}
* module is on the classpath.
* Used by {@link EnableWebSecurity} to conditionally import:
*
* <ul>
* <li>{@link OAuth2ClientConfiguration} when the {@code spring-security-oauth2-client} module is present on the classpath</li>
* <li>{@link SecurityReactorContextConfiguration} when the {@code spring-webflux} and {@code spring-security-oauth2-client} module is present on the classpath</li>
* <li>{@link OAuth2ResourceServerConfiguration} when the {@code spring-security-oauth2-resource-server} module is present on the classpath</li>
* </ul>
*
* @author Joe Grandja
* @author Josh Cummings
* @since 5.1
* @see OAuth2ClientConfiguration
* @see SecurityReactorContextConfiguration
* @see OAuth2ResourceServerConfiguration
*/
final class OAuth2ImportSelector implements ImportSelector {
@ -39,13 +44,20 @@ final class OAuth2ImportSelector implements ImportSelector {
public String[] selectImports(AnnotationMetadata importingClassMetadata) {
List<String> imports = new ArrayList<>();
if (ClassUtils.isPresent(
"org.springframework.security.oauth2.client.registration.ClientRegistration", getClass().getClassLoader())) {
boolean oauth2ClientPresent = ClassUtils.isPresent(
"org.springframework.security.oauth2.client.registration.ClientRegistration", getClass().getClassLoader());
if (oauth2ClientPresent) {
imports.add("org.springframework.security.config.annotation.web.configuration.OAuth2ClientConfiguration");
}
boolean webfluxPresent = ClassUtils.isPresent(
"org.springframework.web.reactive.function.client.ExchangeFilterFunction", getClass().getClassLoader());
if (webfluxPresent && oauth2ClientPresent) {
imports.add("org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration");
}
if (ClassUtils.isPresent(
"org.springframework.security.oauth2.server.resource.BearerTokenError", getClass().getClassLoader())) {
"org.springframework.security.oauth2.server.resource.BearerTokenError", getClass().getClassLoader())) {
imports.add("org.springframework.security.config.annotation.web.configuration.OAuth2ResourceServerConfiguration");
}

View File

@ -0,0 +1,165 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.config.annotation.web.configuration;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.util.CollectionUtils;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Hooks;
import reactor.core.publisher.Operators;
import reactor.util.context.Context;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import static org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES;
/**
* {@link Configuration} that (potentially) adds a "decorating" {@code Publisher}
* for the last operator created in every {@code Mono} or {@code Flux}.
*
* <p>
* The {@code Publisher} is solely responsible for adding
* the current {@code HttpServletRequest}, {@code HttpServletResponse} and {@code Authentication}
* to the Reactor {@code Context} so that it's accessible in every flow, if required.
*
* @author Joe Grandja
* @since 5.2
* @see OAuth2ImportSelector
*/
@Configuration(proxyBeanMethods = false)
class SecurityReactorContextConfiguration {
@Bean
SecurityReactorContextSubscriberRegistrar securityReactorContextSubscriberRegistrar() {
return new SecurityReactorContextSubscriberRegistrar();
}
static class SecurityReactorContextSubscriberRegistrar implements InitializingBean, DisposableBean {
private static final String SECURITY_REACTOR_CONTEXT_OPERATOR_KEY = "org.springframework.security.SECURITY_REACTOR_CONTEXT_OPERATOR";
@Override
public void afterPropertiesSet() throws Exception {
Function<? super Publisher<Object>, ? extends Publisher<Object>> lifter =
Operators.liftPublisher((pub, sub) -> createSubscriberIfNecessary(sub));
Hooks.onLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY, pub -> {
if (CollectionUtils.isEmpty(getContextAttributes())) {
// No need to decorate so return original Publisher
return pub;
}
return lifter.apply(pub);
});
}
@Override
public void destroy() throws Exception {
Hooks.resetOnLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY);
}
<T> CoreSubscriber<T> createSubscriberIfNecessary(CoreSubscriber<T> delegate) {
if (delegate.currentContext().hasKey(SECURITY_CONTEXT_ATTRIBUTES)) {
// Already enriched. No need to create Subscriber so return original
return delegate;
}
return new SecurityReactorContextSubscriber<>(delegate, getContextAttributes());
}
private static Map<Object, Object> getContextAttributes() {
HttpServletRequest servletRequest = null;
HttpServletResponse servletResponse = null;
ServletRequestAttributes requestAttributes =
(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if (requestAttributes != null) {
servletRequest = requestAttributes.getRequest();
servletResponse = requestAttributes.getResponse();
}
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication == null && servletRequest == null && servletResponse == null) {
return Collections.emptyMap();
}
Map<Object, Object> contextAttributes = new HashMap<>();
if (servletRequest != null) {
contextAttributes.put(HttpServletRequest.class, servletRequest);
}
if (servletResponse != null) {
contextAttributes.put(HttpServletResponse.class, servletResponse);
}
if (authentication != null) {
contextAttributes.put(Authentication.class, authentication);
}
return contextAttributes;
}
}
static class SecurityReactorContextSubscriber<T> implements CoreSubscriber<T> {
static final String SECURITY_CONTEXT_ATTRIBUTES = "org.springframework.security.SECURITY_CONTEXT_ATTRIBUTES";
private final CoreSubscriber<T> delegate;
private final Context context;
SecurityReactorContextSubscriber(CoreSubscriber<T> delegate, Map<Object, Object> attributes) {
this.delegate = delegate;
Context currentContext = this.delegate.currentContext();
Context context;
if (currentContext.hasKey(SECURITY_CONTEXT_ATTRIBUTES)) {
context = currentContext;
} else {
context = currentContext.put(SECURITY_CONTEXT_ATTRIBUTES, attributes);
}
this.context = context;
}
@Override
public Context currentContext() {
return this.context;
}
@Override
public void onSubscribe(Subscription s) {
this.delegate.onSubscribe(s);
}
@Override
public void onNext(T t) {
this.delegate.onNext(t);
}
@Override
public void onError(Throwable t) {
this.delegate.onError(t);
}
@Override
public void onComplete() {
this.delegate.onComplete();
}
}
}

View File

@ -0,0 +1,195 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.config.annotation.web.configuration;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import reactor.util.context.Context;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.net.URI;
import java.util.HashMap;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.entry;
import static org.springframework.http.HttpMethod.GET;
import static org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES;
/**
* Tests for {@link SecurityReactorContextConfiguration}.
*
* @author Joe Grandja
* @since 5.2
*/
public class SecurityReactorContextConfigurationTests {
private MockHttpServletRequest servletRequest;
private MockHttpServletResponse servletResponse;
private Authentication authentication;
private SecurityReactorContextConfiguration.SecurityReactorContextSubscriberRegistrar subscriberRegistrar =
new SecurityReactorContextConfiguration.SecurityReactorContextSubscriberRegistrar();
@Rule
public final SpringTestRule spring = new SpringTestRule();
@Before
public void setup() {
this.servletRequest = new MockHttpServletRequest();
this.servletResponse = new MockHttpServletResponse();
this.authentication = new TestingAuthenticationToken("principal", "password");
}
@After
public void cleanup() {
SecurityContextHolder.clearContext();
RequestContextHolder.resetRequestAttributes();
}
@Test
public void createSubscriberIfNecessaryWhenSubscriberContextContainsSecurityContextAttributesThenReturnOriginalSubscriber() {
Context context = Context.of(SECURITY_CONTEXT_ATTRIBUTES, new HashMap<>());
BaseSubscriber<Object> originalSubscriber = new BaseSubscriber<Object>() {
@Override
public Context currentContext() {
return context;
}
};
CoreSubscriber<Object> resultSubscriber = this.subscriberRegistrar.createSubscriberIfNecessary(originalSubscriber);
assertThat(resultSubscriber).isSameAs(originalSubscriber);
}
@Test
public void createSubscriberIfNecessaryWhenWebSecurityContextAvailableThenCreateWithParentContext() {
RequestContextHolder.setRequestAttributes(
new ServletRequestAttributes(this.servletRequest, this.servletResponse));
SecurityContextHolder.getContext().setAuthentication(this.authentication);
String testKey = "test_key";
String testValue = "test_value";
BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
@Override
public Context currentContext() {
return Context.of(testKey, testValue);
}
};
CoreSubscriber<Object> subscriber = this.subscriberRegistrar.createSubscriberIfNecessary(parent);
Context resultContext = subscriber.currentContext();
assertThat(resultContext.getOrEmpty(testKey)).hasValue(testValue);
Map<Object, Object> securityContextAttributes = resultContext.getOrDefault(SECURITY_CONTEXT_ATTRIBUTES, null);
assertThat(securityContextAttributes).hasSize(3);
assertThat(securityContextAttributes).contains(
entry(HttpServletRequest.class, this.servletRequest),
entry(HttpServletResponse.class, this.servletResponse),
entry(Authentication.class, this.authentication));
}
@Test
public void createSubscriberIfNecessaryWhenParentContextContainsSecurityContextAttributesThenUseParentContext() {
RequestContextHolder.setRequestAttributes(
new ServletRequestAttributes(this.servletRequest, this.servletResponse));
SecurityContextHolder.getContext().setAuthentication(this.authentication);
Context parentContext = Context.of(SECURITY_CONTEXT_ATTRIBUTES, new HashMap<>());
BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
@Override
public Context currentContext() {
return parentContext;
}
};
CoreSubscriber<Object> subscriber = this.subscriberRegistrar.createSubscriberIfNecessary(parent);
Context resultContext = subscriber.currentContext();
assertThat(resultContext).isSameAs(parentContext);
}
@Test
public void createPublisherWhenLastOperatorAddedThenSecurityContextAttributesAvailable() {
// Trigger the importing of SecurityReactorContextConfiguration via OAuth2ImportSelector
this.spring.register(SecurityConfig.class).autowire();
// Setup for SecurityReactorContextSubscriberRegistrar
RequestContextHolder.setRequestAttributes(
new ServletRequestAttributes(this.servletRequest, this.servletResponse));
SecurityContextHolder.getContext().setAuthentication(this.authentication);
ClientResponse clientResponseOk = ClientResponse.create(HttpStatus.OK).build();
ExchangeFilterFunction filter = (req, next) ->
Mono.subscriberContext()
.filter(ctx -> ctx.hasKey(SECURITY_CONTEXT_ATTRIBUTES))
.map(ctx -> ctx.get(SECURITY_CONTEXT_ATTRIBUTES))
.cast(Map.class)
.map(attributes -> {
if (attributes.containsKey(HttpServletRequest.class) &&
attributes.containsKey(HttpServletResponse.class) &&
attributes.containsKey(Authentication.class)) {
return clientResponseOk;
} else {
return ClientResponse.create(HttpStatus.NOT_FOUND).build();
}
});
ClientRequest clientRequest = ClientRequest.create(GET, URI.create("https://example.com")).build();
MockExchangeFunction exchange = new MockExchangeFunction();
Map<Object, Object> expectedContextAttributes = new HashMap<>();
expectedContextAttributes.put(HttpServletRequest.class, this.servletRequest);
expectedContextAttributes.put(HttpServletResponse.class, this.servletResponse);
expectedContextAttributes.put(Authentication.class, this.authentication);
Mono<ClientResponse> clientResponseMono = filter.filter(clientRequest, exchange)
.flatMap(response -> filter.filter(clientRequest, exchange));
StepVerifier.create(clientResponseMono)
.expectAccessibleContext()
.contains(SECURITY_CONTEXT_ATTRIBUTES, expectedContextAttributes)
.then()
.expectNext(clientResponseOk)
.verifyComplete();
}
@EnableWebSecurity
static class SecurityConfig extends WebSecurityConfigurerAdapter {
@Override
protected void configure(HttpSecurity http) throws Exception {
}
}
}

View File

@ -16,10 +16,6 @@
package org.springframework.security.oauth2.client.web.reactive.function.client;
import org.reactivestreams.Subscription;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.lang.Nullable;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
@ -47,10 +43,7 @@ import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import org.springframework.web.reactive.function.client.ExchangeFunction;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Hooks;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Operators;
import reactor.core.scheduler.Schedulers;
import reactor.util.context.Context;
@ -100,8 +93,10 @@ import java.util.function.Consumer;
* @since 5.1
* @see OAuth2AuthorizedClientManager
*/
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
implements ExchangeFilterFunction, InitializingBean, DisposableBean {
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
// Same key as in SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES
static final String SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY = "org.springframework.security.SECURITY_CONTEXT_ATTRIBUTES";
/**
* The request attribute name used to locate the {@link OAuth2AuthorizedClient}.
@ -112,8 +107,6 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName();
private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken(
"anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
@ -175,16 +168,6 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
return authorizedClientManager;
}
@Override
public void afterPropertiesSet() {
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.liftPublisher((s, sub) -> createRequestContextSubscriberIfNecessary(sub)));
}
@Override
public void destroy() {
Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY);
}
/**
* Sets the {@link OAuth2AccessTokenResponseClient} used for getting an {@link OAuth2AuthorizedClient} for the client_credentials grant.
*
@ -382,22 +365,22 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
}
private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
RequestContextDataHolder holder = RequestContextSubscriber.getRequestContext(ctx);
if (holder != null) {
HttpServletRequest request = holder.getRequest();
if (request != null) {
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
}
HttpServletResponse response = holder.getResponse();
if (response != null) {
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
}
Authentication authentication = holder.getAuthentication();
if (authentication != null) {
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
}
// NOTE: SecurityReactorContextConfiguration.SecurityReactorContextSubscriber adds this key
if (!ctx.hasKey(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY)) {
return;
}
Map<Object, Object> contextAttributes = ctx.get(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY);
HttpServletRequest servletRequest = (HttpServletRequest) contextAttributes.get(HttpServletRequest.class);
if (servletRequest != null) {
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, servletRequest);
}
HttpServletResponse servletResponse = (HttpServletResponse) contextAttributes.get(HttpServletResponse.class);
if (servletResponse != null) {
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, servletResponse);
}
Authentication authentication = (Authentication) contextAttributes.get(Authentication.class);
if (authentication != null) {
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
}
}
@ -503,23 +486,6 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
.build();
}
<T> CoreSubscriber<T> createRequestContextSubscriberIfNecessary(CoreSubscriber<T> delegate) {
HttpServletRequest request = null;
HttpServletResponse response = null;
ServletRequestAttributes requestAttributes =
(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if (requestAttributes != null) {
request = requestAttributes.getRequest();
response = requestAttributes.getResponse();
}
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication == null && request == null && response == null) {
//do not need to create RequestContextSubscriber with empty data
return delegate;
}
return new RequestContextSubscriber<>(delegate, request, response, authentication);
}
static OAuth2AuthorizedClient getOAuth2AuthorizedClient(Map<String, Object> attrs) {
return (OAuth2AuthorizedClient) attrs.get(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME);
}
@ -587,87 +553,4 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
return new UnsupportedOperationException("Not Supported");
}
}
static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
static final String REQUEST_CONTEXT_DATA_HOLDER =
RequestContextSubscriber.class.getName().concat(".REQUEST_CONTEXT_DATA_HOLDER");
private final CoreSubscriber<T> delegate;
private final Context context;
RequestContextSubscriber(CoreSubscriber<T> delegate,
HttpServletRequest request,
HttpServletResponse response,
Authentication authentication) {
this.delegate = delegate;
Context parentContext = this.delegate.currentContext();
Context context;
if (parentContext.hasKey(REQUEST_CONTEXT_DATA_HOLDER)) {
context = parentContext;
} else {
context = parentContext.put(REQUEST_CONTEXT_DATA_HOLDER, new RequestContextDataHolder(request, response, authentication));
}
this.context = context;
}
@Nullable
private static RequestContextDataHolder getRequestContext(Context ctx) {
return ctx.getOrDefault(REQUEST_CONTEXT_DATA_HOLDER, null);
}
@Override
public Context currentContext() {
return this.context;
}
@Override
public void onSubscribe(Subscription s) {
this.delegate.onSubscribe(s);
}
@Override
public void onNext(T t) {
this.delegate.onNext(t);
}
@Override
public void onError(Throwable t) {
this.delegate.onError(t);
}
@Override
public void onComplete() {
this.delegate.onComplete();
}
}
static class RequestContextDataHolder {
private final HttpServletRequest request;
private final HttpServletResponse response;
private final Authentication authentication;
RequestContextDataHolder(@Nullable HttpServletRequest request,
@Nullable HttpServletResponse response,
@Nullable Authentication authentication) {
this.request = request;
this.response = response;
this.authentication = authentication;
}
@Nullable
private HttpServletRequest getRequest() {
return this.request;
}
@Nullable
private HttpServletResponse getResponse() {
return this.response;
}
@Nullable
private Authentication getAuthentication() {
return this.authentication;
}
}
}

View File

@ -43,16 +43,20 @@ import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.blockhound.BlockHound;
import reactor.util.context.Context;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.*;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
/**
@ -104,7 +108,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
});
this.authorizedClientFilter = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
this.clientRegistrationRepository, this.authorizedClientRepository);
this.authorizedClientFilter.afterPropertiesSet();
this.server = new MockWebServer();
this.server.start();
this.serverUrl = this.server.url("/").toString();
@ -120,7 +123,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
@After
public void cleanup() throws Exception {
this.authorizedClientFilter.destroy();
this.server.shutdown();
SecurityContextHolder.clearContext();
RequestContextHolder.resetRequestAttributes();
@ -248,6 +250,7 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
.attributes(clientRegistrationId(clientRegistration2.getRegistrationId()))
.retrieve()
.bodyToMono(String.class))
.subscriberContext(context())
.block();
assertThat(this.server.getRequestCount()).isEqualTo(4);
@ -259,6 +262,14 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
assertThat(authorizedClientCaptor.getAllValues().get(1).getClientRegistration()).isSameAs(clientRegistration2);
}
private Context context() {
Map<Object, Object> contextAttributes = new HashMap<>();
contextAttributes.put(HttpServletRequest.class, this.request);
contextAttributes.put(HttpServletResponse.class, this.response);
contextAttributes.put(Authentication.class, this.authentication);
return Context.of(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY, contextAttributes);
}
private MockResponse jsonResponse(String json) {
return new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)

View File

@ -76,12 +76,10 @@ import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.net.URI;
import java.time.Duration;
import java.time.Instant;
@ -93,7 +91,6 @@ import java.util.Optional;
import java.util.function.Consumer;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.mockito.Mockito.*;
import static org.springframework.http.HttpMethod.GET;
@ -163,7 +160,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
public void cleanup() throws Exception {
SecurityContextHolder.clearContext();
RequestContextHolder.resetRequestAttributes();
this.function.destroy();
}
@Test
@ -591,18 +587,15 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
// gh-6483
@Test
public void filterWhenChainedThenDefaultsStillAvailable() throws Exception {
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
this.function.setDefaultOAuth2AuthorizedClient(true);
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
OAuth2User user = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
user, authorities, this.registration.getRegistrationId());
SecurityContextHolder.getContext().setAuthentication(authentication);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration, "principalName", this.accessToken);
@ -619,12 +612,13 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
// Default request attributes NOT set
final ClientRequest request2 = ClientRequest.create(GET, URI.create("https://example2.com")).build();
Context context = context(servletRequest, servletResponse, authentication);
this.function.filter(request1, this.exchange)
.flatMap(response -> this.function.filter(request2, this.exchange))
.subscriberContext(context)
.block();
this.function.destroy(); // Hooks.onLastOperator() released
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(2);
@ -641,147 +635,12 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
assertThat(getBody(request)).isEmpty();
}
@Test
public void filterWhenRequestAttributesNotSetAndHooksNotInitThenDefaultsNotAvailable() {
// this.function.afterPropertiesSet(); // Hooks.onLastOperator() NOT initialized
this.function.setDefaultOAuth2AuthorizedClient(true);
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
OAuth2User user = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
user, authorities, this.registration.getRegistrationId());
SecurityContextHolder.getContext().setAuthentication(authentication);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
this.function.filter(request, this.exchange).block();
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
request = requests.get(0);
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request)).isEmpty();
}
@Test
public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsNotAvailable() throws Exception {
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
this.function.destroy(); // Hooks.onLastOperator() released
this.function.setDefaultOAuth2AuthorizedClient(true);
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
OAuth2User user = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
user, authorities, this.registration.getRegistrationId());
SecurityContextHolder.getContext().setAuthentication(authentication);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
this.function.filter(request, this.exchange).block();
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
request = requests.get(0);
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request)).isEmpty();
}
// gh-7228
@Test
public void afterPropertiesSetWhenHooksInitAndOutsideWebSecurityContextThenShouldNotThrowException() throws Exception {
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
assertThatCode(() -> Mono.subscriberContext().block())
.as("RequestContext Hook brakes application outside of web/security context")
.doesNotThrowAnyException();
}
@Test
public void createRequestContextSubscriberIfNecessaryWhenOutsideWebSecurityContextThenReturnOriginalSubscriber() throws Exception {
BaseSubscriber<Object> originalSubscriber = new BaseSubscriber<Object>() {};
CoreSubscriber<Object> resultSubscriber = this.function.createRequestContextSubscriberIfNecessary(originalSubscriber);
assertThat(resultSubscriber).isSameAs(originalSubscriber);
}
// gh-7228
@Test
public void createRequestContextSubscriberWhenRequestResponseProvidedThenCreateWithParentContext() throws Exception {
testRequestContextSubscriber(new MockHttpServletRequest(), new MockHttpServletResponse(), null);
}
// gh-7228
@Test
public void createRequestContextSubscriberWhenAuthenticationProvidedThenCreateWithParentContext() throws Exception {
testRequestContextSubscriber(null, null, this.authentication);
}
@Test
public void createRequestContextSubscriberWhenParentContextHasDataHolderThenShouldReuseParentContext() throws Exception {
RequestContextDataHolder testValue = new RequestContextDataHolder(null, null, null);
final Context parentContext = Context.of(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, testValue);
BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
@Override
public Context currentContext() {
return parentContext;
}
};
RequestContextSubscriber<Object> requestContextSubscriber =
new RequestContextSubscriber<>(parent, null, null, authentication);
Context resultContext = requestContextSubscriber.currentContext();
assertThat(resultContext)
.describedAs("parent context was replaced")
.isSameAs(parentContext);
}
private void testRequestContextSubscriber(MockHttpServletRequest servletRequest,
MockHttpServletResponse servletResponse,
Authentication authentication) {
String testKey = "test_key";
String testValue = "test_value";
BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
@Override
public Context currentContext() {
return Context.of(testKey, testValue);
}
};
RequestContextSubscriber<Object> requestContextSubscriber =
new RequestContextSubscriber<>(parent, servletRequest, servletResponse, authentication);
Context resultContext = requestContextSubscriber.currentContext();
assertThat(resultContext)
.describedAs("result context is null")
.isNotNull();
assertThat(resultContext.getOrEmpty(testKey))
.describedAs("context is replaced")
.hasValue(testValue);
Object dataHolder = resultContext.getOrDefault(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, null);
assertThat(dataHolder)
.describedAs("context is not populated with REQUEST_CONTEXT_DATA_HOLDER")
.isNotNull()
.hasFieldOrPropertyWithValue("request", servletRequest)
.hasFieldOrPropertyWithValue("response", servletResponse)
.hasFieldOrPropertyWithValue("authentication", authentication);
private Context context(HttpServletRequest servletRequest, HttpServletResponse servletResponse, Authentication authentication) {
Map<Object, Object> contextAttributes = new HashMap<>();
contextAttributes.put(HttpServletRequest.class, servletRequest);
contextAttributes.put(HttpServletResponse.class, servletResponse);
contextAttributes.put(Authentication.class, authentication);
return Context.of(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY, contextAttributes);
}
private static String getBody(ClientRequest request) {