ServerOAuth2AuthorizedClientExchangeFilterFunction works with UnAuthenticatedServerOAuth2AuthorizedClientRepository

Fixes gh-7544
This commit is contained in:
Joe Grandja 2019-11-27 16:10:48 -05:00
parent 07b8aa0b1f
commit 80f256e425
2 changed files with 104 additions and 1 deletions

View File

@ -22,6 +22,7 @@ import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.client.ClientCredentialsReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager;
@ -35,6 +36,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.UnAuthenticatedServerOAuth2AuthorizedClientRepository;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
@ -124,6 +126,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
.clientCredentials()
.password()
.build();
// gh-7544
if (authorizedClientRepository instanceof UnAuthenticatedServerOAuth2AuthorizedClientRepository) {
UnAuthenticatedReactiveOAuth2AuthorizedClientManager unauthenticatedAuthorizedClientManager =
new UnAuthenticatedReactiveOAuth2AuthorizedClientManager(
clientRegistrationRepository,
(UnAuthenticatedServerOAuth2AuthorizedClientRepository) authorizedClientRepository);
unauthenticatedAuthorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
return unauthenticatedAuthorizedClientManager;
}
DefaultReactiveOAuth2AuthorizedClientManager authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager(
clientRegistrationRepository, authorizedClientRepository);
authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
@ -266,7 +279,11 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
.clientCredentials(this::updateClientCredentialsProvider)
.password(configurer -> configurer.clockSkew(this.accessTokenExpiresSkew))
.build();
((DefaultReactiveOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider);
if (this.authorizedClientManager instanceof UnAuthenticatedReactiveOAuth2AuthorizedClientManager) {
((UnAuthenticatedReactiveOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider);
} else {
((DefaultReactiveOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider);
}
}
private void updateClientCredentialsProvider(ReactiveOAuth2AuthorizedClientProviderBuilder.ClientCredentialsGrantBuilder builder) {
@ -376,4 +393,52 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
.headers(headers -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue()))
.build();
}
private static class UnAuthenticatedReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager {
private final ReactiveClientRegistrationRepository clientRegistrationRepository;
private final UnAuthenticatedServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider;
private UnAuthenticatedReactiveOAuth2AuthorizedClientManager(
ReactiveClientRegistrationRepository clientRegistrationRepository,
UnAuthenticatedServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
this.clientRegistrationRepository = clientRegistrationRepository;
this.authorizedClientRepository = authorizedClientRepository;
}
@Override
public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizeRequest authorizeRequest) {
Assert.notNull(authorizeRequest, "authorizeRequest cannot be null");
String clientRegistrationId = authorizeRequest.getClientRegistrationId();
Authentication principal = authorizeRequest.getPrincipal();
return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient())
.switchIfEmpty(Mono.defer(() -> this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, null)))
.flatMap(authorizedClient -> {
// Re-authorize
return Mono.just(OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient).principal(principal).build())
.flatMap(this.authorizedClientProvider::authorize)
.flatMap(reauthorizedClient -> this.authorizedClientRepository.saveAuthorizedClient(reauthorizedClient, principal, null).thenReturn(reauthorizedClient))
// Default to the existing authorizedClient if the client was not re-authorized
.defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ?
authorizeRequest.getAuthorizedClient() : authorizedClient);
})
.switchIfEmpty(Mono.deferWithContext(context ->
// Authorize
this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException(
"Could not find ClientRegistration with id '" + clientRegistrationId + "'")))
.flatMap(clientRegistration -> Mono.just(OAuth2AuthorizationContext.withClientRegistration(clientRegistration).principal(principal).build()))
.flatMap(this.authorizedClientProvider::authorize)
.flatMap(authorizedClient -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null).thenReturn(authorizedClient))
.subscriberContext(context)
));
}
private void setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider) {
Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null");
this.authorizedClientProvider = authorizedClientProvider;
}
}
}

View File

@ -57,6 +57,7 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.UnAuthenticatedServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@ -587,6 +588,43 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
verify(this.authorizedClientRepository).loadAuthorizedClient(eq(this.registration.getRegistrationId()), any(), eq(this.serverWebExchange));
}
// gh-7544
@Test
public void filterWhenClientCredentialsClientNotAuthorizedAndOutsideRequestContextThenGetNewToken() {
// Use UnAuthenticatedServerOAuth2AuthorizedClientRepository when operating outside of a request context
ServerOAuth2AuthorizedClientRepository unauthenticatedAuthorizedClientRepository = spy(new UnAuthenticatedServerOAuth2AuthorizedClientRepository());
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(
this.clientRegistrationRepository, unauthenticatedAuthorizedClientRepository);
this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("new-token")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(360)
.build();
when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
when(this.clientRegistrationRepository.findByRegistrationId(eq(registration.getRegistrationId()))).thenReturn(Mono.just(registration));
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(clientRegistrationId(registration.getRegistrationId()))
.build();
this.function.filter(request, this.exchange).block();
verify(unauthenticatedAuthorizedClientRepository).loadAuthorizedClient(any(), any(), any());
verify(this.clientCredentialsTokenResponseClient).getTokenResponse(any());
verify(unauthenticatedAuthorizedClientRepository).saveAuthorizedClient(any(), any(), any());
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
ClientRequest request1 = requests.get(0);
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer new-token");
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request1)).isEmpty();
}
private Context serverWebExchange() {
return Context.of(ServerWebExchange.class, this.serverWebExchange);
}