ServletOAuth2AuthorizedClientExchangeFilterFunction supports chaining
Fixes gh-6483
This commit is contained in:
parent
0c2a7e03f7
commit
0c27f64338
|
@ -16,6 +16,9 @@
|
|||
|
||||
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
||||
|
||||
import org.reactivestreams.Subscription;
|
||||
import org.springframework.beans.factory.DisposableBean;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
|
@ -44,8 +47,12 @@ import org.springframework.web.reactive.function.client.ClientResponse;
|
|||
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
|
||||
import org.springframework.web.reactive.function.client.ExchangeFunction;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
import reactor.core.CoreSubscriber;
|
||||
import reactor.core.publisher.Hooks;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.publisher.Operators;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
import reactor.util.context.Context;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
@ -98,7 +105,9 @@ import static org.springframework.security.oauth2.core.web.reactive.function.OAu
|
|||
* @author Rob Winch
|
||||
* @since 5.1
|
||||
*/
|
||||
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
|
||||
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
||||
implements ExchangeFilterFunction, InitializingBean, DisposableBean {
|
||||
|
||||
/**
|
||||
* The request attribute name used to locate the {@link OAuth2AuthorizedClient}.
|
||||
*/
|
||||
|
@ -108,6 +117,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
|
|||
private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
|
||||
private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
|
||||
|
||||
private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName();
|
||||
|
||||
private Clock clock = Clock.systemUTC();
|
||||
|
||||
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
|
||||
|
@ -123,7 +134,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
|
|||
|
||||
private String defaultClientRegistrationId;
|
||||
|
||||
public ServletOAuth2AuthorizedClientExchangeFilterFunction() {}
|
||||
public ServletOAuth2AuthorizedClientExchangeFilterFunction() {
|
||||
}
|
||||
|
||||
public ServletOAuth2AuthorizedClientExchangeFilterFunction(
|
||||
ClientRegistrationRepository clientRegistrationRepository,
|
||||
|
@ -132,6 +144,16 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
|
|||
this.authorizedClientRepository = authorizedClientRepository;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() throws Exception {
|
||||
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> createRequestContextSubscriber(sub)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void destroy() throws Exception {
|
||||
Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@link OAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
|
||||
* client_credentials grant.
|
||||
|
@ -266,15 +288,36 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
|
|||
|
||||
@Override
|
||||
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
|
||||
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
|
||||
.map(OAuth2AuthorizedClient.class::cast);
|
||||
return Mono.justOrEmpty(attribute)
|
||||
.flatMap(authorizedClient -> authorizedClient(request, next, authorizedClient))
|
||||
return Mono.just(request)
|
||||
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
|
||||
.switchIfEmpty(mergeRequestAttributesFromContext(request))
|
||||
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
|
||||
.flatMap(req -> authorizedClient(req, next, getOAuth2AuthorizedClient(req.attributes())))
|
||||
.map(authorizedClient -> bearer(request, authorizedClient))
|
||||
.flatMap(next::exchange)
|
||||
.switchIfEmpty(next.exchange(request));
|
||||
}
|
||||
|
||||
private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest request) {
|
||||
return Mono.just(ClientRequest.from(request))
|
||||
.flatMap(builder -> Mono.subscriberContext()
|
||||
.map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx))))
|
||||
.map(ClientRequest.Builder::build);
|
||||
}
|
||||
|
||||
private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
|
||||
if (ctx.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) {
|
||||
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ctx.get(HTTP_SERVLET_REQUEST_ATTR_NAME));
|
||||
}
|
||||
if (ctx.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
|
||||
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ctx.get(HTTP_SERVLET_RESPONSE_ATTR_NAME));
|
||||
}
|
||||
if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) {
|
||||
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME));
|
||||
}
|
||||
populateDefaultOAuth2AuthorizedClient(attrs);
|
||||
}
|
||||
|
||||
private void populateDefaultRequestResponse(Map<String, Object> attrs) {
|
||||
if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey(
|
||||
HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
|
||||
|
@ -435,6 +478,19 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
|
|||
.build();
|
||||
}
|
||||
|
||||
private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> delegate) {
|
||||
HttpServletRequest request = null;
|
||||
HttpServletResponse response = null;
|
||||
ServletRequestAttributes requestAttributes =
|
||||
(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
|
||||
if (requestAttributes != null) {
|
||||
request = requestAttributes.getRequest();
|
||||
response = requestAttributes.getResponse();
|
||||
}
|
||||
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
|
||||
return new RequestContextSubscriber<>(delegate, request, response, authentication);
|
||||
}
|
||||
|
||||
private static BodyInserters.FormInserter<String> refreshTokenBody(String refreshToken) {
|
||||
return BodyInserters
|
||||
.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue())
|
||||
|
@ -508,4 +564,55 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
|
|||
return new UnsupportedOperationException("Not Supported");
|
||||
}
|
||||
}
|
||||
|
||||
private static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
|
||||
private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME");
|
||||
private final CoreSubscriber<T> delegate;
|
||||
private final HttpServletRequest request;
|
||||
private final HttpServletResponse response;
|
||||
private final Authentication authentication;
|
||||
|
||||
private RequestContextSubscriber(CoreSubscriber<T> delegate,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response,
|
||||
Authentication authentication) {
|
||||
this.delegate = delegate;
|
||||
this.request = request;
|
||||
this.response = response;
|
||||
this.authentication = authentication;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Context currentContext() {
|
||||
Context context = this.delegate.currentContext();
|
||||
if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) {
|
||||
return context;
|
||||
}
|
||||
return Context.of(
|
||||
CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE,
|
||||
HTTP_SERVLET_REQUEST_ATTR_NAME, this.request,
|
||||
HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response,
|
||||
AUTHENTICATION_ATTR_NAME, this.authentication);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSubscribe(Subscription s) {
|
||||
this.delegate.onSubscribe(s);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNext(T t) {
|
||||
this.delegate.onNext(t);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable t) {
|
||||
this.delegate.onError(t);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
this.delegate.onComplete();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -74,14 +74,11 @@ import java.util.Map;
|
|||
import java.util.Optional;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatCode;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyZeroInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.springframework.http.HttpMethod.GET;
|
||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.*;
|
||||
|
||||
|
@ -647,6 +644,121 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|||
assertThat(getBody(request0)).isEmpty();
|
||||
}
|
||||
|
||||
// gh-6483
|
||||
@Test
|
||||
public void filterWhenChainedThenDefaultsStillAvailable() throws Exception {
|
||||
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
|
||||
this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
|
||||
this.function.setDefaultOAuth2AuthorizedClient(true);
|
||||
|
||||
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
|
||||
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
|
||||
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
|
||||
|
||||
OAuth2User user = mock(OAuth2User.class);
|
||||
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
|
||||
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
|
||||
user, authorities, this.registration.getRegistrationId());
|
||||
SecurityContextHolder.getContext().setAuthentication(authentication);
|
||||
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
|
||||
this.registration, "principalName", this.accessToken);
|
||||
when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()),
|
||||
eq(authentication), eq(servletRequest))).thenReturn(authorizedClient);
|
||||
|
||||
// Default request attributes set
|
||||
final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com"))
|
||||
.attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build();
|
||||
|
||||
// Default request attributes NOT set
|
||||
final ClientRequest request2 = ClientRequest.create(GET, URI.create("https://example2.com")).build();
|
||||
|
||||
this.function.filter(request1, this.exchange)
|
||||
.flatMap(response -> this.function.filter(request2, this.exchange))
|
||||
.block();
|
||||
|
||||
this.function.destroy(); // Hooks.onLastOperator() released
|
||||
|
||||
List<ClientRequest> requests = this.exchange.getRequests();
|
||||
assertThat(requests).hasSize(2);
|
||||
|
||||
ClientRequest request = requests.get(0);
|
||||
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
|
||||
assertThat(request.url().toASCIIString()).isEqualTo("https://example1.com");
|
||||
assertThat(request.method()).isEqualTo(HttpMethod.GET);
|
||||
assertThat(getBody(request)).isEmpty();
|
||||
|
||||
request = requests.get(1);
|
||||
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
|
||||
assertThat(request.url().toASCIIString()).isEqualTo("https://example2.com");
|
||||
assertThat(request.method()).isEqualTo(HttpMethod.GET);
|
||||
assertThat(getBody(request)).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void filterWhenRequestAttributesNotSetAndHooksNotInitThenDefaultsNotAvailable() throws Exception {
|
||||
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
|
||||
this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||
// this.function.afterPropertiesSet(); // Hooks.onLastOperator() NOT initialized
|
||||
this.function.setDefaultOAuth2AuthorizedClient(true);
|
||||
|
||||
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
|
||||
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
|
||||
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
|
||||
|
||||
OAuth2User user = mock(OAuth2User.class);
|
||||
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
|
||||
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
|
||||
user, authorities, this.registration.getRegistrationId());
|
||||
SecurityContextHolder.getContext().setAuthentication(authentication);
|
||||
|
||||
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
|
||||
|
||||
this.function.filter(request, this.exchange).block();
|
||||
|
||||
List<ClientRequest> requests = this.exchange.getRequests();
|
||||
assertThat(requests).hasSize(1);
|
||||
|
||||
request = requests.get(0);
|
||||
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
|
||||
assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
|
||||
assertThat(request.method()).isEqualTo(HttpMethod.GET);
|
||||
assertThat(getBody(request)).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsNotAvailable() throws Exception {
|
||||
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
|
||||
this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
|
||||
this.function.destroy(); // Hooks.onLastOperator() released
|
||||
this.function.setDefaultOAuth2AuthorizedClient(true);
|
||||
|
||||
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
|
||||
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
|
||||
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
|
||||
|
||||
OAuth2User user = mock(OAuth2User.class);
|
||||
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
|
||||
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
|
||||
user, authorities, this.registration.getRegistrationId());
|
||||
SecurityContextHolder.getContext().setAuthentication(authentication);
|
||||
|
||||
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
|
||||
|
||||
this.function.filter(request, this.exchange).block();
|
||||
|
||||
List<ClientRequest> requests = this.exchange.getRequests();
|
||||
assertThat(requests).hasSize(1);
|
||||
|
||||
request = requests.get(0);
|
||||
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
|
||||
assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
|
||||
assertThat(request.method()).isEqualTo(HttpMethod.GET);
|
||||
assertThat(getBody(request)).isEmpty();
|
||||
}
|
||||
|
||||
private static String getBody(ClientRequest request) {
|
||||
final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
|
||||
messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));
|
||||
|
|
Loading…
Reference in New Issue