diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 13a7afce51..22c12146d9 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -16,6 +16,9 @@ package org.springframework.security.oauth2.client.web.reactive.function.client; +import org.reactivestreams.Subscription; +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; @@ -44,8 +47,12 @@ import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ExchangeFilterFunction; import org.springframework.web.reactive.function.client.ExchangeFunction; import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; import reactor.core.scheduler.Schedulers; +import reactor.util.context.Context; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -98,7 +105,9 @@ import static org.springframework.security.oauth2.core.web.reactive.function.OAu * @author Rob Winch * @since 5.1 */ -public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction { +public final class ServletOAuth2AuthorizedClientExchangeFilterFunction + implements ExchangeFilterFunction, InitializingBean, DisposableBean { + /** * The request attribute name used to locate the {@link OAuth2AuthorizedClient}. */ @@ -108,6 +117,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName(); private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName(); + private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName(); + private Clock clock = Clock.systemUTC(); private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); @@ -123,7 +134,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement private String defaultClientRegistrationId; - public ServletOAuth2AuthorizedClientExchangeFilterFunction() {} + public ServletOAuth2AuthorizedClientExchangeFilterFunction() { + } public ServletOAuth2AuthorizedClientExchangeFilterFunction( ClientRegistrationRepository clientRegistrationRepository, @@ -132,6 +144,16 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement this.authorizedClientRepository = authorizedClientRepository; } + @Override + public void afterPropertiesSet() throws Exception { + Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> createRequestContextSubscriber(sub))); + } + + @Override + public void destroy() throws Exception { + Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY); + } + /** * Sets the {@link OAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for * client_credentials grant. @@ -266,15 +288,36 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement @Override public Mono filter(ClientRequest request, ExchangeFunction next) { - Optional attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME) - .map(OAuth2AuthorizedClient.class::cast); - return Mono.justOrEmpty(attribute) - .flatMap(authorizedClient -> authorizedClient(request, next, authorizedClient)) + return Mono.just(request) + .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) + .switchIfEmpty(mergeRequestAttributesFromContext(request)) + .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) + .flatMap(req -> authorizedClient(req, next, getOAuth2AuthorizedClient(req.attributes()))) .map(authorizedClient -> bearer(request, authorizedClient)) .flatMap(next::exchange) .switchIfEmpty(next.exchange(request)); } + private Mono mergeRequestAttributesFromContext(ClientRequest request) { + return Mono.just(ClientRequest.from(request)) + .flatMap(builder -> Mono.subscriberContext() + .map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx)))) + .map(ClientRequest.Builder::build); + } + + private void populateRequestAttributes(Map attrs, Context ctx) { + if (ctx.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) { + attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ctx.get(HTTP_SERVLET_REQUEST_ATTR_NAME)); + } + if (ctx.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) { + attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ctx.get(HTTP_SERVLET_RESPONSE_ATTR_NAME)); + } + if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) { + attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME)); + } + populateDefaultOAuth2AuthorizedClient(attrs); + } + private void populateDefaultRequestResponse(Map attrs) { if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey( HTTP_SERVLET_RESPONSE_ATTR_NAME)) { @@ -425,6 +468,19 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement .build(); } + private CoreSubscriber createRequestContextSubscriber(CoreSubscriber delegate) { + HttpServletRequest request = null; + HttpServletResponse response = null; + ServletRequestAttributes requestAttributes = + (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); + if (requestAttributes != null) { + request = requestAttributes.getRequest(); + response = requestAttributes.getResponse(); + } + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + return new RequestContextSubscriber<>(delegate, request, response, authentication); + } + private static BodyInserters.FormInserter refreshTokenBody(String refreshToken) { return BodyInserters .fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue()) @@ -498,4 +554,55 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement return new UnsupportedOperationException("Not Supported"); } } + + private static class RequestContextSubscriber implements CoreSubscriber { + private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME"); + private final CoreSubscriber delegate; + private final HttpServletRequest request; + private final HttpServletResponse response; + private final Authentication authentication; + + private RequestContextSubscriber(CoreSubscriber delegate, + HttpServletRequest request, + HttpServletResponse response, + Authentication authentication) { + this.delegate = delegate; + this.request = request; + this.response = response; + this.authentication = authentication; + } + + @Override + public Context currentContext() { + Context context = this.delegate.currentContext(); + if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) { + return context; + } + return Context.of( + CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE, + HTTP_SERVLET_REQUEST_ATTR_NAME, this.request, + HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response, + AUTHENTICATION_ATTR_NAME, this.authentication); + } + + @Override + public void onSubscribe(Subscription s) { + this.delegate.onSubscribe(s); + } + + @Override + public void onNext(T t) { + this.delegate.onNext(t); + } + + @Override + public void onError(Throwable t) { + this.delegate.onError(t); + } + + @Override + public void onComplete() { + this.delegate.onComplete(); + } + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 1984f032de..bfda009e83 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -74,13 +74,11 @@ import java.util.Map; import java.util.Optional; import java.util.function.Consumer; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import static org.springframework.http.HttpMethod.GET; import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.*; @@ -572,6 +570,121 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { assertThat(getBody(request0)).isEmpty(); } + // gh-6483 + @Test + public void filterWhenChainedThenDefaultsStillAvailable() throws Exception { + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction( + this.clientRegistrationRepository, this.authorizedClientRepository); + this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized + this.function.setDefaultOAuth2AuthorizedClient(true); + + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse)); + + OAuth2User user = mock(OAuth2User.class); + List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken( + user, authorities, this.registration.getRegistrationId()); + SecurityContextHolder.getContext().setAuthentication(authentication); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration, "principalName", this.accessToken); + when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()), + eq(authentication), eq(servletRequest))).thenReturn(authorizedClient); + + // Default request attributes set + final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com")) + .attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build(); + + // Default request attributes NOT set + final ClientRequest request2 = ClientRequest.create(GET, URI.create("https://example2.com")).build(); + + this.function.filter(request1, this.exchange) + .flatMap(response -> this.function.filter(request2, this.exchange)) + .block(); + + this.function.destroy(); // Hooks.onLastOperator() released + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(2); + + ClientRequest request = requests.get(0); + assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); + assertThat(request.url().toASCIIString()).isEqualTo("https://example1.com"); + assertThat(request.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request)).isEmpty(); + + request = requests.get(1); + assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); + assertThat(request.url().toASCIIString()).isEqualTo("https://example2.com"); + assertThat(request.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request)).isEmpty(); + } + + @Test + public void filterWhenRequestAttributesNotSetAndHooksNotInitThenDefaultsNotAvailable() throws Exception { + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction( + this.clientRegistrationRepository, this.authorizedClientRepository); +// this.function.afterPropertiesSet(); // Hooks.onLastOperator() NOT initialized + this.function.setDefaultOAuth2AuthorizedClient(true); + + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse)); + + OAuth2User user = mock(OAuth2User.class); + List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken( + user, authorities, this.registration.getRegistrationId()); + SecurityContextHolder.getContext().setAuthentication(authentication); + + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build(); + + this.function.filter(request, this.exchange).block(); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + + request = requests.get(0); + assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(request.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request)).isEmpty(); + } + + @Test + public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsNotAvailable() throws Exception { + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction( + this.clientRegistrationRepository, this.authorizedClientRepository); + this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized + this.function.destroy(); // Hooks.onLastOperator() released + this.function.setDefaultOAuth2AuthorizedClient(true); + + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse)); + + OAuth2User user = mock(OAuth2User.class); + List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken( + user, authorities, this.registration.getRegistrationId()); + SecurityContextHolder.getContext().setAuthentication(authentication); + + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build(); + + this.function.filter(request, this.exchange).block(); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + + request = requests.get(0); + assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(request.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request)).isEmpty(); + } + private static String getBody(ClientRequest request) { final List> messageWriters = new ArrayList<>(); messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));