ServletOAuth2AuthorizedClientExchangeFilterFunction supports chaining

Fixes gh-6483
This commit is contained in:
Joe Grandja 2019-02-12 21:51:44 -05:00
parent 0c2a7e03f7
commit 0c27f64338
2 changed files with 231 additions and 12 deletions

View File

@ -16,6 +16,9 @@
package org.springframework.security.oauth2.client.web.reactive.function.client;
import org.reactivestreams.Subscription;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
@ -44,8 +47,12 @@ import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import org.springframework.web.reactive.function.client.ExchangeFunction;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Hooks;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Operators;
import reactor.core.scheduler.Schedulers;
import reactor.util.context.Context;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@ -98,7 +105,9 @@ import static org.springframework.security.oauth2.core.web.reactive.function.OAu
* @author Rob Winch
* @since 5.1
*/
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
implements ExchangeFilterFunction, InitializingBean, DisposableBean {
/**
* The request attribute name used to locate the {@link OAuth2AuthorizedClient}.
*/
@ -108,6 +117,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName();
private Clock clock = Clock.systemUTC();
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
@ -123,7 +134,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
private String defaultClientRegistrationId;
public ServletOAuth2AuthorizedClientExchangeFilterFunction() {}
public ServletOAuth2AuthorizedClientExchangeFilterFunction() {
}
public ServletOAuth2AuthorizedClientExchangeFilterFunction(
ClientRegistrationRepository clientRegistrationRepository,
@ -132,6 +144,16 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
this.authorizedClientRepository = authorizedClientRepository;
}
@Override
public void afterPropertiesSet() throws Exception {
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> createRequestContextSubscriber(sub)));
}
@Override
public void destroy() throws Exception {
Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY);
}
/**
* Sets the {@link OAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
* client_credentials grant.
@ -266,15 +288,36 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
@Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
.map(OAuth2AuthorizedClient.class::cast);
return Mono.justOrEmpty(attribute)
.flatMap(authorizedClient -> authorizedClient(request, next, authorizedClient))
return Mono.just(request)
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
.switchIfEmpty(mergeRequestAttributesFromContext(request))
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
.flatMap(req -> authorizedClient(req, next, getOAuth2AuthorizedClient(req.attributes())))
.map(authorizedClient -> bearer(request, authorizedClient))
.flatMap(next::exchange)
.switchIfEmpty(next.exchange(request));
}
private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest request) {
return Mono.just(ClientRequest.from(request))
.flatMap(builder -> Mono.subscriberContext()
.map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx))))
.map(ClientRequest.Builder::build);
}
private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
if (ctx.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) {
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ctx.get(HTTP_SERVLET_REQUEST_ATTR_NAME));
}
if (ctx.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ctx.get(HTTP_SERVLET_RESPONSE_ATTR_NAME));
}
if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) {
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME));
}
populateDefaultOAuth2AuthorizedClient(attrs);
}
private void populateDefaultRequestResponse(Map<String, Object> attrs) {
if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey(
HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
@ -435,6 +478,19 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
.build();
}
private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> delegate) {
HttpServletRequest request = null;
HttpServletResponse response = null;
ServletRequestAttributes requestAttributes =
(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if (requestAttributes != null) {
request = requestAttributes.getRequest();
response = requestAttributes.getResponse();
}
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
return new RequestContextSubscriber<>(delegate, request, response, authentication);
}
private static BodyInserters.FormInserter<String> refreshTokenBody(String refreshToken) {
return BodyInserters
.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue())
@ -508,4 +564,55 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
return new UnsupportedOperationException("Not Supported");
}
}
private static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME");
private final CoreSubscriber<T> delegate;
private final HttpServletRequest request;
private final HttpServletResponse response;
private final Authentication authentication;
private RequestContextSubscriber(CoreSubscriber<T> delegate,
HttpServletRequest request,
HttpServletResponse response,
Authentication authentication) {
this.delegate = delegate;
this.request = request;
this.response = response;
this.authentication = authentication;
}
@Override
public Context currentContext() {
Context context = this.delegate.currentContext();
if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) {
return context;
}
return Context.of(
CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE,
HTTP_SERVLET_REQUEST_ATTR_NAME, this.request,
HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response,
AUTHENTICATION_ATTR_NAME, this.authentication);
}
@Override
public void onSubscribe(Subscription s) {
this.delegate.onSubscribe(s);
}
@Override
public void onNext(T t) {
this.delegate.onNext(t);
}
@Override
public void onError(Throwable t) {
this.delegate.onError(t);
}
@Override
public void onComplete() {
this.delegate.onComplete();
}
}
}

View File

@ -74,14 +74,11 @@ import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import static org.assertj.core.api.Assertions.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;
import static org.springframework.http.HttpMethod.GET;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.*;
@ -647,6 +644,121 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
assertThat(getBody(request0)).isEmpty();
}
// gh-6483
@Test
public void filterWhenChainedThenDefaultsStillAvailable() throws Exception {
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
this.clientRegistrationRepository, this.authorizedClientRepository);
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
this.function.setDefaultOAuth2AuthorizedClient(true);
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
OAuth2User user = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
user, authorities, this.registration.getRegistrationId());
SecurityContextHolder.getContext().setAuthentication(authentication);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration, "principalName", this.accessToken);
when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()),
eq(authentication), eq(servletRequest))).thenReturn(authorizedClient);
// Default request attributes set
final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com"))
.attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build();
// Default request attributes NOT set
final ClientRequest request2 = ClientRequest.create(GET, URI.create("https://example2.com")).build();
this.function.filter(request1, this.exchange)
.flatMap(response -> this.function.filter(request2, this.exchange))
.block();
this.function.destroy(); // Hooks.onLastOperator() released
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(2);
ClientRequest request = requests.get(0);
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
assertThat(request.url().toASCIIString()).isEqualTo("https://example1.com");
assertThat(request.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request)).isEmpty();
request = requests.get(1);
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
assertThat(request.url().toASCIIString()).isEqualTo("https://example2.com");
assertThat(request.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request)).isEmpty();
}
@Test
public void filterWhenRequestAttributesNotSetAndHooksNotInitThenDefaultsNotAvailable() throws Exception {
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
this.clientRegistrationRepository, this.authorizedClientRepository);
// this.function.afterPropertiesSet(); // Hooks.onLastOperator() NOT initialized
this.function.setDefaultOAuth2AuthorizedClient(true);
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
OAuth2User user = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
user, authorities, this.registration.getRegistrationId());
SecurityContextHolder.getContext().setAuthentication(authentication);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
this.function.filter(request, this.exchange).block();
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
request = requests.get(0);
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request)).isEmpty();
}
@Test
public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsNotAvailable() throws Exception {
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
this.clientRegistrationRepository, this.authorizedClientRepository);
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
this.function.destroy(); // Hooks.onLastOperator() released
this.function.setDefaultOAuth2AuthorizedClient(true);
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
OAuth2User user = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
user, authorities, this.registration.getRegistrationId());
SecurityContextHolder.getContext().setAuthentication(authentication);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
this.function.filter(request, this.exchange).block();
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
request = requests.get(0);
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request)).isEmpty();
}
private static String getBody(ClientRequest request) {
final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));