Fix NPE in RequestContextSubscriber
RequestContextSubscriber could cause NPE if Mono/Flux.subscribe() was invoked outside of Web Context. In addition it replaced source Context with its own without respect to old data. Now Request Context Data is Propagated within holder class and it is added to existing reactor Context if Holder is not empty. Fixes gh-7228
This commit is contained in:
parent
1de885e298
commit
ffc43e02c3
|
@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client.web.reactive.function.client;
|
|||
import org.reactivestreams.Subscription;
|
||||
import org.springframework.beans.factory.DisposableBean;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.security.authentication.AnonymousAuthenticationToken;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.GrantedAuthority;
|
||||
|
@ -95,6 +96,7 @@ import java.util.function.Consumer;
|
|||
*
|
||||
* @author Rob Winch
|
||||
* @author Joe Grandja
|
||||
* @author Roman Matiushchenko
|
||||
* @since 5.1
|
||||
* @see OAuth2AuthorizedClientManager
|
||||
*/
|
||||
|
@ -174,7 +176,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
|||
|
||||
@Override
|
||||
public void afterPropertiesSet() throws Exception {
|
||||
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> createRequestContextSubscriber(sub)));
|
||||
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.liftPublisher((s, sub) -> createRequestContextSubscriberIfNecessary(sub)));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -378,14 +380,22 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
|||
}
|
||||
|
||||
private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
|
||||
if (ctx.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) {
|
||||
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ctx.get(HTTP_SERVLET_REQUEST_ATTR_NAME));
|
||||
}
|
||||
if (ctx.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
|
||||
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ctx.get(HTTP_SERVLET_RESPONSE_ATTR_NAME));
|
||||
}
|
||||
if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) {
|
||||
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME));
|
||||
RequestContextDataHolder holder = RequestContextSubscriber.getRequestContext(ctx);
|
||||
if (holder != null) {
|
||||
HttpServletRequest request = holder.getRequest();
|
||||
if (request != null) {
|
||||
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
|
||||
}
|
||||
|
||||
HttpServletResponse response = holder.getResponse();
|
||||
if (response != null) {
|
||||
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
|
||||
}
|
||||
|
||||
Authentication authentication = holder.getAuthentication();
|
||||
if (authentication != null) {
|
||||
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -472,7 +482,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
|||
.build();
|
||||
}
|
||||
|
||||
private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> delegate) {
|
||||
<T> CoreSubscriber<T> createRequestContextSubscriberIfNecessary(CoreSubscriber<T> delegate) {
|
||||
HttpServletRequest request = null;
|
||||
HttpServletResponse response = null;
|
||||
ServletRequestAttributes requestAttributes =
|
||||
|
@ -482,6 +492,10 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
|||
response = requestAttributes.getResponse();
|
||||
}
|
||||
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
|
||||
if (authentication == null && request == null && response == null) {
|
||||
//do not need to create RequestContextSubscriber with empty data
|
||||
return delegate;
|
||||
}
|
||||
return new RequestContextSubscriber<>(delegate, request, response, authentication);
|
||||
}
|
||||
|
||||
|
@ -553,34 +567,37 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
|||
}
|
||||
}
|
||||
|
||||
private static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
|
||||
private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME");
|
||||
static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
|
||||
static final String REQUEST_CONTEXT_DATA_HOLDER =
|
||||
RequestContextSubscriber.class.getName().concat(".REQUEST_CONTEXT_DATA_HOLDER");
|
||||
private final CoreSubscriber<T> delegate;
|
||||
private final HttpServletRequest request;
|
||||
private final HttpServletResponse response;
|
||||
private final Authentication authentication;
|
||||
private final Context context;
|
||||
|
||||
private RequestContextSubscriber(CoreSubscriber<T> delegate,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response,
|
||||
Authentication authentication) {
|
||||
RequestContextSubscriber(CoreSubscriber<T> delegate,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response,
|
||||
Authentication authentication) {
|
||||
this.delegate = delegate;
|
||||
this.request = request;
|
||||
this.response = response;
|
||||
this.authentication = authentication;
|
||||
|
||||
Context parentContext = this.delegate.currentContext();
|
||||
Context context;
|
||||
if (parentContext.hasKey(REQUEST_CONTEXT_DATA_HOLDER)) {
|
||||
context = parentContext;
|
||||
} else {
|
||||
context = parentContext.put(REQUEST_CONTEXT_DATA_HOLDER, new RequestContextDataHolder(request, response, authentication));
|
||||
}
|
||||
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private static RequestContextDataHolder getRequestContext(Context ctx) {
|
||||
return ctx.getOrDefault(REQUEST_CONTEXT_DATA_HOLDER, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Context currentContext() {
|
||||
Context context = this.delegate.currentContext();
|
||||
if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) {
|
||||
return context;
|
||||
}
|
||||
return Context.of(
|
||||
CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE,
|
||||
HTTP_SERVLET_REQUEST_ATTR_NAME, this.request,
|
||||
HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response,
|
||||
AUTHENTICATION_ATTR_NAME, this.authentication);
|
||||
return this.context;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -603,4 +620,33 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
|||
this.delegate.onComplete();
|
||||
}
|
||||
}
|
||||
|
||||
static class RequestContextDataHolder {
|
||||
private final HttpServletRequest request;
|
||||
private final HttpServletResponse response;
|
||||
private final Authentication authentication;
|
||||
|
||||
RequestContextDataHolder(@Nullable HttpServletRequest request,
|
||||
@Nullable HttpServletResponse response,
|
||||
@Nullable Authentication authentication) {
|
||||
this.request = request;
|
||||
this.response = response;
|
||||
this.authentication = authentication;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private HttpServletRequest getRequest() {
|
||||
return this.request;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private HttpServletResponse getResponse() {
|
||||
return this.response;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private Authentication getAuthentication() {
|
||||
return this.authentication;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -72,6 +72,10 @@ import org.springframework.web.context.request.ServletRequestAttributes;
|
|||
import org.springframework.web.reactive.function.BodyInserter;
|
||||
import org.springframework.web.reactive.function.client.ClientRequest;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
import reactor.core.CoreSubscriber;
|
||||
import reactor.core.publisher.BaseSubscriber;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.util.context.Context;
|
||||
|
||||
import java.net.URI;
|
||||
import java.time.Duration;
|
||||
|
@ -84,6 +88,7 @@ import java.util.Optional;
|
|||
import java.util.function.Consumer;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatCode;
|
||||
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.springframework.http.HttpMethod.GET;
|
||||
|
@ -144,9 +149,10 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|||
}
|
||||
|
||||
@After
|
||||
public void cleanup() {
|
||||
public void cleanup() throws Exception {
|
||||
SecurityContextHolder.clearContext();
|
||||
RequestContextHolder.resetRequestAttributes();
|
||||
this.function.destroy();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -633,6 +639,90 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|||
assertThat(getBody(request)).isEmpty();
|
||||
}
|
||||
|
||||
// gh-7228
|
||||
@Test
|
||||
public void afterPropertiesSetWhenHooksInitAndOutsideWebSecurityContextThenShouldNotThrowException() throws Exception {
|
||||
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
|
||||
assertThatCode(() -> Mono.subscriberContext().block())
|
||||
.as("RequestContext Hook brakes application outside of web/security context")
|
||||
.doesNotThrowAnyException();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createRequestContextSubscriberIfNecessaryWhenOutsideWebSecurityContextThenReturnOriginalSubscriber() throws Exception {
|
||||
BaseSubscriber<Object> originalSubscriber = new BaseSubscriber<Object>() {};
|
||||
CoreSubscriber<Object> resultSubscriber = this.function.createRequestContextSubscriberIfNecessary(originalSubscriber);
|
||||
assertThat(resultSubscriber).isSameAs(originalSubscriber);
|
||||
}
|
||||
|
||||
// gh-7228
|
||||
@Test
|
||||
public void createRequestContextSubscriberWhenRequestResponseProvidedThenCreateWithParentContext() throws Exception {
|
||||
testRequestContextSubscriber(new MockHttpServletRequest(), new MockHttpServletResponse(), null);
|
||||
}
|
||||
|
||||
// gh-7228
|
||||
@Test
|
||||
public void createRequestContextSubscriberWhenAuthenticationProvidedThenCreateWithParentContext() throws Exception {
|
||||
testRequestContextSubscriber(null, null, this.authentication);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createRequestContextSubscriberWhenParentContextHasDataHolderThenShouldReuseParentContext() throws Exception {
|
||||
RequestContextDataHolder testValue = new RequestContextDataHolder(null, null, null);
|
||||
final Context parentContext = Context.of(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, testValue);
|
||||
BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
|
||||
@Override
|
||||
public Context currentContext() {
|
||||
return parentContext;
|
||||
}
|
||||
};
|
||||
|
||||
RequestContextSubscriber<Object> requestContextSubscriber =
|
||||
new RequestContextSubscriber<>(parent, null, null, authentication);
|
||||
|
||||
Context resultContext = requestContextSubscriber.currentContext();
|
||||
|
||||
assertThat(resultContext)
|
||||
.describedAs("parent context was replaced")
|
||||
.isSameAs(parentContext);
|
||||
}
|
||||
|
||||
private void testRequestContextSubscriber(MockHttpServletRequest servletRequest,
|
||||
MockHttpServletResponse servletResponse,
|
||||
Authentication authentication) {
|
||||
String testKey = "test_key";
|
||||
String testValue = "test_value";
|
||||
|
||||
BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
|
||||
@Override
|
||||
public Context currentContext() {
|
||||
return Context.of(testKey, testValue);
|
||||
}
|
||||
};
|
||||
|
||||
RequestContextSubscriber<Object> requestContextSubscriber =
|
||||
new RequestContextSubscriber<>(parent, servletRequest, servletResponse, authentication);
|
||||
|
||||
Context resultContext = requestContextSubscriber.currentContext();
|
||||
|
||||
assertThat(resultContext)
|
||||
.describedAs("result context is null")
|
||||
.isNotNull();
|
||||
|
||||
assertThat(resultContext.getOrEmpty(testKey))
|
||||
.describedAs("context is replaced")
|
||||
.hasValue(testValue);
|
||||
|
||||
Object dataHolder = resultContext.getOrDefault(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, null);
|
||||
assertThat(dataHolder)
|
||||
.describedAs("context is not populated with REQUEST_CONTEXT_DATA_HOLDER")
|
||||
.isNotNull()
|
||||
.hasFieldOrPropertyWithValue("request", servletRequest)
|
||||
.hasFieldOrPropertyWithValue("response", servletResponse)
|
||||
.hasFieldOrPropertyWithValue("authentication", authentication);
|
||||
}
|
||||
|
||||
private static String getBody(ClientRequest request) {
|
||||
final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
|
||||
messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));
|
||||
|
|
Loading…
Reference in New Issue