Reactive Implementation of AuthorizedClientServiceOAuth2AuthorizedClientManager

ReactiveOAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager is reactive
version of AuthorizedClientServiceOAuth2AuthorizedClientManager

Fixes: gh-7569
This commit is contained in:
Ankur Pathak 2019-10-31 19:21:39 +05:30 committed by Joe Grandja
parent 0c47bfb1e3
commit c29309d744
2 changed files with 447 additions and 0 deletions

View File

@ -0,0 +1,147 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
/**
* An implementation of an {@link ReactiveOAuth2AuthorizedClientManager}
* that is capable of operating outside of a {@code ServerHttpRequest} context,
* e.g. in a scheduled/background thread and/or in the service-tier.
*
* @author Ankur Pathak
* @see ReactiveOAuth2AuthorizedClientManager
* @see ReactiveOAuth2AuthorizedClientProvider
* @see ReactiveOAuth2AuthorizedClientService
* @since 5.3
*/
public final class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager {
private final ReactiveClientRegistrationRepository clientRegistrationRepository;
private final ReactiveOAuth2AuthorizedClientService authorizedClientService;
private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = context -> Mono.empty();
private Function<OAuth2AuthorizeRequest, Mono<Map<String, Object>>> contextAttributesMapper = new DefaultContextAttributesMapper();
/**
* Constructs an {@code OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} using the provided parameters.
*
* @param clientRegistrationRepository the repository of client registrations
* @param authorizedClientService the authorized client service
*/
public OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(ReactiveClientRegistrationRepository clientRegistrationRepository,
ReactiveOAuth2AuthorizedClientService authorizedClientService) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
this.clientRegistrationRepository = clientRegistrationRepository;
this.authorizedClientService = authorizedClientService;
}
@Nullable
@Override
public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizeRequest authorizeRequest) {
Assert.notNull(authorizeRequest, "authorizeRequest cannot be null");
String clientRegistrationId = authorizeRequest.getClientRegistrationId();
OAuth2AuthorizedClient authorizedClient = authorizeRequest.getAuthorizedClient();
Authentication principal = authorizeRequest.getPrincipal();
// @formatter:off
return Mono.justOrEmpty(authorizedClient)
.map(OAuth2AuthorizationContext::withAuthorizedClient)
.switchIfEmpty(Mono.defer(() -> this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.flatMap(clientRegistration -> this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName())
.map(OAuth2AuthorizationContext::withAuthorizedClient)
.switchIfEmpty(Mono.fromSupplier(() -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration)))
)
.switchIfEmpty(Mono.error(new IllegalArgumentException("Could not find ClientRegistration with id '" + clientRegistrationId + "'")))
)
)
.flatMap(contextBuilder -> this.contextAttributesMapper.apply(authorizeRequest)
.filter(contextAttributes-> !CollectionUtils.isEmpty(contextAttributes))
.map(contextAttributes -> contextBuilder.principal(principal)
.attributes(attributes -> {
attributes.putAll(contextAttributes);
}).build())
).flatMap(authorizationContext -> this.authorizedClientProvider.authorize(authorizationContext)
.doOnNext(_authorizedClient -> authorizedClientService.saveAuthorizedClient(_authorizedClient, principal))
.switchIfEmpty(Mono.defer(()-> Mono.justOrEmpty(Optional.ofNullable(authorizationContext.getAuthorizedClient()))))
);
// @formatter:on
}
/**
* Sets the {@link ReactiveOAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client.
*
* @param authorizedClientProvider the {@link ReactiveOAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client
*/
public void setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider) {
Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null");
this.authorizedClientProvider = authorizedClientProvider;
}
/**
* Sets the {@code Function} used for mapping attribute(s) from the {@link OAuth2AuthorizeRequest} to a {@code Map} of attributes
* to be associated to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}.
*
* @param contextAttributesMapper the {@code Function} used for supplying the {@code Map} of attributes
* to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}
*/
public void setContextAttributesMapper(Function<OAuth2AuthorizeRequest, Mono<Map<String, Object>>> contextAttributesMapper) {
Assert.notNull(contextAttributesMapper, "contextAttributesMapper cannot be null");
this.contextAttributesMapper = contextAttributesMapper;
}
private static Mono<ServerWebExchange> currentServerWebExchange() {
return Mono.subscriberContext()
.filter(c -> c.hasKey(ServerWebExchange.class))
.map(c -> c.get(ServerWebExchange.class));
}
/**
* The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}.
*/
public static class DefaultContextAttributesMapper implements Function<OAuth2AuthorizeRequest, Mono<Map<String, Object>>> {
@Override
public Mono<Map<String, Object>> apply(OAuth2AuthorizeRequest authorizeRequest) {
ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
return Mono.justOrEmpty(serverWebExchange)
.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
.flatMap(exchange -> {
Map<String, Object> contextAttributes = Collections.emptyMap();
String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
if (StringUtils.hasText(scope)) {
contextAttributes = new HashMap<>();
contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
StringUtils.delimitedListToStringArray(scope, " "));
}
return Mono.just(contextAttributes);
})
.defaultIfEmpty(Collections.emptyMap());
}
}
}

View File

@ -0,0 +1,300 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import java.util.function.Function;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
/**
* Tests for {@link OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager}.
*
* @author Ankur Pathak
*/
public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
private ReactiveClientRegistrationRepository clientRegistrationRepository;
private ReactiveOAuth2AuthorizedClientService authorizedClientService;
private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider;
private Function contextAttributesMapper;
private OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager authorizedClientManager;
private ClientRegistration clientRegistration;
private Authentication principal;
private OAuth2AuthorizedClient authorizedClient;
private ArgumentCaptor<OAuth2AuthorizationContext> authorizationContextCaptor;
@SuppressWarnings("unchecked")
@Before
public void setup() {
this.clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class);
this.authorizedClientService = mock(ReactiveOAuth2AuthorizedClientService.class);
this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class);
this.contextAttributesMapper = mock(Function.class);
this.authorizedClientManager = new OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(
this.clientRegistrationRepository, this.authorizedClientService);
this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider);
this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper);
this.clientRegistration = TestClientRegistrations.clientRegistration().build();
this.principal = new TestingAuthenticationToken("principal", "password");
this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(),
TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken());
this.authorizationContextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationContext.class);
}
@Test
public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(null, this.authorizedClientService))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("reactiveClientRegistrationRepository cannot be null");
}
@Test
public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("reactiveAuthorizedClientService cannot be null");
}
@Test
public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("reactiveAuthorizedClientProvider cannot be null");
}
@Test
public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("contextAttributesMapper cannot be null");
}
@Test
public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.authorize(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizeRequest cannot be null");
}
@Test
public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() {
String clientRegistrationId = "invalid-registration-id";
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(clientRegistrationId)
.principal(this.principal)
.build();
when(this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)).thenReturn(Mono.empty());
StepVerifier.create(this.authorizedClientManager.authorize(authorizeRequest))
.verifyError(IllegalArgumentException.class);
}
@SuppressWarnings("unchecked")
@Test
public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() {
when(this.clientRegistrationRepository.findByRegistrationId(
eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
when(this.authorizedClientService.loadAuthorizedClient(
any(), any())).thenReturn(Mono.empty());
when(authorizedClientProvider.authorize(any())).thenReturn(Mono.empty());
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build();
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
authorizedClient.subscribe();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isNull();
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
StepVerifier.create(authorizedClient).expectComplete();
verify(this.authorizedClientService, never()).saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), eq(this.principal));
}
@SuppressWarnings("unchecked")
@Test
public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() {
when(this.clientRegistrationRepository.findByRegistrationId(
eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
when(this.authorizedClientService.loadAuthorizedClient(
any(), any())).thenReturn(Mono.empty());
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build();
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
authorizedClient.subscribe();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isNull();
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient));
verify(this.authorizedClientService).saveAuthorizedClient(
eq(this.authorizedClient), eq(this.principal));
}
@SuppressWarnings("unchecked")
@Test
public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() {
when(this.clientRegistrationRepository.findByRegistrationId(
eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
when(this.authorizedClientService.loadAuthorizedClient(
eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName()))).thenReturn(Mono.just(this.authorizedClient));
OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(
this.clientRegistration, this.principal.getName(),
TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken());
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient));
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build();
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
authorizedClient.subscribe();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient));
verify(this.authorizedClientService).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal));
}
@SuppressWarnings("unchecked")
@Test
public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() {
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.empty());
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.build();
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);
authorizedClient.subscribe();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient));
verify(this.authorizedClientService, never()).saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), eq(this.principal));
}
@SuppressWarnings("unchecked")
@Test
public void reauthorizeWhenSupportedProviderThenReauthorized() {
OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(
this.clientRegistration, this.principal.getName(),
TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken());
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient));
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.build();
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);
authorizedClient.subscribe();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient));
verify(this.authorizedClientService).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal));
}
@SuppressWarnings("unchecked")
@Test
public void reauthorizeWhenRequestAttributeScopeThenMappedToContext() {
OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(
this.clientRegistration, this.principal.getName(),
TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken());
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient));
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.attribute(OAuth2ParameterNames.SCOPE, "read write")
.build();
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);
authorizedClient.subscribe();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizationContext.getAttributes()).containsKey(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
assertThat(requestScopeAttribute).contains("read", "write");
StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient));
verify(this.authorizedClientService).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal));
}
}