ServerOAuth2AuthorizedClientExchangeFilterFunction default ServerWebExchange
Leverage ServerWebExchange established by ServerWebExchangeReactorContextWebFilter Issue: gh-4921
This commit is contained in:
parent
ac78258847
commit
23726abb1e
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
||||
|
||||
import com.sun.security.ntlm.Server;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
|
@ -211,14 +212,25 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|||
|
||||
@Override
|
||||
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
|
||||
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
|
||||
return authorizedClient(request)
|
||||
.flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, exchange))
|
||||
.flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, request))
|
||||
.map(authorizedClient -> bearer(request, authorizedClient))
|
||||
.flatMap(next::exchange)
|
||||
.switchIfEmpty(next.exchange(request));
|
||||
}
|
||||
|
||||
private Mono<ServerWebExchange> serverWebExchange(ClientRequest request) {
|
||||
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
|
||||
return Mono.justOrEmpty(exchange)
|
||||
.switchIfEmpty(serverWebExchange());
|
||||
}
|
||||
|
||||
private Mono<ServerWebExchange> serverWebExchange() {
|
||||
return Mono.subscriberContext()
|
||||
.filter(c -> c.hasKey(ServerWebExchange.class))
|
||||
.map(c -> c.get(ServerWebExchange.class));
|
||||
}
|
||||
|
||||
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
|
||||
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
|
||||
.map(OAuth2AuthorizedClient.class::cast);
|
||||
|
@ -231,10 +243,9 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|||
return Mono.empty();
|
||||
}
|
||||
|
||||
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
|
||||
return currentAuthentication()
|
||||
.flatMap(principal -> clientRegistrationId(request, principal)
|
||||
.flatMap(clientRegistrationId -> loadAuthorizedClient(clientRegistrationId, exchange, principal))
|
||||
.flatMap(clientRegistrationId -> serverWebExchange(request).flatMap(exchange -> loadAuthorizedClient(clientRegistrationId, exchange, principal)))
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -289,9 +300,10 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|||
});
|
||||
}
|
||||
|
||||
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
|
||||
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ClientRequest request) {
|
||||
if (shouldRefresh(authorizedClient)) {
|
||||
return refreshAuthorizedClient(next, authorizedClient, exchange);
|
||||
return serverWebExchange(request)
|
||||
.flatMap(exchange -> refreshAuthorizedClient(next, authorizedClient, exchange));
|
||||
}
|
||||
return Mono.just(authorizedClient);
|
||||
}
|
||||
|
|
|
@ -49,7 +49,9 @@ import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
|
|||
import org.springframework.security.oauth2.core.user.OAuth2User;
|
||||
import org.springframework.web.reactive.function.BodyInserter;
|
||||
import org.springframework.web.reactive.function.client.ClientRequest;
|
||||
import org.springframework.web.server.ServerWebExchange;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.util.context.Context;
|
||||
|
||||
import java.net.URI;
|
||||
import java.time.Duration;
|
||||
|
@ -83,6 +85,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|||
@Mock
|
||||
private ReactiveClientRegistrationRepository clientRegistrationRepository;
|
||||
|
||||
@Mock
|
||||
private ServerWebExchange serverWebExchange;
|
||||
|
||||
private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
|
||||
|
||||
private MockExchangeFunction exchange = new MockExchangeFunction();
|
||||
|
@ -352,6 +357,30 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|||
verifyZeroInteractions(this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void filterWhenClientRegistrationIdAndServerWebExchangeFromContextThenServerWebExchangeFromContext() {
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||
|
||||
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||
"principalName", this.accessToken, refreshToken);
|
||||
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
|
||||
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
|
||||
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
|
||||
.attributes(clientRegistrationId(this.registration.getRegistrationId()))
|
||||
.build();
|
||||
|
||||
this.function.filter(request, this.exchange)
|
||||
.subscriberContext(serverWebExchange())
|
||||
.block();
|
||||
|
||||
verify(this.authorizedClientRepository).loadAuthorizedClient(eq(this.registration.getRegistrationId()), any(), eq(this.serverWebExchange));
|
||||
}
|
||||
|
||||
private Context serverWebExchange() {
|
||||
return Context.of(ServerWebExchange.class, this.serverWebExchange);
|
||||
}
|
||||
|
||||
private static String getBody(ClientRequest request) {
|
||||
final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
|
||||
messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));
|
||||
|
|
Loading…
Reference in New Issue