mirror of
https://github.com/spring-projects/spring-security.git
synced 2025-05-31 09:12:14 +00:00
Align Servlet ExchangeFilterFunction CoreSubscriber
Fixes gh-7422
This commit is contained in:
parent
d17cbe4e59
commit
2a5bd6e719
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -15,23 +15,28 @@
|
||||
*/
|
||||
package org.springframework.security.config.annotation.web.configuration;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.springframework.context.annotation.ImportSelector;
|
||||
import org.springframework.core.type.AnnotationMetadata;
|
||||
import org.springframework.util.ClassUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Used by {@link EnableWebSecurity} to conditionally import {@link OAuth2ClientConfiguration}
|
||||
* when the {@code spring-security-oauth2-client} module is present on the classpath and
|
||||
* {@link OAuth2ResourceServerConfiguration} when the {@code spring-security-oauth2-resource-server}
|
||||
* module is on the classpath.
|
||||
* Used by {@link EnableWebSecurity} to conditionally import:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link OAuth2ClientConfiguration} when the {@code spring-security-oauth2-client} module is present on the classpath</li>
|
||||
* <li>{@link SecurityReactorContextConfiguration} when the {@code spring-webflux} and {@code spring-security-oauth2-client} module is present on the classpath</li>
|
||||
* <li>{@link OAuth2ResourceServerConfiguration} when the {@code spring-security-oauth2-resource-server} module is present on the classpath</li>
|
||||
* </ul>
|
||||
*
|
||||
* @author Joe Grandja
|
||||
* @author Josh Cummings
|
||||
* @since 5.1
|
||||
* @see OAuth2ClientConfiguration
|
||||
* @see SecurityReactorContextConfiguration
|
||||
* @see OAuth2ResourceServerConfiguration
|
||||
*/
|
||||
final class OAuth2ImportSelector implements ImportSelector {
|
||||
|
||||
@ -39,13 +44,20 @@ final class OAuth2ImportSelector implements ImportSelector {
|
||||
public String[] selectImports(AnnotationMetadata importingClassMetadata) {
|
||||
List<String> imports = new ArrayList<>();
|
||||
|
||||
if (ClassUtils.isPresent(
|
||||
"org.springframework.security.oauth2.client.registration.ClientRegistration", getClass().getClassLoader())) {
|
||||
boolean oauth2ClientPresent = ClassUtils.isPresent(
|
||||
"org.springframework.security.oauth2.client.registration.ClientRegistration", getClass().getClassLoader());
|
||||
if (oauth2ClientPresent) {
|
||||
imports.add("org.springframework.security.config.annotation.web.configuration.OAuth2ClientConfiguration");
|
||||
}
|
||||
|
||||
boolean webfluxPresent = ClassUtils.isPresent(
|
||||
"org.springframework.web.reactive.function.client.ExchangeFilterFunction", getClass().getClassLoader());
|
||||
if (webfluxPresent && oauth2ClientPresent) {
|
||||
imports.add("org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration");
|
||||
}
|
||||
|
||||
if (ClassUtils.isPresent(
|
||||
"org.springframework.security.oauth2.server.resource.BearerTokenError", getClass().getClassLoader())) {
|
||||
"org.springframework.security.oauth2.server.resource.BearerTokenError", getClass().getClassLoader())) {
|
||||
imports.add("org.springframework.security.config.annotation.web.configuration.OAuth2ResourceServerConfiguration");
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,165 @@
|
||||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.security.config.annotation.web.configuration;
|
||||
|
||||
import org.reactivestreams.Publisher;
|
||||
import org.reactivestreams.Subscription;
|
||||
import org.springframework.beans.factory.DisposableBean;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.context.SecurityContextHolder;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.web.context.request.RequestContextHolder;
|
||||
import org.springframework.web.context.request.ServletRequestAttributes;
|
||||
import reactor.core.CoreSubscriber;
|
||||
import reactor.core.publisher.Hooks;
|
||||
import reactor.core.publisher.Operators;
|
||||
import reactor.util.context.Context;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
|
||||
import static org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES;
|
||||
|
||||
/**
|
||||
* {@link Configuration} that (potentially) adds a "decorating" {@code Publisher}
|
||||
* for the last operator created in every {@code Mono} or {@code Flux}.
|
||||
*
|
||||
* <p>
|
||||
* The {@code Publisher} is solely responsible for adding
|
||||
* the current {@code HttpServletRequest}, {@code HttpServletResponse} and {@code Authentication}
|
||||
* to the Reactor {@code Context} so that it's accessible in every flow, if required.
|
||||
*
|
||||
* @author Joe Grandja
|
||||
* @since 5.2
|
||||
* @see OAuth2ImportSelector
|
||||
*/
|
||||
@Configuration(proxyBeanMethods = false)
|
||||
class SecurityReactorContextConfiguration {
|
||||
|
||||
@Bean
|
||||
SecurityReactorContextSubscriberRegistrar securityReactorContextSubscriberRegistrar() {
|
||||
return new SecurityReactorContextSubscriberRegistrar();
|
||||
}
|
||||
|
||||
static class SecurityReactorContextSubscriberRegistrar implements InitializingBean, DisposableBean {
|
||||
private static final String SECURITY_REACTOR_CONTEXT_OPERATOR_KEY = "org.springframework.security.SECURITY_REACTOR_CONTEXT_OPERATOR";
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() throws Exception {
|
||||
Function<? super Publisher<Object>, ? extends Publisher<Object>> lifter =
|
||||
Operators.liftPublisher((pub, sub) -> createSubscriberIfNecessary(sub));
|
||||
|
||||
Hooks.onLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY, pub -> {
|
||||
if (CollectionUtils.isEmpty(getContextAttributes())) {
|
||||
// No need to decorate so return original Publisher
|
||||
return pub;
|
||||
}
|
||||
return lifter.apply(pub);
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void destroy() throws Exception {
|
||||
Hooks.resetOnLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY);
|
||||
}
|
||||
|
||||
<T> CoreSubscriber<T> createSubscriberIfNecessary(CoreSubscriber<T> delegate) {
|
||||
if (delegate.currentContext().hasKey(SECURITY_CONTEXT_ATTRIBUTES)) {
|
||||
// Already enriched. No need to create Subscriber so return original
|
||||
return delegate;
|
||||
}
|
||||
return new SecurityReactorContextSubscriber<>(delegate, getContextAttributes());
|
||||
}
|
||||
|
||||
private static Map<Object, Object> getContextAttributes() {
|
||||
HttpServletRequest servletRequest = null;
|
||||
HttpServletResponse servletResponse = null;
|
||||
ServletRequestAttributes requestAttributes =
|
||||
(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
|
||||
if (requestAttributes != null) {
|
||||
servletRequest = requestAttributes.getRequest();
|
||||
servletResponse = requestAttributes.getResponse();
|
||||
}
|
||||
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
|
||||
if (authentication == null && servletRequest == null && servletResponse == null) {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
Map<Object, Object> contextAttributes = new HashMap<>();
|
||||
if (servletRequest != null) {
|
||||
contextAttributes.put(HttpServletRequest.class, servletRequest);
|
||||
}
|
||||
if (servletResponse != null) {
|
||||
contextAttributes.put(HttpServletResponse.class, servletResponse);
|
||||
}
|
||||
if (authentication != null) {
|
||||
contextAttributes.put(Authentication.class, authentication);
|
||||
}
|
||||
|
||||
return contextAttributes;
|
||||
}
|
||||
}
|
||||
|
||||
static class SecurityReactorContextSubscriber<T> implements CoreSubscriber<T> {
|
||||
static final String SECURITY_CONTEXT_ATTRIBUTES = "org.springframework.security.SECURITY_CONTEXT_ATTRIBUTES";
|
||||
private final CoreSubscriber<T> delegate;
|
||||
private final Context context;
|
||||
|
||||
SecurityReactorContextSubscriber(CoreSubscriber<T> delegate, Map<Object, Object> attributes) {
|
||||
this.delegate = delegate;
|
||||
Context currentContext = this.delegate.currentContext();
|
||||
Context context;
|
||||
if (currentContext.hasKey(SECURITY_CONTEXT_ATTRIBUTES)) {
|
||||
context = currentContext;
|
||||
} else {
|
||||
context = currentContext.put(SECURITY_CONTEXT_ATTRIBUTES, attributes);
|
||||
}
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Context currentContext() {
|
||||
return this.context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSubscribe(Subscription s) {
|
||||
this.delegate.onSubscribe(s);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNext(T t) {
|
||||
this.delegate.onNext(t);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable t) {
|
||||
this.delegate.onError(t);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
this.delegate.onComplete();
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,195 @@
|
||||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.security.config.annotation.web.configuration;
|
||||
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.mock.web.MockHttpServletResponse;
|
||||
import org.springframework.security.authentication.TestingAuthenticationToken;
|
||||
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
|
||||
import org.springframework.security.config.test.SpringTestRule;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.context.SecurityContextHolder;
|
||||
import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
|
||||
import org.springframework.web.context.request.RequestContextHolder;
|
||||
import org.springframework.web.context.request.ServletRequestAttributes;
|
||||
import org.springframework.web.reactive.function.client.ClientRequest;
|
||||
import org.springframework.web.reactive.function.client.ClientResponse;
|
||||
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
|
||||
import reactor.core.CoreSubscriber;
|
||||
import reactor.core.publisher.BaseSubscriber;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.test.StepVerifier;
|
||||
import reactor.util.context.Context;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.net.URI;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.entry;
|
||||
import static org.springframework.http.HttpMethod.GET;
|
||||
import static org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES;
|
||||
|
||||
/**
|
||||
* Tests for {@link SecurityReactorContextConfiguration}.
|
||||
*
|
||||
* @author Joe Grandja
|
||||
* @since 5.2
|
||||
*/
|
||||
public class SecurityReactorContextConfigurationTests {
|
||||
private MockHttpServletRequest servletRequest;
|
||||
private MockHttpServletResponse servletResponse;
|
||||
private Authentication authentication;
|
||||
private SecurityReactorContextConfiguration.SecurityReactorContextSubscriberRegistrar subscriberRegistrar =
|
||||
new SecurityReactorContextConfiguration.SecurityReactorContextSubscriberRegistrar();
|
||||
|
||||
@Rule
|
||||
public final SpringTestRule spring = new SpringTestRule();
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
this.servletRequest = new MockHttpServletRequest();
|
||||
this.servletResponse = new MockHttpServletResponse();
|
||||
this.authentication = new TestingAuthenticationToken("principal", "password");
|
||||
}
|
||||
|
||||
@After
|
||||
public void cleanup() {
|
||||
SecurityContextHolder.clearContext();
|
||||
RequestContextHolder.resetRequestAttributes();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createSubscriberIfNecessaryWhenSubscriberContextContainsSecurityContextAttributesThenReturnOriginalSubscriber() {
|
||||
Context context = Context.of(SECURITY_CONTEXT_ATTRIBUTES, new HashMap<>());
|
||||
BaseSubscriber<Object> originalSubscriber = new BaseSubscriber<Object>() {
|
||||
@Override
|
||||
public Context currentContext() {
|
||||
return context;
|
||||
}
|
||||
};
|
||||
CoreSubscriber<Object> resultSubscriber = this.subscriberRegistrar.createSubscriberIfNecessary(originalSubscriber);
|
||||
assertThat(resultSubscriber).isSameAs(originalSubscriber);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createSubscriberIfNecessaryWhenWebSecurityContextAvailableThenCreateWithParentContext() {
|
||||
RequestContextHolder.setRequestAttributes(
|
||||
new ServletRequestAttributes(this.servletRequest, this.servletResponse));
|
||||
SecurityContextHolder.getContext().setAuthentication(this.authentication);
|
||||
|
||||
String testKey = "test_key";
|
||||
String testValue = "test_value";
|
||||
|
||||
BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
|
||||
@Override
|
||||
public Context currentContext() {
|
||||
return Context.of(testKey, testValue);
|
||||
}
|
||||
};
|
||||
CoreSubscriber<Object> subscriber = this.subscriberRegistrar.createSubscriberIfNecessary(parent);
|
||||
|
||||
Context resultContext = subscriber.currentContext();
|
||||
|
||||
assertThat(resultContext.getOrEmpty(testKey)).hasValue(testValue);
|
||||
Map<Object, Object> securityContextAttributes = resultContext.getOrDefault(SECURITY_CONTEXT_ATTRIBUTES, null);
|
||||
assertThat(securityContextAttributes).hasSize(3);
|
||||
assertThat(securityContextAttributes).contains(
|
||||
entry(HttpServletRequest.class, this.servletRequest),
|
||||
entry(HttpServletResponse.class, this.servletResponse),
|
||||
entry(Authentication.class, this.authentication));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createSubscriberIfNecessaryWhenParentContextContainsSecurityContextAttributesThenUseParentContext() {
|
||||
RequestContextHolder.setRequestAttributes(
|
||||
new ServletRequestAttributes(this.servletRequest, this.servletResponse));
|
||||
SecurityContextHolder.getContext().setAuthentication(this.authentication);
|
||||
|
||||
Context parentContext = Context.of(SECURITY_CONTEXT_ATTRIBUTES, new HashMap<>());
|
||||
BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
|
||||
@Override
|
||||
public Context currentContext() {
|
||||
return parentContext;
|
||||
}
|
||||
};
|
||||
CoreSubscriber<Object> subscriber = this.subscriberRegistrar.createSubscriberIfNecessary(parent);
|
||||
|
||||
Context resultContext = subscriber.currentContext();
|
||||
assertThat(resultContext).isSameAs(parentContext);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createPublisherWhenLastOperatorAddedThenSecurityContextAttributesAvailable() {
|
||||
// Trigger the importing of SecurityReactorContextConfiguration via OAuth2ImportSelector
|
||||
this.spring.register(SecurityConfig.class).autowire();
|
||||
|
||||
// Setup for SecurityReactorContextSubscriberRegistrar
|
||||
RequestContextHolder.setRequestAttributes(
|
||||
new ServletRequestAttributes(this.servletRequest, this.servletResponse));
|
||||
SecurityContextHolder.getContext().setAuthentication(this.authentication);
|
||||
|
||||
ClientResponse clientResponseOk = ClientResponse.create(HttpStatus.OK).build();
|
||||
|
||||
ExchangeFilterFunction filter = (req, next) ->
|
||||
Mono.subscriberContext()
|
||||
.filter(ctx -> ctx.hasKey(SECURITY_CONTEXT_ATTRIBUTES))
|
||||
.map(ctx -> ctx.get(SECURITY_CONTEXT_ATTRIBUTES))
|
||||
.cast(Map.class)
|
||||
.map(attributes -> {
|
||||
if (attributes.containsKey(HttpServletRequest.class) &&
|
||||
attributes.containsKey(HttpServletResponse.class) &&
|
||||
attributes.containsKey(Authentication.class)) {
|
||||
return clientResponseOk;
|
||||
} else {
|
||||
return ClientResponse.create(HttpStatus.NOT_FOUND).build();
|
||||
}
|
||||
});
|
||||
|
||||
ClientRequest clientRequest = ClientRequest.create(GET, URI.create("https://example.com")).build();
|
||||
MockExchangeFunction exchange = new MockExchangeFunction();
|
||||
|
||||
Map<Object, Object> expectedContextAttributes = new HashMap<>();
|
||||
expectedContextAttributes.put(HttpServletRequest.class, this.servletRequest);
|
||||
expectedContextAttributes.put(HttpServletResponse.class, this.servletResponse);
|
||||
expectedContextAttributes.put(Authentication.class, this.authentication);
|
||||
|
||||
Mono<ClientResponse> clientResponseMono = filter.filter(clientRequest, exchange)
|
||||
.flatMap(response -> filter.filter(clientRequest, exchange));
|
||||
|
||||
StepVerifier.create(clientResponseMono)
|
||||
.expectAccessibleContext()
|
||||
.contains(SECURITY_CONTEXT_ATTRIBUTES, expectedContextAttributes)
|
||||
.then()
|
||||
.expectNext(clientResponseOk)
|
||||
.verifyComplete();
|
||||
}
|
||||
|
||||
@EnableWebSecurity
|
||||
static class SecurityConfig extends WebSecurityConfigurerAdapter {
|
||||
|
||||
@Override
|
||||
protected void configure(HttpSecurity http) throws Exception {
|
||||
}
|
||||
}
|
||||
}
|
@ -16,10 +16,6 @@
|
||||
|
||||
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
||||
|
||||
import org.reactivestreams.Subscription;
|
||||
import org.springframework.beans.factory.DisposableBean;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.security.authentication.AnonymousAuthenticationToken;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.GrantedAuthority;
|
||||
@ -47,10 +43,7 @@ import org.springframework.web.reactive.function.client.ClientResponse;
|
||||
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
|
||||
import org.springframework.web.reactive.function.client.ExchangeFunction;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
import reactor.core.CoreSubscriber;
|
||||
import reactor.core.publisher.Hooks;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.publisher.Operators;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
import reactor.util.context.Context;
|
||||
|
||||
@ -100,8 +93,10 @@ import java.util.function.Consumer;
|
||||
* @since 5.1
|
||||
* @see OAuth2AuthorizedClientManager
|
||||
*/
|
||||
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
||||
implements ExchangeFilterFunction, InitializingBean, DisposableBean {
|
||||
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
|
||||
|
||||
// Same key as in SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES
|
||||
static final String SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY = "org.springframework.security.SECURITY_CONTEXT_ATTRIBUTES";
|
||||
|
||||
/**
|
||||
* The request attribute name used to locate the {@link OAuth2AuthorizedClient}.
|
||||
@ -112,8 +107,6 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
||||
private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
|
||||
private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
|
||||
|
||||
private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName();
|
||||
|
||||
private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken(
|
||||
"anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
|
||||
|
||||
@ -175,16 +168,6 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
||||
return authorizedClientManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.liftPublisher((s, sub) -> createRequestContextSubscriberIfNecessary(sub)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void destroy() {
|
||||
Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@link OAuth2AccessTokenResponseClient} used for getting an {@link OAuth2AuthorizedClient} for the client_credentials grant.
|
||||
*
|
||||
@ -382,22 +365,22 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
||||
}
|
||||
|
||||
private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
|
||||
RequestContextDataHolder holder = RequestContextSubscriber.getRequestContext(ctx);
|
||||
if (holder != null) {
|
||||
HttpServletRequest request = holder.getRequest();
|
||||
if (request != null) {
|
||||
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
|
||||
}
|
||||
|
||||
HttpServletResponse response = holder.getResponse();
|
||||
if (response != null) {
|
||||
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
|
||||
}
|
||||
|
||||
Authentication authentication = holder.getAuthentication();
|
||||
if (authentication != null) {
|
||||
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
|
||||
}
|
||||
// NOTE: SecurityReactorContextConfiguration.SecurityReactorContextSubscriber adds this key
|
||||
if (!ctx.hasKey(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY)) {
|
||||
return;
|
||||
}
|
||||
Map<Object, Object> contextAttributes = ctx.get(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY);
|
||||
HttpServletRequest servletRequest = (HttpServletRequest) contextAttributes.get(HttpServletRequest.class);
|
||||
if (servletRequest != null) {
|
||||
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, servletRequest);
|
||||
}
|
||||
HttpServletResponse servletResponse = (HttpServletResponse) contextAttributes.get(HttpServletResponse.class);
|
||||
if (servletResponse != null) {
|
||||
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, servletResponse);
|
||||
}
|
||||
Authentication authentication = (Authentication) contextAttributes.get(Authentication.class);
|
||||
if (authentication != null) {
|
||||
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
|
||||
}
|
||||
}
|
||||
|
||||
@ -503,23 +486,6 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
||||
.build();
|
||||
}
|
||||
|
||||
<T> CoreSubscriber<T> createRequestContextSubscriberIfNecessary(CoreSubscriber<T> delegate) {
|
||||
HttpServletRequest request = null;
|
||||
HttpServletResponse response = null;
|
||||
ServletRequestAttributes requestAttributes =
|
||||
(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
|
||||
if (requestAttributes != null) {
|
||||
request = requestAttributes.getRequest();
|
||||
response = requestAttributes.getResponse();
|
||||
}
|
||||
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
|
||||
if (authentication == null && request == null && response == null) {
|
||||
//do not need to create RequestContextSubscriber with empty data
|
||||
return delegate;
|
||||
}
|
||||
return new RequestContextSubscriber<>(delegate, request, response, authentication);
|
||||
}
|
||||
|
||||
static OAuth2AuthorizedClient getOAuth2AuthorizedClient(Map<String, Object> attrs) {
|
||||
return (OAuth2AuthorizedClient) attrs.get(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME);
|
||||
}
|
||||
@ -587,87 +553,4 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
||||
return new UnsupportedOperationException("Not Supported");
|
||||
}
|
||||
}
|
||||
|
||||
static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
|
||||
static final String REQUEST_CONTEXT_DATA_HOLDER =
|
||||
RequestContextSubscriber.class.getName().concat(".REQUEST_CONTEXT_DATA_HOLDER");
|
||||
private final CoreSubscriber<T> delegate;
|
||||
private final Context context;
|
||||
|
||||
RequestContextSubscriber(CoreSubscriber<T> delegate,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response,
|
||||
Authentication authentication) {
|
||||
this.delegate = delegate;
|
||||
|
||||
Context parentContext = this.delegate.currentContext();
|
||||
Context context;
|
||||
if (parentContext.hasKey(REQUEST_CONTEXT_DATA_HOLDER)) {
|
||||
context = parentContext;
|
||||
} else {
|
||||
context = parentContext.put(REQUEST_CONTEXT_DATA_HOLDER, new RequestContextDataHolder(request, response, authentication));
|
||||
}
|
||||
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private static RequestContextDataHolder getRequestContext(Context ctx) {
|
||||
return ctx.getOrDefault(REQUEST_CONTEXT_DATA_HOLDER, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Context currentContext() {
|
||||
return this.context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSubscribe(Subscription s) {
|
||||
this.delegate.onSubscribe(s);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNext(T t) {
|
||||
this.delegate.onNext(t);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable t) {
|
||||
this.delegate.onError(t);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
this.delegate.onComplete();
|
||||
}
|
||||
}
|
||||
|
||||
static class RequestContextDataHolder {
|
||||
private final HttpServletRequest request;
|
||||
private final HttpServletResponse response;
|
||||
private final Authentication authentication;
|
||||
|
||||
RequestContextDataHolder(@Nullable HttpServletRequest request,
|
||||
@Nullable HttpServletResponse response,
|
||||
@Nullable Authentication authentication) {
|
||||
this.request = request;
|
||||
this.response = response;
|
||||
this.authentication = authentication;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private HttpServletRequest getRequest() {
|
||||
return this.request;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private HttpServletResponse getResponse() {
|
||||
return this.response;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private Authentication getAuthentication() {
|
||||
return this.authentication;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -43,16 +43,20 @@ import org.springframework.web.context.request.RequestContextHolder;
|
||||
import org.springframework.web.context.request.ServletRequestAttributes;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
import reactor.blockhound.BlockHound;
|
||||
import reactor.util.context.Context;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY;
|
||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
|
||||
|
||||
/**
|
||||
@ -104,7 +108,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
|
||||
});
|
||||
this.authorizedClientFilter = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
|
||||
this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||
this.authorizedClientFilter.afterPropertiesSet();
|
||||
this.server = new MockWebServer();
|
||||
this.server.start();
|
||||
this.serverUrl = this.server.url("/").toString();
|
||||
@ -120,7 +123,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
|
||||
|
||||
@After
|
||||
public void cleanup() throws Exception {
|
||||
this.authorizedClientFilter.destroy();
|
||||
this.server.shutdown();
|
||||
SecurityContextHolder.clearContext();
|
||||
RequestContextHolder.resetRequestAttributes();
|
||||
@ -248,6 +250,7 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
|
||||
.attributes(clientRegistrationId(clientRegistration2.getRegistrationId()))
|
||||
.retrieve()
|
||||
.bodyToMono(String.class))
|
||||
.subscriberContext(context())
|
||||
.block();
|
||||
|
||||
assertThat(this.server.getRequestCount()).isEqualTo(4);
|
||||
@ -259,6 +262,14 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
|
||||
assertThat(authorizedClientCaptor.getAllValues().get(1).getClientRegistration()).isSameAs(clientRegistration2);
|
||||
}
|
||||
|
||||
private Context context() {
|
||||
Map<Object, Object> contextAttributes = new HashMap<>();
|
||||
contextAttributes.put(HttpServletRequest.class, this.request);
|
||||
contextAttributes.put(HttpServletResponse.class, this.response);
|
||||
contextAttributes.put(Authentication.class, this.authentication);
|
||||
return Context.of(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY, contextAttributes);
|
||||
}
|
||||
|
||||
private MockResponse jsonResponse(String json) {
|
||||
return new MockResponse()
|
||||
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
|
||||
|
@ -76,12 +76,10 @@ import org.springframework.web.context.request.ServletRequestAttributes;
|
||||
import org.springframework.web.reactive.function.BodyInserter;
|
||||
import org.springframework.web.reactive.function.client.ClientRequest;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
import reactor.core.CoreSubscriber;
|
||||
import reactor.core.publisher.BaseSubscriber;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.util.context.Context;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.net.URI;
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
@ -93,7 +91,6 @@ import java.util.Optional;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatCode;
|
||||
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.springframework.http.HttpMethod.GET;
|
||||
@ -163,7 +160,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||
public void cleanup() throws Exception {
|
||||
SecurityContextHolder.clearContext();
|
||||
RequestContextHolder.resetRequestAttributes();
|
||||
this.function.destroy();
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -591,18 +587,15 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||
// gh-6483
|
||||
@Test
|
||||
public void filterWhenChainedThenDefaultsStillAvailable() throws Exception {
|
||||
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
|
||||
this.function.setDefaultOAuth2AuthorizedClient(true);
|
||||
|
||||
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
|
||||
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
|
||||
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
|
||||
|
||||
OAuth2User user = mock(OAuth2User.class);
|
||||
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
|
||||
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
|
||||
user, authorities, this.registration.getRegistrationId());
|
||||
SecurityContextHolder.getContext().setAuthentication(authentication);
|
||||
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
|
||||
this.registration, "principalName", this.accessToken);
|
||||
@ -619,12 +612,13 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||
// Default request attributes NOT set
|
||||
final ClientRequest request2 = ClientRequest.create(GET, URI.create("https://example2.com")).build();
|
||||
|
||||
Context context = context(servletRequest, servletResponse, authentication);
|
||||
|
||||
this.function.filter(request1, this.exchange)
|
||||
.flatMap(response -> this.function.filter(request2, this.exchange))
|
||||
.subscriberContext(context)
|
||||
.block();
|
||||
|
||||
this.function.destroy(); // Hooks.onLastOperator() released
|
||||
|
||||
List<ClientRequest> requests = this.exchange.getRequests();
|
||||
assertThat(requests).hasSize(2);
|
||||
|
||||
@ -641,147 +635,12 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||
assertThat(getBody(request)).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void filterWhenRequestAttributesNotSetAndHooksNotInitThenDefaultsNotAvailable() {
|
||||
// this.function.afterPropertiesSet(); // Hooks.onLastOperator() NOT initialized
|
||||
this.function.setDefaultOAuth2AuthorizedClient(true);
|
||||
|
||||
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
|
||||
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
|
||||
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
|
||||
|
||||
OAuth2User user = mock(OAuth2User.class);
|
||||
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
|
||||
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
|
||||
user, authorities, this.registration.getRegistrationId());
|
||||
SecurityContextHolder.getContext().setAuthentication(authentication);
|
||||
|
||||
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
|
||||
|
||||
this.function.filter(request, this.exchange).block();
|
||||
|
||||
List<ClientRequest> requests = this.exchange.getRequests();
|
||||
assertThat(requests).hasSize(1);
|
||||
|
||||
request = requests.get(0);
|
||||
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
|
||||
assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
|
||||
assertThat(request.method()).isEqualTo(HttpMethod.GET);
|
||||
assertThat(getBody(request)).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsNotAvailable() throws Exception {
|
||||
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
|
||||
this.function.destroy(); // Hooks.onLastOperator() released
|
||||
this.function.setDefaultOAuth2AuthorizedClient(true);
|
||||
|
||||
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
|
||||
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
|
||||
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
|
||||
|
||||
OAuth2User user = mock(OAuth2User.class);
|
||||
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
|
||||
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
|
||||
user, authorities, this.registration.getRegistrationId());
|
||||
SecurityContextHolder.getContext().setAuthentication(authentication);
|
||||
|
||||
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
|
||||
|
||||
this.function.filter(request, this.exchange).block();
|
||||
|
||||
List<ClientRequest> requests = this.exchange.getRequests();
|
||||
assertThat(requests).hasSize(1);
|
||||
|
||||
request = requests.get(0);
|
||||
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
|
||||
assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
|
||||
assertThat(request.method()).isEqualTo(HttpMethod.GET);
|
||||
assertThat(getBody(request)).isEmpty();
|
||||
}
|
||||
|
||||
// gh-7228
|
||||
@Test
|
||||
public void afterPropertiesSetWhenHooksInitAndOutsideWebSecurityContextThenShouldNotThrowException() throws Exception {
|
||||
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
|
||||
assertThatCode(() -> Mono.subscriberContext().block())
|
||||
.as("RequestContext Hook brakes application outside of web/security context")
|
||||
.doesNotThrowAnyException();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createRequestContextSubscriberIfNecessaryWhenOutsideWebSecurityContextThenReturnOriginalSubscriber() throws Exception {
|
||||
BaseSubscriber<Object> originalSubscriber = new BaseSubscriber<Object>() {};
|
||||
CoreSubscriber<Object> resultSubscriber = this.function.createRequestContextSubscriberIfNecessary(originalSubscriber);
|
||||
assertThat(resultSubscriber).isSameAs(originalSubscriber);
|
||||
}
|
||||
|
||||
// gh-7228
|
||||
@Test
|
||||
public void createRequestContextSubscriberWhenRequestResponseProvidedThenCreateWithParentContext() throws Exception {
|
||||
testRequestContextSubscriber(new MockHttpServletRequest(), new MockHttpServletResponse(), null);
|
||||
}
|
||||
|
||||
// gh-7228
|
||||
@Test
|
||||
public void createRequestContextSubscriberWhenAuthenticationProvidedThenCreateWithParentContext() throws Exception {
|
||||
testRequestContextSubscriber(null, null, this.authentication);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createRequestContextSubscriberWhenParentContextHasDataHolderThenShouldReuseParentContext() throws Exception {
|
||||
RequestContextDataHolder testValue = new RequestContextDataHolder(null, null, null);
|
||||
final Context parentContext = Context.of(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, testValue);
|
||||
BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
|
||||
@Override
|
||||
public Context currentContext() {
|
||||
return parentContext;
|
||||
}
|
||||
};
|
||||
|
||||
RequestContextSubscriber<Object> requestContextSubscriber =
|
||||
new RequestContextSubscriber<>(parent, null, null, authentication);
|
||||
|
||||
Context resultContext = requestContextSubscriber.currentContext();
|
||||
|
||||
assertThat(resultContext)
|
||||
.describedAs("parent context was replaced")
|
||||
.isSameAs(parentContext);
|
||||
}
|
||||
|
||||
private void testRequestContextSubscriber(MockHttpServletRequest servletRequest,
|
||||
MockHttpServletResponse servletResponse,
|
||||
Authentication authentication) {
|
||||
String testKey = "test_key";
|
||||
String testValue = "test_value";
|
||||
|
||||
BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
|
||||
@Override
|
||||
public Context currentContext() {
|
||||
return Context.of(testKey, testValue);
|
||||
}
|
||||
};
|
||||
|
||||
RequestContextSubscriber<Object> requestContextSubscriber =
|
||||
new RequestContextSubscriber<>(parent, servletRequest, servletResponse, authentication);
|
||||
|
||||
Context resultContext = requestContextSubscriber.currentContext();
|
||||
|
||||
assertThat(resultContext)
|
||||
.describedAs("result context is null")
|
||||
.isNotNull();
|
||||
|
||||
assertThat(resultContext.getOrEmpty(testKey))
|
||||
.describedAs("context is replaced")
|
||||
.hasValue(testValue);
|
||||
|
||||
Object dataHolder = resultContext.getOrDefault(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, null);
|
||||
assertThat(dataHolder)
|
||||
.describedAs("context is not populated with REQUEST_CONTEXT_DATA_HOLDER")
|
||||
.isNotNull()
|
||||
.hasFieldOrPropertyWithValue("request", servletRequest)
|
||||
.hasFieldOrPropertyWithValue("response", servletResponse)
|
||||
.hasFieldOrPropertyWithValue("authentication", authentication);
|
||||
private Context context(HttpServletRequest servletRequest, HttpServletResponse servletResponse, Authentication authentication) {
|
||||
Map<Object, Object> contextAttributes = new HashMap<>();
|
||||
contextAttributes.put(HttpServletRequest.class, servletRequest);
|
||||
contextAttributes.put(HttpServletResponse.class, servletResponse);
|
||||
contextAttributes.put(Authentication.class, authentication);
|
||||
return Context.of(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY, contextAttributes);
|
||||
}
|
||||
|
||||
private static String getBody(ClientRequest request) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user