Fix NPE in RequestContextSubscriber

RequestContextSubscriber could cause NPE if Mono/Flux.subscribe()
was invoked outside of Web Context.
In addition it replaced source Context with its own without respect
to old data.
Now Request Context Data is Propagated within holder class and
it is added to existing reactor Context if Holder is not empty.

Fixes gh-7228
This commit is contained in:
Roman Matiushchenko 2019-08-30 16:49:38 +03:00
parent 1de885e298
commit ffc43e02c3
2 changed files with 168 additions and 32 deletions

View File

@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client.web.reactive.function.client;
import org.reactivestreams.Subscription;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.lang.Nullable;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
@ -95,6 +96,7 @@ import java.util.function.Consumer;
*
* @author Rob Winch
* @author Joe Grandja
* @author Roman Matiushchenko
* @since 5.1
* @see OAuth2AuthorizedClientManager
*/
@ -174,7 +176,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
@Override
public void afterPropertiesSet() throws Exception {
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> createRequestContextSubscriber(sub)));
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.liftPublisher((s, sub) -> createRequestContextSubscriberIfNecessary(sub)));
}
@Override
@ -378,14 +380,22 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
}
private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
if (ctx.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) {
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ctx.get(HTTP_SERVLET_REQUEST_ATTR_NAME));
}
if (ctx.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ctx.get(HTTP_SERVLET_RESPONSE_ATTR_NAME));
}
if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) {
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME));
RequestContextDataHolder holder = RequestContextSubscriber.getRequestContext(ctx);
if (holder != null) {
HttpServletRequest request = holder.getRequest();
if (request != null) {
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
}
HttpServletResponse response = holder.getResponse();
if (response != null) {
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
}
Authentication authentication = holder.getAuthentication();
if (authentication != null) {
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
}
}
}
@ -472,7 +482,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
.build();
}
private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> delegate) {
<T> CoreSubscriber<T> createRequestContextSubscriberIfNecessary(CoreSubscriber<T> delegate) {
HttpServletRequest request = null;
HttpServletResponse response = null;
ServletRequestAttributes requestAttributes =
@ -482,6 +492,10 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
response = requestAttributes.getResponse();
}
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication == null && request == null && response == null) {
//do not need to create RequestContextSubscriber with empty data
return delegate;
}
return new RequestContextSubscriber<>(delegate, request, response, authentication);
}
@ -553,34 +567,37 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
}
}
private static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME");
static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
static final String REQUEST_CONTEXT_DATA_HOLDER =
RequestContextSubscriber.class.getName().concat(".REQUEST_CONTEXT_DATA_HOLDER");
private final CoreSubscriber<T> delegate;
private final HttpServletRequest request;
private final HttpServletResponse response;
private final Authentication authentication;
private final Context context;
private RequestContextSubscriber(CoreSubscriber<T> delegate,
HttpServletRequest request,
HttpServletResponse response,
Authentication authentication) {
RequestContextSubscriber(CoreSubscriber<T> delegate,
HttpServletRequest request,
HttpServletResponse response,
Authentication authentication) {
this.delegate = delegate;
this.request = request;
this.response = response;
this.authentication = authentication;
Context parentContext = this.delegate.currentContext();
Context context;
if (parentContext.hasKey(REQUEST_CONTEXT_DATA_HOLDER)) {
context = parentContext;
} else {
context = parentContext.put(REQUEST_CONTEXT_DATA_HOLDER, new RequestContextDataHolder(request, response, authentication));
}
this.context = context;
}
@Nullable
private static RequestContextDataHolder getRequestContext(Context ctx) {
return ctx.getOrDefault(REQUEST_CONTEXT_DATA_HOLDER, null);
}
@Override
public Context currentContext() {
Context context = this.delegate.currentContext();
if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) {
return context;
}
return Context.of(
CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE,
HTTP_SERVLET_REQUEST_ATTR_NAME, this.request,
HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response,
AUTHENTICATION_ATTR_NAME, this.authentication);
return this.context;
}
@Override
@ -603,4 +620,33 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
this.delegate.onComplete();
}
}
static class RequestContextDataHolder {
private final HttpServletRequest request;
private final HttpServletResponse response;
private final Authentication authentication;
RequestContextDataHolder(@Nullable HttpServletRequest request,
@Nullable HttpServletResponse response,
@Nullable Authentication authentication) {
this.request = request;
this.response = response;
this.authentication = authentication;
}
@Nullable
private HttpServletRequest getRequest() {
return this.request;
}
@Nullable
private HttpServletResponse getResponse() {
return this.response;
}
@Nullable
private Authentication getAuthentication() {
return this.authentication;
}
}
}

View File

@ -72,6 +72,10 @@ import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;
import java.net.URI;
import java.time.Duration;
@ -84,6 +88,7 @@ import java.util.Optional;
import java.util.function.Consumer;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.mockito.Mockito.*;
import static org.springframework.http.HttpMethod.GET;
@ -144,9 +149,10 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
}
@After
public void cleanup() {
public void cleanup() throws Exception {
SecurityContextHolder.clearContext();
RequestContextHolder.resetRequestAttributes();
this.function.destroy();
}
@Test
@ -633,6 +639,90 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
assertThat(getBody(request)).isEmpty();
}
// gh-7228
@Test
public void afterPropertiesSetWhenHooksInitAndOutsideWebSecurityContextThenShouldNotThrowException() throws Exception {
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
assertThatCode(() -> Mono.subscriberContext().block())
.as("RequestContext Hook brakes application outside of web/security context")
.doesNotThrowAnyException();
}
@Test
public void createRequestContextSubscriberIfNecessaryWhenOutsideWebSecurityContextThenReturnOriginalSubscriber() throws Exception {
BaseSubscriber<Object> originalSubscriber = new BaseSubscriber<Object>() {};
CoreSubscriber<Object> resultSubscriber = this.function.createRequestContextSubscriberIfNecessary(originalSubscriber);
assertThat(resultSubscriber).isSameAs(originalSubscriber);
}
// gh-7228
@Test
public void createRequestContextSubscriberWhenRequestResponseProvidedThenCreateWithParentContext() throws Exception {
testRequestContextSubscriber(new MockHttpServletRequest(), new MockHttpServletResponse(), null);
}
// gh-7228
@Test
public void createRequestContextSubscriberWhenAuthenticationProvidedThenCreateWithParentContext() throws Exception {
testRequestContextSubscriber(null, null, this.authentication);
}
@Test
public void createRequestContextSubscriberWhenParentContextHasDataHolderThenShouldReuseParentContext() throws Exception {
RequestContextDataHolder testValue = new RequestContextDataHolder(null, null, null);
final Context parentContext = Context.of(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, testValue);
BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
@Override
public Context currentContext() {
return parentContext;
}
};
RequestContextSubscriber<Object> requestContextSubscriber =
new RequestContextSubscriber<>(parent, null, null, authentication);
Context resultContext = requestContextSubscriber.currentContext();
assertThat(resultContext)
.describedAs("parent context was replaced")
.isSameAs(parentContext);
}
private void testRequestContextSubscriber(MockHttpServletRequest servletRequest,
MockHttpServletResponse servletResponse,
Authentication authentication) {
String testKey = "test_key";
String testValue = "test_value";
BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
@Override
public Context currentContext() {
return Context.of(testKey, testValue);
}
};
RequestContextSubscriber<Object> requestContextSubscriber =
new RequestContextSubscriber<>(parent, servletRequest, servletResponse, authentication);
Context resultContext = requestContextSubscriber.currentContext();
assertThat(resultContext)
.describedAs("result context is null")
.isNotNull();
assertThat(resultContext.getOrEmpty(testKey))
.describedAs("context is replaced")
.hasValue(testValue);
Object dataHolder = resultContext.getOrDefault(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, null);
assertThat(dataHolder)
.describedAs("context is not populated with REQUEST_CONTEXT_DATA_HOLDER")
.isNotNull()
.hasFieldOrPropertyWithValue("request", servletRequest)
.hasFieldOrPropertyWithValue("response", servletResponse)
.hasFieldOrPropertyWithValue("authentication", authentication);
}
private static String getBody(ClientRequest request) {
final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));