Rename OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager to AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.

Handle empty mono returned from contextAttributesMapper.

Handle empty map returned from contextAttributesMapper.

Fix DefaultContextAttributesMapper so that it doesn't access ServerWebExchange.

Fix unit tests so that they pass.

Use StepVerifier in unit tests, rather than .subscribe().

Fixes gh-7569
This commit is contained in:
Phil Clay 2019-12-05 13:14:06 -08:00 committed by Joe Grandja
parent c29309d744
commit cffad1be02
2 changed files with 87 additions and 80 deletions

View File

@ -15,20 +15,13 @@
*/
package org.springframework.security.oauth2.client;
import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
/**
@ -36,25 +29,31 @@ import java.util.function.Function;
* that is capable of operating outside of a {@code ServerHttpRequest} context,
* e.g. in a scheduled/background thread and/or in the service-tier.
*
* <p>This is a reactive equivalent of {@link org.springframework.security.oauth2.client.AuthorizedClientServiceOAuth2AuthorizedClientManager}</p>
*
* @author Ankur Pathak
* @author Phil Clay
* @see ReactiveOAuth2AuthorizedClientManager
* @see ReactiveOAuth2AuthorizedClientProvider
* @see ReactiveOAuth2AuthorizedClientService
* @since 5.3
* @since 5.2.2
*/
public final class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager {
public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager
implements ReactiveOAuth2AuthorizedClientManager {
private final ReactiveClientRegistrationRepository clientRegistrationRepository;
private final ReactiveOAuth2AuthorizedClientService authorizedClientService;
private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = context -> Mono.empty();
private Function<OAuth2AuthorizeRequest, Mono<Map<String, Object>>> contextAttributesMapper = new DefaultContextAttributesMapper();
/**
* Constructs an {@code OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} using the provided parameters.
* Constructs an {@code AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} using the provided parameters.
*
* @param clientRegistrationRepository the repository of client registrations
* @param authorizedClientService the authorized client service
*/
public OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(ReactiveClientRegistrationRepository clientRegistrationRepository,
public AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(
ReactiveClientRegistrationRepository clientRegistrationRepository,
ReactiveOAuth2AuthorizedClientService authorizedClientService) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
@ -62,35 +61,42 @@ public final class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientMa
this.authorizedClientService = authorizedClientService;
}
@Nullable
@Override
public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizeRequest authorizeRequest) {
Assert.notNull(authorizeRequest, "authorizeRequest cannot be null");
return createAuthorizationContext(authorizeRequest)
.flatMap(this::authorizeAndSave);
}
private Mono<OAuth2AuthorizationContext> createAuthorizationContext(OAuth2AuthorizeRequest authorizeRequest) {
String clientRegistrationId = authorizeRequest.getClientRegistrationId();
OAuth2AuthorizedClient authorizedClient = authorizeRequest.getAuthorizedClient();
Authentication principal = authorizeRequest.getPrincipal();
// @formatter:off
return Mono.justOrEmpty(authorizedClient)
return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient())
.map(OAuth2AuthorizationContext::withAuthorizedClient)
.switchIfEmpty(Mono.defer(() -> this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.flatMap(clientRegistration -> this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName())
.map(OAuth2AuthorizationContext::withAuthorizedClient)
.switchIfEmpty(Mono.fromSupplier(() -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration)))
)
.switchIfEmpty(Mono.error(new IllegalArgumentException("Could not find ClientRegistration with id '" + clientRegistrationId + "'")))
)
)
.flatMap(clientRegistration -> this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName())
.map(OAuth2AuthorizationContext::withAuthorizedClient)
.switchIfEmpty(Mono.fromSupplier(() -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration))))
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Could not find ClientRegistration with id '" + clientRegistrationId + "'")))))
.flatMap(contextBuilder -> this.contextAttributesMapper.apply(authorizeRequest)
.filter(contextAttributes-> !CollectionUtils.isEmpty(contextAttributes))
.map(contextAttributes -> contextBuilder.principal(principal)
.attributes(attributes -> {
attributes.putAll(contextAttributes);
}).build())
).flatMap(authorizationContext -> this.authorizedClientProvider.authorize(authorizationContext)
.doOnNext(_authorizedClient -> authorizedClientService.saveAuthorizedClient(_authorizedClient, principal))
.switchIfEmpty(Mono.defer(()-> Mono.justOrEmpty(Optional.ofNullable(authorizationContext.getAuthorizedClient()))))
);
// @formatter:on
.defaultIfEmpty(Collections.emptyMap())
.map(contextAttributes -> {
OAuth2AuthorizationContext.Builder builder = contextBuilder.principal(principal);
if (!contextAttributes.isEmpty()) {
builder = builder.attributes(attributes -> attributes.putAll(contextAttributes));
}
return builder.build();
}));
}
private Mono<OAuth2AuthorizedClient> authorizeAndSave(OAuth2AuthorizationContext authorizationContext) {
return this.authorizedClientProvider.authorize(authorizationContext)
.flatMap(authorizedClient -> this.authorizedClientService.saveAuthorizedClient(
authorizedClient,
authorizationContext.getPrincipal())
.thenReturn(authorizedClient))
.switchIfEmpty(Mono.defer(()-> Mono.justOrEmpty(authorizationContext.getAuthorizedClient())));
}
/**
@ -115,33 +121,17 @@ public final class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientMa
this.contextAttributesMapper = contextAttributesMapper;
}
private static Mono<ServerWebExchange> currentServerWebExchange() {
return Mono.subscriberContext()
.filter(c -> c.hasKey(ServerWebExchange.class))
.map(c -> c.get(ServerWebExchange.class));
}
/**
* The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}.
*/
public static class DefaultContextAttributesMapper implements Function<OAuth2AuthorizeRequest, Mono<Map<String, Object>>> {
private final AuthorizedClientServiceOAuth2AuthorizedClientManager.DefaultContextAttributesMapper mapper =
new AuthorizedClientServiceOAuth2AuthorizedClientManager.DefaultContextAttributesMapper();
@Override
public Mono<Map<String, Object>> apply(OAuth2AuthorizeRequest authorizeRequest) {
ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
return Mono.justOrEmpty(serverWebExchange)
.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
.flatMap(exchange -> {
Map<String, Object> contextAttributes = Collections.emptyMap();
String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
if (StringUtils.hasText(scope)) {
contextAttributes = new HashMap<>();
contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
StringUtils.delimitedListToStringArray(scope, " "));
}
return Mono.just(contextAttributes);
})
.defaultIfEmpty(Collections.emptyMap());
return Mono.fromCallable(() -> mapper.apply(authorizeRequest));
}
}
}

View File

@ -28,39 +28,49 @@ import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import reactor.test.publisher.PublisherProbe;
import java.util.Map;
import java.util.function.Function;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/**
* Tests for {@link OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager}.
* Tests for {@link AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager}.
*
* @author Ankur Pathak
* @author Phil Clay
*/
public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
private ReactiveClientRegistrationRepository clientRegistrationRepository;
private ReactiveOAuth2AuthorizedClientService authorizedClientService;
private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider;
private Function contextAttributesMapper;
private OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager authorizedClientManager;
private Function<OAuth2AuthorizeRequest, Mono<Map<String, Object>>> contextAttributesMapper;
private AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager authorizedClientManager;
private ClientRegistration clientRegistration;
private Authentication principal;
private OAuth2AuthorizedClient authorizedClient;
private ArgumentCaptor<OAuth2AuthorizationContext> authorizationContextCaptor;
private PublisherProbe<Void> saveAuthorizedClientProbe;
@SuppressWarnings("unchecked")
@Before
public void setup() {
this.clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class);
this.authorizedClientService = mock(ReactiveOAuth2AuthorizedClientService.class);
this.saveAuthorizedClientProbe = PublisherProbe.empty();
when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(this.saveAuthorizedClientProbe.mono());
this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class);
this.contextAttributesMapper = mock(Function.class);
this.authorizedClientManager = new OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(
when(this.contextAttributesMapper.apply(any())).thenReturn(Mono.empty());
this.authorizedClientManager = new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(
this.clientRegistrationRepository, this.authorizedClientService);
this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider);
this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper);
@ -73,23 +83,23 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT
@Test
public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(null, this.authorizedClientService))
assertThatThrownBy(() -> new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(null, this.authorizedClientService))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("reactiveClientRegistrationRepository cannot be null");
.hasMessage("clientRegistrationRepository cannot be null");
}
@Test
public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null))
assertThatThrownBy(() -> new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("reactiveAuthorizedClientService cannot be null");
.hasMessage("authorizedClientService cannot be null");
}
@Test
public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("reactiveAuthorizedClientProvider cannot be null");
.hasMessage("authorizedClientProvider cannot be null");
}
@Test
@ -132,7 +142,7 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT
.build();
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
authorizedClient.subscribe();
StepVerifier.create(authorizedClient).verifyComplete();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
@ -142,7 +152,6 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT
assertThat(authorizationContext.getAuthorizedClient()).isNull();
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
StepVerifier.create(authorizedClient).expectComplete();
verify(this.authorizedClientService, never()).saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), eq(this.principal));
}
@ -163,7 +172,9 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT
.build();
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
authorizedClient.subscribe();
StepVerifier.create(authorizedClient)
.expectNext(this.authorizedClient)
.verifyComplete();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
@ -173,9 +184,9 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT
assertThat(authorizationContext.getAuthorizedClient()).isNull();
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient));
verify(this.authorizedClientService).saveAuthorizedClient(
eq(this.authorizedClient), eq(this.principal));
this.saveAuthorizedClientProbe.assertWasSubscribed();
}
@SuppressWarnings("unchecked")
@ -197,8 +208,9 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT
.build();
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
authorizedClient.subscribe();
StepVerifier.create(authorizedClient)
.expectNext(reauthorizedClient)
.verifyComplete();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
@ -207,9 +219,9 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient));
verify(this.authorizedClientService).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal));
this.saveAuthorizedClientProbe.assertWasSubscribed();
}
@SuppressWarnings("unchecked")
@ -221,8 +233,9 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT
.build();
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);
authorizedClient.subscribe();
StepVerifier.create(authorizedClient)
.expectNext(this.authorizedClient)
.verifyComplete();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
@ -231,7 +244,6 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient));
verify(this.authorizedClientService, never()).saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), eq(this.principal));
}
@ -250,7 +262,9 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT
.build();
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);
authorizedClient.subscribe();
StepVerifier.create(authorizedClient)
.expectNext(reauthorizedClient)
.verifyComplete();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
@ -260,9 +274,9 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient));
verify(this.authorizedClientService).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal));
this.saveAuthorizedClientProbe.assertWasSubscribed();
}
@SuppressWarnings("unchecked")
@ -274,14 +288,20 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient));
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.attribute(OAuth2ParameterNames.SCOPE, "read write")
.build();
this.authorizedClientManager.setContextAttributesMapper(new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.DefaultContextAttributesMapper());
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);
authorizedClient.subscribe();
StepVerifier.create(authorizedClient)
.expectNext(reauthorizedClient)
.verifyComplete();
verify(this.authorizedClientService).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal));
this.saveAuthorizedClientProbe.assertWasSubscribed();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
@ -293,8 +313,5 @@ public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerT
String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
assertThat(requestScopeAttribute).contains("read", "write");
StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient));
verify(this.authorizedClientService).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal));
}
}