ServerOAuth2AuthorizedClientExchangeFilterFunction default ServerWebExchange
Leverage ServerWebExchange established by ServerWebExchangeReactorContextWebFilter Issue: gh-4921
This commit is contained in:
parent
ac78258847
commit
23726abb1e
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
||||||
|
|
||||||
|
import com.sun.security.ntlm.Server;
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.http.HttpMethod;
|
import org.springframework.http.HttpMethod;
|
||||||
import org.springframework.http.MediaType;
|
import org.springframework.http.MediaType;
|
||||||
|
@ -211,14 +212,25 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
|
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
|
||||||
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
|
|
||||||
return authorizedClient(request)
|
return authorizedClient(request)
|
||||||
.flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, exchange))
|
.flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, request))
|
||||||
.map(authorizedClient -> bearer(request, authorizedClient))
|
.map(authorizedClient -> bearer(request, authorizedClient))
|
||||||
.flatMap(next::exchange)
|
.flatMap(next::exchange)
|
||||||
.switchIfEmpty(next.exchange(request));
|
.switchIfEmpty(next.exchange(request));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Mono<ServerWebExchange> serverWebExchange(ClientRequest request) {
|
||||||
|
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
|
||||||
|
return Mono.justOrEmpty(exchange)
|
||||||
|
.switchIfEmpty(serverWebExchange());
|
||||||
|
}
|
||||||
|
|
||||||
|
private Mono<ServerWebExchange> serverWebExchange() {
|
||||||
|
return Mono.subscriberContext()
|
||||||
|
.filter(c -> c.hasKey(ServerWebExchange.class))
|
||||||
|
.map(c -> c.get(ServerWebExchange.class));
|
||||||
|
}
|
||||||
|
|
||||||
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
|
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
|
||||||
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
|
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
|
||||||
.map(OAuth2AuthorizedClient.class::cast);
|
.map(OAuth2AuthorizedClient.class::cast);
|
||||||
|
@ -231,10 +243,9 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
||||||
return Mono.empty();
|
return Mono.empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
|
|
||||||
return currentAuthentication()
|
return currentAuthentication()
|
||||||
.flatMap(principal -> clientRegistrationId(request, principal)
|
.flatMap(principal -> clientRegistrationId(request, principal)
|
||||||
.flatMap(clientRegistrationId -> loadAuthorizedClient(clientRegistrationId, exchange, principal))
|
.flatMap(clientRegistrationId -> serverWebExchange(request).flatMap(exchange -> loadAuthorizedClient(clientRegistrationId, exchange, principal)))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -289,9 +300,10 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
|
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ClientRequest request) {
|
||||||
if (shouldRefresh(authorizedClient)) {
|
if (shouldRefresh(authorizedClient)) {
|
||||||
return refreshAuthorizedClient(next, authorizedClient, exchange);
|
return serverWebExchange(request)
|
||||||
|
.flatMap(exchange -> refreshAuthorizedClient(next, authorizedClient, exchange));
|
||||||
}
|
}
|
||||||
return Mono.just(authorizedClient);
|
return Mono.just(authorizedClient);
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,7 +49,9 @@ import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
|
||||||
import org.springframework.security.oauth2.core.user.OAuth2User;
|
import org.springframework.security.oauth2.core.user.OAuth2User;
|
||||||
import org.springframework.web.reactive.function.BodyInserter;
|
import org.springframework.web.reactive.function.BodyInserter;
|
||||||
import org.springframework.web.reactive.function.client.ClientRequest;
|
import org.springframework.web.reactive.function.client.ClientRequest;
|
||||||
|
import org.springframework.web.server.ServerWebExchange;
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
|
import reactor.util.context.Context;
|
||||||
|
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
|
@ -83,6 +85,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||||
@Mock
|
@Mock
|
||||||
private ReactiveClientRegistrationRepository clientRegistrationRepository;
|
private ReactiveClientRegistrationRepository clientRegistrationRepository;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private ServerWebExchange serverWebExchange;
|
||||||
|
|
||||||
private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
|
private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
|
||||||
|
|
||||||
private MockExchangeFunction exchange = new MockExchangeFunction();
|
private MockExchangeFunction exchange = new MockExchangeFunction();
|
||||||
|
@ -352,6 +357,30 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||||
verifyZeroInteractions(this.clientRegistrationRepository, this.authorizedClientRepository);
|
verifyZeroInteractions(this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void filterWhenClientRegistrationIdAndServerWebExchangeFromContextThenServerWebExchangeFromContext() {
|
||||||
|
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||||
|
|
||||||
|
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
|
||||||
|
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||||
|
"principalName", this.accessToken, refreshToken);
|
||||||
|
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
|
||||||
|
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
|
||||||
|
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
|
||||||
|
.attributes(clientRegistrationId(this.registration.getRegistrationId()))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
this.function.filter(request, this.exchange)
|
||||||
|
.subscriberContext(serverWebExchange())
|
||||||
|
.block();
|
||||||
|
|
||||||
|
verify(this.authorizedClientRepository).loadAuthorizedClient(eq(this.registration.getRegistrationId()), any(), eq(this.serverWebExchange));
|
||||||
|
}
|
||||||
|
|
||||||
|
private Context serverWebExchange() {
|
||||||
|
return Context.of(ServerWebExchange.class, this.serverWebExchange);
|
||||||
|
}
|
||||||
|
|
||||||
private static String getBody(ClientRequest request) {
|
private static String getBody(ClientRequest request) {
|
||||||
final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
|
final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
|
||||||
messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));
|
messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));
|
||||||
|
|
Loading…
Reference in New Issue