ServerOAuth2AuthorizedClientExchangeFilterFunction default ServerWebExchange

Leverage ServerWebExchange established by ServerWebExchangeReactorContextWebFilter

Issue: gh-4921
This commit is contained in:
Rob Winch 2018-09-05 20:27:04 -05:00
parent ac78258847
commit 23726abb1e
2 changed files with 47 additions and 6 deletions

View File

@ -16,6 +16,7 @@
package org.springframework.security.oauth2.client.web.reactive.function.client;
import com.sun.security.ntlm.Server;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
@ -211,14 +212,25 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
@Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
return authorizedClient(request)
.flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, exchange))
.flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, request))
.map(authorizedClient -> bearer(request, authorizedClient))
.flatMap(next::exchange)
.switchIfEmpty(next.exchange(request));
}
private Mono<ServerWebExchange> serverWebExchange(ClientRequest request) {
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
return Mono.justOrEmpty(exchange)
.switchIfEmpty(serverWebExchange());
}
private Mono<ServerWebExchange> serverWebExchange() {
return Mono.subscriberContext()
.filter(c -> c.hasKey(ServerWebExchange.class))
.map(c -> c.get(ServerWebExchange.class));
}
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
.map(OAuth2AuthorizedClient.class::cast);
@ -231,10 +243,9 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
return Mono.empty();
}
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
return currentAuthentication()
.flatMap(principal -> clientRegistrationId(request, principal)
.flatMap(clientRegistrationId -> loadAuthorizedClient(clientRegistrationId, exchange, principal))
.flatMap(clientRegistrationId -> serverWebExchange(request).flatMap(exchange -> loadAuthorizedClient(clientRegistrationId, exchange, principal)))
);
}
@ -289,9 +300,10 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
});
}
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ClientRequest request) {
if (shouldRefresh(authorizedClient)) {
return refreshAuthorizedClient(next, authorizedClient, exchange);
return serverWebExchange(request)
.flatMap(exchange -> refreshAuthorizedClient(next, authorizedClient, exchange));
}
return Mono.just(authorizedClient);
}

View File

@ -49,7 +49,9 @@ import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;
import java.net.URI;
import java.time.Duration;
@ -83,6 +85,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Mock
private ReactiveClientRegistrationRepository clientRegistrationRepository;
@Mock
private ServerWebExchange serverWebExchange;
private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
private MockExchangeFunction exchange = new MockExchangeFunction();
@ -352,6 +357,30 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
verifyZeroInteractions(this.clientRegistrationRepository, this.authorizedClientRepository);
}
@Test
public void filterWhenClientRegistrationIdAndServerWebExchangeFromContextThenServerWebExchangeFromContext() {
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, refreshToken);
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(clientRegistrationId(this.registration.getRegistrationId()))
.build();
this.function.filter(request, this.exchange)
.subscriberContext(serverWebExchange())
.block();
verify(this.authorizedClientRepository).loadAuthorizedClient(eq(this.registration.getRegistrationId()), any(), eq(this.serverWebExchange));
}
private Context serverWebExchange() {
return Context.of(ServerWebExchange.class, this.serverWebExchange);
}
private static String getBody(ClientRequest request) {
final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));