OAuth2AuthorizedClientManager implementation works outside of request

Fixes gh-6780
This commit is contained in:
Joe Grandja 2019-07-22 11:49:33 -04:00
parent a60446836b
commit f7d03858f1
2 changed files with 422 additions and 0 deletions

View File

@ -0,0 +1,142 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
/**
* An implementation of an {@link OAuth2AuthorizedClientManager}
* that is capable of operating outside of a {@code HttpServletRequest} context,
* e.g. in a scheduled/background thread and/or in the service-tier.
*
* @author Joe Grandja
* @since 5.2
* @see OAuth2AuthorizedClientManager
* @see OAuth2AuthorizedClientProvider
* @see OAuth2AuthorizedClientService
*/
public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implements OAuth2AuthorizedClientManager {
private final ClientRegistrationRepository clientRegistrationRepository;
private final OAuth2AuthorizedClientService authorizedClientService;
private OAuth2AuthorizedClientProvider authorizedClientProvider = context -> null;
private Function<OAuth2AuthorizeRequest, Map<String, Object>> contextAttributesMapper = new DefaultContextAttributesMapper();
/**
* Constructs an {@code AuthorizedClientServiceOAuth2AuthorizedClientManager} using the provided parameters.
*
* @param clientRegistrationRepository the repository of client registrations
* @param authorizedClientService the authorized client service
*/
public AuthorizedClientServiceOAuth2AuthorizedClientManager(ClientRegistrationRepository clientRegistrationRepository,
OAuth2AuthorizedClientService authorizedClientService) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
this.clientRegistrationRepository = clientRegistrationRepository;
this.authorizedClientService = authorizedClientService;
}
@Nullable
@Override
public OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest) {
Assert.notNull(authorizeRequest, "authorizeRequest cannot be null");
String clientRegistrationId = authorizeRequest.getClientRegistrationId();
OAuth2AuthorizedClient authorizedClient = authorizeRequest.getAuthorizedClient();
Authentication principal = authorizeRequest.getPrincipal();
OAuth2AuthorizationContext.Builder contextBuilder;
if (authorizedClient != null) {
contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient);
} else {
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'");
authorizedClient = this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName());
if (authorizedClient != null) {
contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient);
} else {
contextBuilder = OAuth2AuthorizationContext.withClientRegistration(clientRegistration);
}
}
OAuth2AuthorizationContext authorizationContext = contextBuilder
.principal(principal)
.attributes(this.contextAttributesMapper.apply(authorizeRequest))
.build();
authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
if (authorizedClient != null) {
this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal);
} else {
// In the case of re-authorization, the returned `authorizedClient` may be null if re-authorization is not supported.
// For these cases, return the provided `authorizationContext.authorizedClient`.
if (authorizationContext.getAuthorizedClient() != null) {
return authorizationContext.getAuthorizedClient();
}
}
return authorizedClient;
}
/**
* Sets the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client.
*
* @param authorizedClientProvider the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client
*/
public void setAuthorizedClientProvider(OAuth2AuthorizedClientProvider authorizedClientProvider) {
Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null");
this.authorizedClientProvider = authorizedClientProvider;
}
/**
* Sets the {@code Function} used for mapping attribute(s) from the {@link OAuth2AuthorizeRequest} to a {@code Map} of attributes
* to be associated to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}.
*
* @param contextAttributesMapper the {@code Function} used for supplying the {@code Map} of attributes
* to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}
*/
public void setContextAttributesMapper(Function<OAuth2AuthorizeRequest, Map<String, Object>> contextAttributesMapper) {
Assert.notNull(contextAttributesMapper, "contextAttributesMapper cannot be null");
this.contextAttributesMapper = contextAttributesMapper;
}
/**
* The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}.
*/
public static class DefaultContextAttributesMapper implements Function<OAuth2AuthorizeRequest, Map<String, Object>> {
@Override
public Map<String, Object> apply(OAuth2AuthorizeRequest authorizeRequest) {
Map<String, Object> contextAttributes = Collections.emptyMap();
String scope = authorizeRequest.getAttribute(OAuth2ParameterNames.SCOPE);
if (StringUtils.hasText(scope)) {
contextAttributes = new HashMap<>();
contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
StringUtils.delimitedListToStringArray(scope, " "));
}
return contextAttributes;
}
}
}

View File

@ -0,0 +1,280 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import java.util.function.Function;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
/**
* Tests for {@link AuthorizedClientServiceOAuth2AuthorizedClientManager}.
*
* @author Joe Grandja
*/
public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
private ClientRegistrationRepository clientRegistrationRepository;
private OAuth2AuthorizedClientService authorizedClientService;
private OAuth2AuthorizedClientProvider authorizedClientProvider;
private Function contextAttributesMapper;
private AuthorizedClientServiceOAuth2AuthorizedClientManager authorizedClientManager;
private ClientRegistration clientRegistration;
private Authentication principal;
private OAuth2AuthorizedClient authorizedClient;
private ArgumentCaptor<OAuth2AuthorizationContext> authorizationContextCaptor;
@SuppressWarnings("unchecked")
@Before
public void setup() {
this.clientRegistrationRepository = mock(ClientRegistrationRepository.class);
this.authorizedClientService = mock(OAuth2AuthorizedClientService.class);
this.authorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class);
this.contextAttributesMapper = mock(Function.class);
this.authorizedClientManager = new AuthorizedClientServiceOAuth2AuthorizedClientManager(
this.clientRegistrationRepository, this.authorizedClientService);
this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider);
this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper);
this.clientRegistration = TestClientRegistrations.clientRegistration().build();
this.principal = new TestingAuthenticationToken("principal", "password");
this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(),
TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken());
this.authorizationContextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationContext.class);
}
@Test
public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new AuthorizedClientServiceOAuth2AuthorizedClientManager(null, this.authorizedClientService))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("clientRegistrationRepository cannot be null");
}
@Test
public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new AuthorizedClientServiceOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizedClientService cannot be null");
}
@Test
public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizedClientProvider cannot be null");
}
@Test
public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("contextAttributesMapper cannot be null");
}
@Test
public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.authorize(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizeRequest cannot be null");
}
@Test
public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() {
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId("invalid-registration-id")
.principal(this.principal)
.build();
assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'");
}
@SuppressWarnings("unchecked")
@Test
public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() {
when(this.clientRegistrationRepository.findByRegistrationId(
eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration);
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isNull();
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isNull();
verify(this.authorizedClientService, never()).saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), eq(this.principal));
}
@SuppressWarnings("unchecked")
@Test
public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() {
when(this.clientRegistrationRepository.findByRegistrationId(
eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration);
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(this.authorizedClient);
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isNull();
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(this.authorizedClient);
verify(this.authorizedClientService).saveAuthorizedClient(
eq(this.authorizedClient), eq(this.principal));
}
@SuppressWarnings("unchecked")
@Test
public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() {
when(this.clientRegistrationRepository.findByRegistrationId(
eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration);
when(this.authorizedClientService.loadAuthorizedClient(
eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName()))).thenReturn(this.authorizedClient);
OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(
this.clientRegistration, this.principal.getName(),
TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken());
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient);
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(reauthorizedClient);
verify(this.authorizedClientService).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal));
}
@SuppressWarnings("unchecked")
@Test
public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() {
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(this.authorizedClient);
verify(this.authorizedClientService, never()).saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), eq(this.principal));
}
@SuppressWarnings("unchecked")
@Test
public void reauthorizeWhenSupportedProviderThenReauthorized() {
OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(
this.clientRegistration, this.principal.getName(),
TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken());
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient);
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(reauthorizedClient);
verify(this.authorizedClientService).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal));
}
@SuppressWarnings("unchecked")
@Test
public void reauthorizeWhenRequestAttributeScopeThenMappedToContext() {
OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(
this.clientRegistration, this.principal.getName(),
TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken());
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient);
// Override the mock with the default
this.authorizedClientManager.setContextAttributesMapper(
new AuthorizedClientServiceOAuth2AuthorizedClientManager.DefaultContextAttributesMapper());
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.attribute(OAuth2ParameterNames.SCOPE, "read write")
.build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizationContext.getAttributes()).containsKey(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
assertThat(requestScopeAttribute).contains("read", "write");
assertThat(authorizedClient).isSameAs(reauthorizedClient);
verify(this.authorizedClientService).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal));
}
}