Introduce Reactive OAuth2AuthorizedClient Manager/Provider

Fixes gh-7116
This commit is contained in:
Joe Grandja 2019-07-11 13:41:44 -04:00
parent a377581951
commit 46756d2e6b
25 changed files with 2918 additions and 276 deletions

View File

@ -20,11 +20,14 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.ImportSelector; import org.springframework.context.annotation.ImportSelector;
import org.springframework.core.type.AnnotationMetadata; import org.springframework.core.type.AnnotationMetadata;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver; import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
import org.springframework.web.reactive.config.WebFluxConfigurer; import org.springframework.web.reactive.config.WebFluxConfigurer;
import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer; import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer;
@ -63,7 +66,16 @@ final class ReactiveOAuth2ClientImportSelector implements ImportSelector {
@Override @Override
public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) { public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) {
if (this.authorizedClientRepository != null && this.clientRegistrationRepository != null) { if (this.authorizedClientRepository != null && this.clientRegistrationRepository != null) {
configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, getAuthorizedClientRepository())); ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider =
ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode()
.refreshToken()
.clientCredentials()
.build();
DefaultServerOAuth2AuthorizedClientManager authorizedClientManager = new DefaultServerOAuth2AuthorizedClientManager(
this.clientRegistrationRepository, getAuthorizedClientRepository());
authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager));
} }
} }

View File

@ -0,0 +1,53 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert;
import reactor.core.publisher.Mono;
/**
* An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider}
* for the {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} grant.
*
* @author Joe Grandja
* @since 5.2
* @see ReactiveOAuth2AuthorizedClientProvider
*/
public final class AuthorizationCodeReactiveOAuth2AuthorizedClientProvider implements ReactiveOAuth2AuthorizedClientProvider {
/**
* Attempt to authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}.
* Returns an empty {@code Mono} if authorization is not supported,
* e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type}
* is not {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} OR the client is already authorized.
*
* @param context the context that holds authorization-specific state for the client
* @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if authorization is not supported
*/
@Override
public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizationContext context) {
Assert.notNull(context, "context cannot be null");
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getClientRegistration().getAuthorizationGrantType()) &&
context.getAuthorizedClient() == null) {
// ClientAuthorizationRequiredException is caught by OAuth2AuthorizationRequestRedirectWebFilter which initiates authorization
return Mono.error(() -> new ClientAuthorizationRequiredException(context.getClientRegistration().getRegistrationId()));
}
return Mono.empty();
}
}

View File

@ -0,0 +1,108 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AbstractOAuth2Token;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert;
import reactor.core.publisher.Mono;
import java.time.Duration;
import java.time.Instant;
/**
* An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider}
* for the {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} grant.
*
* @author Joe Grandja
* @since 5.2
* @see ReactiveOAuth2AuthorizedClientProvider
* @see WebClientReactiveClientCredentialsTokenResponseClient
*/
public final class ClientCredentialsReactiveOAuth2AuthorizedClientProvider implements ReactiveOAuth2AuthorizedClientProvider {
private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient =
new WebClientReactiveClientCredentialsTokenResponseClient();
private Duration clockSkew = Duration.ofSeconds(60);
/**
* Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}.
* Returns an empty {@code Mono} if authorization (or re-authorization) is not supported,
* e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type}
* is not {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} OR
* the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired.
*
* @param context the context that holds authorization-specific state for the client
* @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if authorization (or re-authorization) is not supported
*/
@Override
public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizationContext context) {
Assert.notNull(context, "context cannot be null");
ClientRegistration clientRegistration = context.getClientRegistration();
if (!AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
return Mono.empty();
}
OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient();
if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) {
// If client is already authorized but access token is NOT expired than no need for re-authorization
return Mono.empty();
}
// As per spec, in section 4.4.3 Access Token Response
// https://tools.ietf.org/html/rfc6749#section-4.4.3
// A refresh token SHOULD NOT be included.
//
// Therefore, renewing an expired access token (re-authorization)
// is the same as acquiring a new access token (authorization).
return Mono.just(new OAuth2ClientCredentialsGrantRequest(clientRegistration))
.flatMap(this.accessTokenResponseClient::getTokenResponse)
.map(tokenResponse -> new OAuth2AuthorizedClient(
clientRegistration, context.getPrincipal().getName(), tokenResponse.getAccessToken()));
}
private boolean hasTokenExpired(AbstractOAuth2Token token) {
return token.getExpiresAt().isBefore(Instant.now().minus(this.clockSkew));
}
/**
* Sets the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant.
*
* @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant
*/
public void setAccessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient) {
Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null");
this.accessTokenResponseClient = accessTokenResponseClient;
}
/**
* Sets the maximum acceptable clock skew, which is used when checking the
* {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is 60 seconds.
* An access token is considered expired if it's before {@code Instant.now() - clockSkew}.
*
* @param clockSkew the maximum acceptable clock skew
*/
public void setClockSkew(Duration clockSkew) {
Assert.notNull(clockSkew, "clockSkew cannot be null");
Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0");
this.clockSkew = clockSkew;
}
}

View File

@ -0,0 +1,70 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
* An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider} that simply delegates
* to it's internal {@code List} of {@link ReactiveOAuth2AuthorizedClientProvider}(s).
* <p>
* Each provider is given a chance to
* {@link ReactiveOAuth2AuthorizedClientProvider#authorize(OAuth2AuthorizationContext) authorize}
* the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided context
* with the first available {@link OAuth2AuthorizedClient} being returned.
*
* @author Joe Grandja
* @since 5.2
* @see ReactiveOAuth2AuthorizedClientProvider
*/
public final class DelegatingReactiveOAuth2AuthorizedClientProvider implements ReactiveOAuth2AuthorizedClientProvider {
private final List<ReactiveOAuth2AuthorizedClientProvider> authorizedClientProviders;
/**
* Constructs a {@code DelegatingReactiveOAuth2AuthorizedClientProvider} using the provided parameters.
*
* @param authorizedClientProviders a list of {@link ReactiveOAuth2AuthorizedClientProvider}(s)
*/
public DelegatingReactiveOAuth2AuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider... authorizedClientProviders) {
Assert.notEmpty(authorizedClientProviders, "authorizedClientProviders cannot be empty");
this.authorizedClientProviders = Collections.unmodifiableList(Arrays.asList(authorizedClientProviders));
}
/**
* Constructs a {@code DelegatingReactiveOAuth2AuthorizedClientProvider} using the provided parameters.
*
* @param authorizedClientProviders a {@code List} of {@link OAuth2AuthorizedClientProvider}(s)
*/
public DelegatingReactiveOAuth2AuthorizedClientProvider(List<ReactiveOAuth2AuthorizedClientProvider> authorizedClientProviders) {
Assert.notEmpty(authorizedClientProviders, "authorizedClientProviders cannot be empty");
this.authorizedClientProviders = Collections.unmodifiableList(new ArrayList<>(authorizedClientProviders));
}
@Override
public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizationContext context) {
Assert.notNull(context, "context cannot be null");
return Flux.fromIterable(this.authorizedClientProviders)
.concatMap(authorizedClientProvider -> authorizedClientProvider.authorize(context))
.next();
}
}

View File

@ -0,0 +1,44 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import reactor.core.publisher.Mono;
/**
* A strategy for authorizing (or re-authorizing) an OAuth 2.0 Client.
* Implementations will typically implement a specific {@link AuthorizationGrantType authorization grant} type.
*
* @author Joe Grandja
* @since 5.2
* @see OAuth2AuthorizedClient
* @see OAuth2AuthorizationContext
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-1.3">Section 1.3 Authorization Grant</a>
*/
public interface ReactiveOAuth2AuthorizedClientProvider {
/**
* Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided context.
* Implementations must return an empty {@code Mono} if authorization is not supported for the specified client,
* e.g. the provider doesn't support the {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type configured for the client.
*
* @param context the context that holds authorization-specific state for the client
* @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if authorization is not supported for the specified client
*/
Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizationContext context);
}

View File

@ -0,0 +1,267 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.util.Assert;
import java.time.Duration;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.Collectors;
/**
* A builder that builds a {@link DelegatingReactiveOAuth2AuthorizedClientProvider} composed of
* one or more {@link ReactiveOAuth2AuthorizedClientProvider}(s) that implement specific authorization grants.
* The supported authorization grants are {@link #authorizationCode() authorization_code},
* {@link #refreshToken() refresh_token} and {@link #clientCredentials() client_credentials}.
* In addition to the standard authorization grants, an implementation of an extension grant
* may be supplied via {@link #provider(ReactiveOAuth2AuthorizedClientProvider)}.
*
* @author Joe Grandja
* @since 5.2
* @see ReactiveOAuth2AuthorizedClientProvider
* @see AuthorizationCodeReactiveOAuth2AuthorizedClientProvider
* @see RefreshTokenReactiveOAuth2AuthorizedClientProvider
* @see ClientCredentialsReactiveOAuth2AuthorizedClientProvider
* @see DelegatingReactiveOAuth2AuthorizedClientProvider
*/
public final class ReactiveOAuth2AuthorizedClientProviderBuilder {
private final Map<Class<?>, Builder> builders = new LinkedHashMap<>();
private ReactiveOAuth2AuthorizedClientProviderBuilder() {
}
/**
* Returns a new {@link ReactiveOAuth2AuthorizedClientProviderBuilder} for configuring the supported authorization grant(s).
*
* @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder}
*/
public static ReactiveOAuth2AuthorizedClientProviderBuilder builder() {
return new ReactiveOAuth2AuthorizedClientProviderBuilder();
}
/**
* Configures a {@link ReactiveOAuth2AuthorizedClientProvider} to be composed with the {@link DelegatingReactiveOAuth2AuthorizedClientProvider}.
* This may be used for implementations of extension authorization grants.
*
* @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder}
*/
public ReactiveOAuth2AuthorizedClientProviderBuilder provider(ReactiveOAuth2AuthorizedClientProvider provider) {
Assert.notNull(provider, "provider cannot be null");
this.builders.computeIfAbsent(provider.getClass(), k -> () -> provider);
return ReactiveOAuth2AuthorizedClientProviderBuilder.this;
}
/**
* Configures support for the {@code authorization_code} grant.
*
* @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder}
*/
public ReactiveOAuth2AuthorizedClientProviderBuilder authorizationCode() {
this.builders.computeIfAbsent(AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.class, k -> new AuthorizationCodeGrantBuilder());
return ReactiveOAuth2AuthorizedClientProviderBuilder.this;
}
/**
* A builder for the {@code authorization_code} grant.
*/
public class AuthorizationCodeGrantBuilder implements Builder {
private AuthorizationCodeGrantBuilder() {
}
/**
* Builds an instance of {@link AuthorizationCodeReactiveOAuth2AuthorizedClientProvider}.
*
* @return the {@link AuthorizationCodeReactiveOAuth2AuthorizedClientProvider}
*/
@Override
public ReactiveOAuth2AuthorizedClientProvider build() {
return new AuthorizationCodeReactiveOAuth2AuthorizedClientProvider();
}
}
/**
* Configures support for the {@code refresh_token} grant.
*
* @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder}
*/
public ReactiveOAuth2AuthorizedClientProviderBuilder refreshToken() {
this.builders.computeIfAbsent(RefreshTokenReactiveOAuth2AuthorizedClientProvider.class, k -> new RefreshTokenGrantBuilder());
return ReactiveOAuth2AuthorizedClientProviderBuilder.this;
}
/**
* Configures support for the {@code refresh_token} grant.
*
* @param builderConsumer a {@code Consumer} of {@link RefreshTokenGrantBuilder} used for further configuration
* @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder}
*/
public ReactiveOAuth2AuthorizedClientProviderBuilder refreshToken(Consumer<RefreshTokenGrantBuilder> builderConsumer) {
RefreshTokenGrantBuilder builder = (RefreshTokenGrantBuilder) this.builders.computeIfAbsent(
RefreshTokenReactiveOAuth2AuthorizedClientProvider.class, k -> new RefreshTokenGrantBuilder());
builderConsumer.accept(builder);
return ReactiveOAuth2AuthorizedClientProviderBuilder.this;
}
/**
* A builder for the {@code refresh_token} grant.
*/
public class RefreshTokenGrantBuilder implements Builder {
private ReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient;
private Duration clockSkew;
private RefreshTokenGrantBuilder() {
}
/**
* Sets the client used when requesting an access token credential at the Token Endpoint.
*
* @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint
* @return the {@link RefreshTokenGrantBuilder}
*/
public RefreshTokenGrantBuilder accessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient) {
this.accessTokenResponseClient = accessTokenResponseClient;
return this;
}
/**
* Sets the maximum acceptable clock skew, which is used when checking the access token expiry.
* An access token is considered expired if it's before {@code Instant.now() - clockSkew}.
*
* @param clockSkew the maximum acceptable clock skew
* @return the {@link RefreshTokenGrantBuilder}
*/
public RefreshTokenGrantBuilder clockSkew(Duration clockSkew) {
this.clockSkew = clockSkew;
return this;
}
/**
* Builds an instance of {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider}.
*
* @return the {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider}
*/
@Override
public ReactiveOAuth2AuthorizedClientProvider build() {
RefreshTokenReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenReactiveOAuth2AuthorizedClientProvider();
if (this.accessTokenResponseClient != null) {
authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient);
}
if (this.clockSkew != null) {
authorizedClientProvider.setClockSkew(this.clockSkew);
}
return authorizedClientProvider;
}
}
/**
* Configures support for the {@code client_credentials} grant.
*
* @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder}
*/
public ReactiveOAuth2AuthorizedClientProviderBuilder clientCredentials() {
this.builders.computeIfAbsent(ClientCredentialsReactiveOAuth2AuthorizedClientProvider.class, k -> new ClientCredentialsGrantBuilder());
return ReactiveOAuth2AuthorizedClientProviderBuilder.this;
}
/**
* Configures support for the {@code client_credentials} grant.
*
* @param builderConsumer a {@code Consumer} of {@link ClientCredentialsGrantBuilder} used for further configuration
* @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder}
*/
public ReactiveOAuth2AuthorizedClientProviderBuilder clientCredentials(Consumer<ClientCredentialsGrantBuilder> builderConsumer) {
ClientCredentialsGrantBuilder builder = (ClientCredentialsGrantBuilder) this.builders.computeIfAbsent(
ClientCredentialsReactiveOAuth2AuthorizedClientProvider.class, k -> new ClientCredentialsGrantBuilder());
builderConsumer.accept(builder);
return ReactiveOAuth2AuthorizedClientProviderBuilder.this;
}
/**
* A builder for the {@code client_credentials} grant.
*/
public class ClientCredentialsGrantBuilder implements Builder {
private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient;
private Duration clockSkew;
private ClientCredentialsGrantBuilder() {
}
/**
* Sets the client used when requesting an access token credential at the Token Endpoint.
*
* @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint
* @return the {@link ClientCredentialsGrantBuilder}
*/
public ClientCredentialsGrantBuilder accessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient) {
this.accessTokenResponseClient = accessTokenResponseClient;
return this;
}
/**
* Sets the maximum acceptable clock skew, which is used when checking the access token expiry.
* An access token is considered expired if it's before {@code Instant.now() - clockSkew}.
*
* @param clockSkew the maximum acceptable clock skew
* @return the {@link ClientCredentialsGrantBuilder}
*/
public ClientCredentialsGrantBuilder clockSkew(Duration clockSkew) {
this.clockSkew = clockSkew;
return this;
}
/**
* Builds an instance of {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider}.
*
* @return the {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider}
*/
@Override
public ReactiveOAuth2AuthorizedClientProvider build() {
ClientCredentialsReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = new ClientCredentialsReactiveOAuth2AuthorizedClientProvider();
if (this.accessTokenResponseClient != null) {
authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient);
}
if (this.clockSkew != null) {
authorizedClientProvider.setClockSkew(this.clockSkew);
}
return authorizedClientProvider;
}
}
/**
* Builds an instance of {@link DelegatingReactiveOAuth2AuthorizedClientProvider}
* composed of one or more {@link ReactiveOAuth2AuthorizedClientProvider}(s).
*
* @return the {@link DelegatingReactiveOAuth2AuthorizedClientProvider}
*/
public ReactiveOAuth2AuthorizedClientProvider build() {
List<ReactiveOAuth2AuthorizedClientProvider> authorizedClientProviders =
this.builders.values().stream()
.map(Builder::build)
.collect(Collectors.toList());
return new DelegatingReactiveOAuth2AuthorizedClientProvider(authorizedClientProviders);
}
interface Builder {
ReactiveOAuth2AuthorizedClientProvider build();
}
}

View File

@ -0,0 +1,119 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.WebClientReactiveRefreshTokenTokenResponseClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AbstractOAuth2Token;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert;
import reactor.core.publisher.Mono;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
/**
* An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider}
* for the {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant.
*
* @author Joe Grandja
* @since 5.2
* @see ReactiveOAuth2AuthorizedClientProvider
* @see WebClientReactiveRefreshTokenTokenResponseClient
*/
public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider implements ReactiveOAuth2AuthorizedClientProvider {
private ReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient =
new WebClientReactiveRefreshTokenTokenResponseClient();
private Duration clockSkew = Duration.ofSeconds(60);
/**
* Attempt to re-authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}.
* Returns an empty {@code Mono} if re-authorization is not supported,
* e.g. the client is not authorized OR the {@link OAuth2AuthorizedClient#getRefreshToken() refresh token}
* is not available for the authorized client OR the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired.
*
* <p>
* The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported:
* <ol>
* <li>{@code "org.springframework.security.oauth2.client.REQUEST_SCOPE"} (optional) - a {@code String[]} of scope(s)
* to be requested by the {@link OAuth2AuthorizationContext#getClientRegistration() client}</li>
* </ol>
*
* @param context the context that holds authorization-specific state for the client
* @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if re-authorization is not supported
*/
@Override
public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizationContext context) {
Assert.notNull(context, "context cannot be null");
OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient();
if (authorizedClient == null ||
authorizedClient.getRefreshToken() == null ||
!hasTokenExpired(authorizedClient.getAccessToken())) {
return Mono.empty();
}
Object requestScope = context.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
Set<String> scopes = Collections.emptySet();
if (requestScope != null) {
Assert.isInstanceOf(String[].class, requestScope,
"The context attribute must be of type String[] '" + OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'");
scopes = new HashSet<>(Arrays.asList((String[]) requestScope));
}
ClientRegistration clientRegistration = context.getClientRegistration();
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
clientRegistration, authorizedClient.getAccessToken(), authorizedClient.getRefreshToken(), scopes);
return Mono.just(refreshTokenGrantRequest)
.flatMap(this.accessTokenResponseClient::getTokenResponse)
.map(tokenResponse -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(),
tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()));
}
private boolean hasTokenExpired(AbstractOAuth2Token token) {
return token.getExpiresAt().isBefore(Instant.now().minus(this.clockSkew));
}
/**
* Sets the client used when requesting an access token credential at the Token Endpoint for the {@code refresh_token} grant.
*
* @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code refresh_token} grant
*/
public void setAccessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient) {
Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null");
this.accessTokenResponseClient = accessTokenResponseClient;
}
/**
* Sets the maximum acceptable clock skew, which is used when checking the
* {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is 60 seconds.
* An access token is considered expired if it's before {@code Instant.now() - clockSkew}.
*
* @param clockSkew the maximum acceptable clock skew
*/
public void setClockSkew(Duration clockSkew) {
Assert.notNull(clockSkew, "clockSkew cannot be null");
Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0");
this.clockSkew = clockSkew;
}
}

View File

@ -0,0 +1,142 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;
import java.util.Collections;
import java.util.function.Consumer;
import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
/**
* An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient}
* for the {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant.
* This implementation uses {@link WebClient} when requesting
* an access token credential at the Authorization Server's Token Endpoint.
*
* @author Joe Grandja
* @since 5.2
* @see ReactiveOAuth2AccessTokenResponseClient
* @see OAuth2RefreshTokenGrantRequest
* @see OAuth2AccessTokenResponse
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-6">Section 6 Refreshing an Access Token</a>
*/
public final class WebClientReactiveRefreshTokenTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> {
private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response";
private WebClient webClient = WebClient.builder().build();
@Override
public Mono<OAuth2AccessTokenResponse> getTokenResponse(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) {
Assert.notNull(refreshTokenGrantRequest, "refreshTokenGrantRequest cannot be null");
return Mono.defer(() -> {
ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration();
return this.webClient.post()
.uri(clientRegistration.getProviderDetails().getTokenUri())
.headers(tokenRequestHeaders(clientRegistration))
.body(tokenRequestBody(refreshTokenGrantRequest))
.exchange()
.flatMap(response -> {
if (!response.statusCode().is2xxSuccessful()) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " +
"HTTP Status Code " + response.rawStatusCode(), null);
return response
.bodyToMono(DataBuffer.class)
.map(DataBufferUtils::release)
.then(Mono.error(new OAuth2AuthorizationException(oauth2Error)));
}
return response.body(oauth2AccessTokenResponse());
})
.map(tokenResponse -> tokenResponse(refreshTokenGrantRequest, tokenResponse));
});
}
private static Consumer<HttpHeaders> tokenRequestHeaders(ClientRegistration clientRegistration) {
return headers -> {
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
}
};
}
private static BodyInserters.FormInserter<String> tokenRequestBody(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) {
ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration();
BodyInserters.FormInserter<String> body = BodyInserters.fromFormData(
OAuth2ParameterNames.GRANT_TYPE, refreshTokenGrantRequest.getGrantType().getValue());
body.with(OAuth2ParameterNames.REFRESH_TOKEN,
refreshTokenGrantRequest.getRefreshToken().getTokenValue());
if (!CollectionUtils.isEmpty(refreshTokenGrantRequest.getScopes())) {
body.with(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(refreshTokenGrantRequest.getScopes(), " "));
}
if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) {
body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
}
return body;
}
private static OAuth2AccessTokenResponse tokenResponse(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest,
OAuth2AccessTokenResponse accessTokenResponse) {
if (!CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes()) &&
accessTokenResponse.getRefreshToken() != null) {
return accessTokenResponse;
}
OAuth2AccessTokenResponse.Builder tokenResponseBuilder = OAuth2AccessTokenResponse.withResponse(accessTokenResponse);
if (CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes())) {
// As per spec, in Section 5.1 Successful Access Token Response
// https://tools.ietf.org/html/rfc6749#section-5.1
// If AccessTokenResponse.scope is empty, then default to the scope
// originally requested by the client in the Token Request
tokenResponseBuilder.scopes(refreshTokenGrantRequest.getAccessToken().getScopes());
}
if (accessTokenResponse.getRefreshToken() == null) {
// Reuse existing refresh token
tokenResponseBuilder.refreshToken(refreshTokenGrantRequest.getRefreshToken().getTokenValue());
}
return tokenResponseBuilder.build();
}
/**
* Sets the {@link WebClient} used when requesting the OAuth 2.0 Access Token Response.
*
* @param webClient the {@link WebClient} used when requesting the Access Token Response
*/
public void setWebClient(WebClient webClient) {
Assert.notNull(webClient, "webClient cannot be null");
this.webClient = webClient;
}
}

View File

@ -16,23 +16,26 @@
package org.springframework.security.oauth2.client.web.reactive.function.client; package org.springframework.security.oauth2.client.web.reactive.function.client;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.client.ClientCredentialsReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder;
import org.springframework.security.oauth2.client.RefreshTokenReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizeRequest;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction; import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
@ -40,21 +43,17 @@ import org.springframework.web.reactive.function.client.ExchangeFunction;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import java.net.URI;
import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.function.Consumer; import java.util.function.Consumer;
import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
/** /**
* Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the
* token as a Bearer Token. * token as a Bearer Token.
* *
* @author Rob Winch * @author Rob Winch
* @author Joe Grandja
* @since 5.1 * @since 5.1
*/ */
public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction { public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
@ -76,21 +75,59 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser", private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
AuthorityUtils.createAuthorityList("ROLE_USER")); AuthorityUtils.createAuthorityList("ROLE_USER"));
private Clock clock = Clock.systemUTC(); private ServerOAuth2AuthorizedClientManager authorizedClientManager;
private boolean defaultAuthorizedClientManager;
private boolean defaultOAuth2AuthorizedClient;
private String defaultClientRegistrationId;
@Deprecated
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; @Deprecated
private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient;
private final OAuth2AuthorizedClientResolver authorizedClientResolver;
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { /**
this(authorizedClientRepository, new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository)); * Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters.
*
* @since 5.2
* @param authorizedClientManager the {@link ServerOAuth2AuthorizedClientManager} which manages the authorized client(s)
*/
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientManager authorizedClientManager) {
Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null");
this.authorizedClientManager = authorizedClientManager;
} }
ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository, OAuth2AuthorizedClientResolver authorizedClientResolver) { /**
this.authorizedClientRepository = authorizedClientRepository; * Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters.
this.authorizedClientResolver = authorizedClientResolver; *
* @param clientRegistrationRepository the repository of client registrations
* @param authorizedClientRepository the repository of authorized clients
*/
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository,
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository);
this.defaultAuthorizedClientManager = true;
}
private static ServerOAuth2AuthorizedClientManager createDefaultAuthorizedClientManager(
ReactiveClientRegistrationRepository clientRegistrationRepository,
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider =
ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode()
.refreshToken()
.clientCredentials()
.build();
DefaultServerOAuth2AuthorizedClientManager authorizedClientManager = new DefaultServerOAuth2AuthorizedClientManager(
clientRegistrationRepository, authorizedClientRepository);
authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
return authorizedClientManager;
} }
/** /**
@ -99,7 +136,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
* *
* <pre> * <pre>
* WebClient webClient = WebClient.builder() * WebClient webClient = WebClient.builder()
* .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientRepository)) * .filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager))
* .build(); * .build();
* Mono<String> response = webClient * Mono<String> response = webClient
* .get() * .get()
@ -114,8 +151,6 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
* are true: * are true:
* *
* <ul> * <ul>
* <li>The ReactiveOAuth2AuthorizedClientService on the
* {@link ServerOAuth2AuthorizedClientExchangeFilterFunction} is not null</li>
* <li>A refresh token is present on the OAuth2AuthorizedClient</li> * <li>A refresh token is present on the OAuth2AuthorizedClient</li>
* <li>The access token will be expired in * <li>The access token will be expired in
* {@link #setAccessTokenExpiresSkew(Duration)}</li> * {@link #setAccessTokenExpiresSkew(Duration)}</li>
@ -136,12 +171,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
} }
/** /**
* Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for * Modifies the {@link ClientRequest#attributes()} to include the {@link ServerWebExchange} to be used for
* providing the Bearer Token. Example usage: * providing the Bearer Token. Example usage:
* *
* <pre> * <pre>
* WebClient webClient = WebClient.builder() * WebClient webClient = WebClient.builder()
* .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientRepository)) * .filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager))
* .build(); * .build();
* Mono<String> response = webClient * Mono<String> response = webClient
* .get() * .get()
@ -190,7 +225,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
* Default is false. * Default is false.
*/ */
public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) { public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
this.authorizedClientResolver.setDefaultOAuth2AuthorizedClient(defaultOAuth2AuthorizedClient); this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient;
} }
/** /**
@ -199,124 +234,127 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
* @param clientRegistrationId the id to use * @param clientRegistrationId the id to use
*/ */
public void setDefaultClientRegistrationId(String clientRegistrationId) { public void setDefaultClientRegistrationId(String clientRegistrationId) {
this.authorizedClientResolver.setDefaultClientRegistrationId(clientRegistrationId); this.defaultClientRegistrationId = clientRegistrationId;
} }
/** /**
* Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for * Sets the {@link ReactiveOAuth2AccessTokenResponseClient} used for getting an {@link OAuth2AuthorizedClient} for the client_credentials grant.
* client_credentials grant. *
* @deprecated Use {@link #ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientManager)} instead.
* Create an instance of {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider} configured with a
* {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient) WebClientReactiveClientCredentialsTokenResponseClient}
* (or a custom one) and than supply it to {@link DefaultServerOAuth2AuthorizedClientManager#setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider) DefaultOAuth2AuthorizedClientManager}.
*
* @param clientCredentialsTokenResponseClient the client to use * @param clientCredentialsTokenResponseClient the client to use
*/ */
@Deprecated
public void setClientCredentialsTokenResponseClient( public void setClientCredentialsTokenResponseClient(
ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) { ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
this.authorizedClientResolver.setClientCredentialsTokenResponseClient(clientCredentialsTokenResponseClient); Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
Assert.state(this.defaultAuthorizedClientManager, "The client cannot be set when the constructor used is \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientManager)\". " +
"Instead, use the constructor \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\".");
this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
updateDefaultAuthorizedClientManager();
}
private void updateDefaultAuthorizedClientManager() {
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider =
ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode()
.refreshToken(configurer -> configurer.clockSkew(this.accessTokenExpiresSkew))
.clientCredentials(this::updateClientCredentialsProvider)
.build();
((DefaultServerOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider);
}
private void updateClientCredentialsProvider(ReactiveOAuth2AuthorizedClientProviderBuilder.ClientCredentialsGrantBuilder builder) {
if (this.clientCredentialsTokenResponseClient != null) {
builder.accessTokenResponseClient(this.clientCredentialsTokenResponseClient);
}
builder.clockSkew(this.accessTokenExpiresSkew);
} }
/** /**
* An access token will be considered expired by comparing its expiration to now + * An access token will be considered expired by comparing its expiration to now +
* this skewed Duration. The default is 1 minute. * this skewed Duration. The default is 1 minute.
*
* @deprecated The {@code accessTokenExpiresSkew} should be configured with the specific {@link ReactiveOAuth2AuthorizedClientProvider} implementation,
* e.g. {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider#setClockSkew(Duration) ClientCredentialsReactiveOAuth2AuthorizedClientProvider} or
* {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider#setClockSkew(Duration) RefreshTokenReactiveOAuth2AuthorizedClientProvider}.
*
* @param accessTokenExpiresSkew the Duration to use. * @param accessTokenExpiresSkew the Duration to use.
*/ */
@Deprecated
public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) { public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) {
Assert.notNull(accessTokenExpiresSkew, "accessTokenExpiresSkew cannot be null"); Assert.notNull(accessTokenExpiresSkew, "accessTokenExpiresSkew cannot be null");
Assert.state(this.defaultAuthorizedClientManager, "The accessTokenExpiresSkew cannot be set when the constructor used is \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientManager)\". " +
"Instead, use the constructor \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\".");
this.accessTokenExpiresSkew = accessTokenExpiresSkew; this.accessTokenExpiresSkew = accessTokenExpiresSkew;
updateDefaultAuthorizedClientManager();
} }
@Override @Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) { public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
return authorizedClient(request, next) return authorizedClient(request)
.map(authorizedClient -> bearer(request, authorizedClient)) .map(authorizedClient -> bearer(request, authorizedClient))
.flatMap(next::exchange) .flatMap(next::exchange)
.switchIfEmpty(Mono.defer(() -> next.exchange(request))); .switchIfEmpty(Mono.defer(() -> next.exchange(request)));
} }
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request, ExchangeFunction next) { private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
OAuth2AuthorizedClient authorizedClientFromAttrs = oauth2AuthorizedClient(request); OAuth2AuthorizedClient authorizedClientFromAttrs = oauth2AuthorizedClient(request);
return Mono.justOrEmpty(authorizedClientFromAttrs) return Mono.justOrEmpty(authorizedClientFromAttrs)
.switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(request))) .switchIfEmpty(Mono.defer(() ->
.flatMap(authorizedClient -> refreshIfNecessary(request, next, authorizedClient)); authorizeRequest(request).flatMap(this.authorizedClientManager::authorize)))
.flatMap(authorizedClient ->
reauthorizeRequest(request, authorizedClient).flatMap(this.authorizedClientManager::authorize));
} }
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(ClientRequest request) { private Mono<ServerOAuth2AuthorizeRequest> authorizeRequest(ClientRequest request) {
return createRequest(request) Mono<Authentication> authentication = currentAuthentication();
.flatMap(r -> this.authorizedClientResolver.loadAuthorizedClient(r));
Mono<String> clientRegistrationId = Mono.justOrEmpty(clientRegistrationId(request))
.switchIfEmpty(Mono.justOrEmpty(this.defaultClientRegistrationId))
.switchIfEmpty(clientRegistrationId(authentication));
Mono<Optional<ServerWebExchange>> serverWebExchange = Mono.justOrEmpty(serverWebExchange(request))
.switchIfEmpty(currentServerWebExchange())
.map(Optional::of)
.defaultIfEmpty(Optional.empty());
return Mono.zip(clientRegistrationId, authentication, serverWebExchange)
.map(t3 -> new ServerOAuth2AuthorizeRequest(t3.getT1(), t3.getT2(), t3.getT3().orElse(null)));
} }
private Mono<OAuth2AuthorizedClientResolver.Request> createRequest(ClientRequest request) { private Mono<ServerOAuth2AuthorizeRequest> reauthorizeRequest(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {
String clientRegistrationId = clientRegistrationId(request); Mono<Authentication> authentication = currentAuthentication();
Authentication authentication = null;
ServerWebExchange exchange = serverWebExchange(request); Mono<Optional<ServerWebExchange>> serverWebExchange = Mono.justOrEmpty(serverWebExchange(request))
return this.authorizedClientResolver.createDefaultedRequest(clientRegistrationId, authentication, exchange); .switchIfEmpty(currentServerWebExchange())
.map(Optional::of)
.defaultIfEmpty(Optional.empty());
return Mono.zip(authentication, serverWebExchange)
.map(t2 -> new ServerOAuth2AuthorizeRequest(authorizedClient, t2.getT1(), t2.getT2().orElse(null)));
} }
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) { private Mono<Authentication> currentAuthentication() {
ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); return ReactiveSecurityContextHolder.getContext()
if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) { .map(SecurityContext::getAuthentication)
return createRequest(request) .defaultIfEmpty(ANONYMOUS_USER_TOKEN);
.flatMap(r -> authorizeWithClientCredentials(clientRegistration, r));
} else if (shouldRefreshToken(authorizedClient)) {
return createRequest(request)
.flatMap(r -> authorizeWithRefreshToken(next, authorizedClient, r));
}
return Mono.just(authorizedClient);
} }
private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) { private Mono<String> clientRegistrationId(Mono<Authentication> authentication) {
return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType()); return authentication
.filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken)
.cast(OAuth2AuthenticationToken.class)
.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
} }
private Mono<OAuth2AuthorizedClient> authorizeWithClientCredentials(ClientRegistration clientRegistration, OAuth2AuthorizedClientResolver.Request request) { private Mono<ServerWebExchange> currentServerWebExchange() {
Authentication authentication = request.getAuthentication(); return Mono.subscriberContext()
ServerWebExchange exchange = request.getExchange(); .filter(c -> c.hasKey(ServerWebExchange.class))
.map(c -> c.get(ServerWebExchange.class));
return this.authorizedClientResolver.clientCredentials(clientRegistration, authentication, exchange).
flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange)
.thenReturn(result));
}
private Mono<OAuth2AuthorizedClient> authorizeWithRefreshToken(ExchangeFunction next,
OAuth2AuthorizedClient authorizedClient,
OAuth2AuthorizedClientResolver.Request r) {
ServerWebExchange exchange = r.getExchange();
Authentication authentication = r.getAuthentication();
ClientRegistration clientRegistration = authorizedClient
.getClientRegistration();
String tokenUri = clientRegistration
.getProviderDetails().getTokenUri();
ClientRequest refreshRequest = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri))
.header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.headers(headers -> headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()))
.body(refreshTokenBody(authorizedClient.getRefreshToken().getTokenValue()))
.build();
return next.exchange(refreshRequest)
.flatMap(refreshResponse -> refreshResponse.body(oauth2AccessTokenResponse()))
.map(accessTokenResponse -> {
OAuth2RefreshToken refreshToken = Optional.ofNullable(accessTokenResponse.getRefreshToken())
.orElse(authorizedClient.getRefreshToken());
return new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), refreshToken);
})
.flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange)
.thenReturn(result));
}
private boolean shouldRefreshToken(OAuth2AuthorizedClient authorizedClient) {
if (this.authorizedClientRepository == null) {
return false;
}
OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();
if (refreshToken == null) {
return false;
}
return hasTokenExpired(authorizedClient);
}
private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
Instant now = this.clock.instant();
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) {
return true;
}
return false;
} }
private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {
@ -324,10 +362,4 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
.headers(headers -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue())) .headers(headers -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue()))
.build(); .build();
} }
private static BodyInserters.FormInserter<String> refreshTokenBody(String refreshToken) {
return BodyInserters
.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue())
.with("refresh_token", refreshToken);
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2019 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,9 +18,20 @@ package org.springframework.security.oauth2.client.web.reactive.result.method.an
import org.springframework.core.MethodParameter; import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder;
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizeRequest;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
@ -46,22 +57,53 @@ import reactor.core.publisher.Mono;
* </pre> * </pre>
* *
* @author Rob Winch * @author Rob Winch
* @author Joe Grandja
* @since 5.1 * @since 5.1
* @see RegisteredOAuth2AuthorizedClient * @see RegisteredOAuth2AuthorizedClient
*/ */
public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver { public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver {
private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken(
private final OAuth2AuthorizedClientResolver authorizedClientResolver; "anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_USER"));
private ServerOAuth2AuthorizedClientManager authorizedClientManager;
/** /**
* Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters. * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters.
* *
* @param authorizedClientRepository the authorized client repository * @since 5.2
* @param authorizedClientManager the {@link ServerOAuth2AuthorizedClientManager} which manages the authorized client(s)
*/ */
public OAuth2AuthorizedClientArgumentResolver(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { public OAuth2AuthorizedClientArgumentResolver(ServerOAuth2AuthorizedClientManager authorizedClientManager) {
Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null");
this.authorizedClientManager = authorizedClientManager;
}
/**
* Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters.
*
* @param clientRegistrationRepository the repository of client registrations
* @param authorizedClientRepository the repository of authorized clients
*/
public OAuth2AuthorizedClientArgumentResolver(ReactiveClientRegistrationRepository clientRegistrationRepository,
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository); this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository);
this.authorizedClientResolver.setDefaultOAuth2AuthorizedClient(true); }
private static ServerOAuth2AuthorizedClientManager createDefaultAuthorizedClientManager(
ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider =
ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode()
.refreshToken()
.clientCredentials()
.build();
DefaultServerOAuth2AuthorizedClientManager authorizedClientManager = new DefaultServerOAuth2AuthorizedClientManager(
clientRegistrationRepository, authorizedClientRepository);
authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
return authorizedClientManager;
} }
@Override @Override
@ -70,8 +112,7 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
} }
@Override @Override
public Mono<Object> resolveArgument( public Mono<Object> resolveArgument(MethodParameter parameter, BindingContext bindingContext, ServerWebExchange exchange) {
MethodParameter parameter, BindingContext bindingContext, ServerWebExchange exchange) {
return Mono.defer(() -> { return Mono.defer(() -> {
RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils
.findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class); .findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class);
@ -79,8 +120,41 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
String clientRegistrationId = StringUtils.hasLength(authorizedClientAnnotation.registrationId()) ? String clientRegistrationId = StringUtils.hasLength(authorizedClientAnnotation.registrationId()) ?
authorizedClientAnnotation.registrationId() : null; authorizedClientAnnotation.registrationId() : null;
return this.authorizedClientResolver.createDefaultedRequest(clientRegistrationId, null, exchange) return authorizeRequest(clientRegistrationId, exchange)
.flatMap(this.authorizedClientResolver::loadAuthorizedClient); .flatMap(this.authorizedClientManager::authorize);
}); });
} }
private Mono<ServerOAuth2AuthorizeRequest> authorizeRequest(String registrationId, ServerWebExchange exchange) {
Mono<Authentication> defaultedAuthentication = currentAuthentication();
Mono<String> defaultedRegistrationId = Mono.justOrEmpty(registrationId)
.switchIfEmpty(clientRegistrationId(defaultedAuthentication))
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("The clientRegistrationId could not be resolved. Please provide one")));
Mono<ServerWebExchange> defaultedExchange = Mono.justOrEmpty(exchange)
.switchIfEmpty(currentServerWebExchange());
return Mono.zip(defaultedRegistrationId, defaultedAuthentication, defaultedExchange)
.map(t3 -> new ServerOAuth2AuthorizeRequest(t3.getT1(), t3.getT2(), t3.getT3()));
}
private Mono<Authentication> currentAuthentication() {
return ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
}
private Mono<String> clientRegistrationId(Mono<Authentication> authentication) {
return authentication
.filter(t -> t instanceof OAuth2AuthenticationToken)
.cast(OAuth2AuthenticationToken.class)
.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
}
private Mono<ServerWebExchange> currentServerWebExchange() {
return Mono.subscriberContext()
.filter(c -> c.hasKey(ServerWebExchange.class))
.map(c -> c.get(ServerWebExchange.class));
}
} }

View File

@ -0,0 +1,143 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.server;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
/**
* The default implementation of a {@link ServerOAuth2AuthorizedClientManager}.
*
* @author Joe Grandja
* @since 5.2
* @see ServerOAuth2AuthorizedClientManager
* @see ReactiveOAuth2AuthorizedClientProvider
*/
public final class DefaultServerOAuth2AuthorizedClientManager implements ServerOAuth2AuthorizedClientManager {
private final ReactiveClientRegistrationRepository clientRegistrationRepository;
private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = context -> Mono.empty();
private Function<ServerOAuth2AuthorizeRequest, Map<String, Object>> contextAttributesMapper = new DefaultContextAttributesMapper();
/**
* Constructs a {@code DefaultServerOAuth2AuthorizedClientManager} using the provided parameters.
*
* @param clientRegistrationRepository the repository of client registrations
* @param authorizedClientRepository the repository of authorized clients
*/
public DefaultServerOAuth2AuthorizedClientManager(ReactiveClientRegistrationRepository clientRegistrationRepository,
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
this.clientRegistrationRepository = clientRegistrationRepository;
this.authorizedClientRepository = authorizedClientRepository;
}
@Override
public Mono<OAuth2AuthorizedClient> authorize(ServerOAuth2AuthorizeRequest authorizeRequest) {
Assert.notNull(authorizeRequest, "authorizeRequest cannot be null");
String clientRegistrationId = authorizeRequest.getClientRegistrationId();
Authentication principal = authorizeRequest.getPrincipal();
ServerWebExchange serverWebExchange = authorizeRequest.getServerWebExchange();
return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient())
.switchIfEmpty(Mono.defer(() ->
this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange)))
.flatMap(authorizedClient -> {
// Re-authorize
OAuth2AuthorizationContext reauthorizationContext =
OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient)
.principal(principal)
.attributes(this.contextAttributesMapper.apply(authorizeRequest))
.build();
return Mono.just(reauthorizationContext)
.flatMap(this.authorizedClientProvider::authorize)
.doOnNext(reauthorizedClient ->
this.authorizedClientRepository.saveAuthorizedClient(
reauthorizedClient, principal, serverWebExchange))
// Return the `authorizedClient` if `reauthorizedClient` is null, e.g. re-authorization is not supported
.defaultIfEmpty(authorizedClient);
})
.switchIfEmpty(Mono.defer(() ->
// Authorize
this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException(
"Could not find ClientRegistration with id '" + clientRegistrationId + "'")))
.map(clientRegistration -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration)
.principal(principal)
.attributes(this.contextAttributesMapper.apply(authorizeRequest))
.build())
.flatMap(this.authorizedClientProvider::authorize)
.doOnNext(authorizedClient ->
this.authorizedClientRepository.saveAuthorizedClient(
authorizedClient, principal, serverWebExchange))
));
}
/**
* Sets the {@link ReactiveOAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client.
*
* @param authorizedClientProvider the {@link ReactiveOAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client
*/
public void setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider) {
Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null");
this.authorizedClientProvider = authorizedClientProvider;
}
/**
* Sets the {@code Function} used for mapping attribute(s) from the {@link ServerOAuth2AuthorizeRequest} to a {@code Map} of attributes
* to be associated to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}.
*
* @param contextAttributesMapper the {@code Function} used for supplying the {@code Map} of attributes
* to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}
*/
public void setContextAttributesMapper(Function<ServerOAuth2AuthorizeRequest, Map<String, Object>> contextAttributesMapper) {
Assert.notNull(contextAttributesMapper, "contextAttributesMapper cannot be null");
this.contextAttributesMapper = contextAttributesMapper;
}
/**
* The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}.
*/
public static class DefaultContextAttributesMapper implements Function<ServerOAuth2AuthorizeRequest, Map<String, Object>> {
@Override
public Map<String, Object> apply(ServerOAuth2AuthorizeRequest authorizeRequest) {
Map<String, Object> contextAttributes = Collections.emptyMap();
String scope = authorizeRequest.getServerWebExchange().getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
if (StringUtils.hasText(scope)) {
contextAttributes = new HashMap<>();
contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
StringUtils.delimitedListToStringArray(scope, " "));
}
return contextAttributes;
}
}
}

View File

@ -0,0 +1,112 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.server;
import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange;
/**
* Represents a request the {@link ServerOAuth2AuthorizedClientManager} uses to
* {@link ServerOAuth2AuthorizedClientManager#authorize(ServerOAuth2AuthorizeRequest) authorize} (or re-authorize)
* the {@link ClientRegistration client} identified by the provided {@link #getClientRegistrationId() clientRegistrationId}.
*
* @author Joe Grandja
* @since 5.2
* @see ServerOAuth2AuthorizedClientManager
*/
public class ServerOAuth2AuthorizeRequest {
private final String clientRegistrationId;
private final OAuth2AuthorizedClient authorizedClient;
private final Authentication principal;
private final ServerWebExchange serverWebExchange;
/**
* Constructs a {@code ServerOAuth2AuthorizeRequest} using the provided parameters.
*
* @param clientRegistrationId the identifier for the {@link ClientRegistration client registration}
* @param principal the {@code Principal} (to be) associated to the authorized client
* @param serverWebExchange the {@code ServerWebExchange}
*/
public ServerOAuth2AuthorizeRequest(String clientRegistrationId, Authentication principal,
ServerWebExchange serverWebExchange) {
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
Assert.notNull(principal, "principal cannot be null");
Assert.notNull(serverWebExchange, "serverWebExchange cannot be null");
this.clientRegistrationId = clientRegistrationId;
this.authorizedClient = null;
this.principal = principal;
this.serverWebExchange = serverWebExchange;
}
/**
* Constructs a {@code ServerOAuth2AuthorizeRequest} using the provided parameters.
*
* @param authorizedClient the {@link OAuth2AuthorizedClient authorized client}
* @param principal the {@code Principal} (to be) associated to the authorized client
* @param serverWebExchange the {@code ServerWebExchange}
*/
public ServerOAuth2AuthorizeRequest(OAuth2AuthorizedClient authorizedClient, Authentication principal,
ServerWebExchange serverWebExchange) {
Assert.notNull(authorizedClient, "authorizedClient cannot be null");
Assert.notNull(principal, "principal cannot be null");
Assert.notNull(serverWebExchange, "serverWebExchange cannot be null");
this.clientRegistrationId = authorizedClient.getClientRegistration().getRegistrationId();
this.authorizedClient = authorizedClient;
this.principal = principal;
this.serverWebExchange = serverWebExchange;
}
/**
* Returns the identifier for the {@link ClientRegistration client registration}.
*
* @return the identifier for the client registration
*/
public String getClientRegistrationId() {
return this.clientRegistrationId;
}
/**
* Returns the {@link OAuth2AuthorizedClient authorized client} or {@code null} if it was not provided.
*
* @return the {@link OAuth2AuthorizedClient} or {@code null} if it was not provided
*/
@Nullable
public OAuth2AuthorizedClient getAuthorizedClient() {
return this.authorizedClient;
}
/**
* Returns the {@code Principal} (to be) associated to the authorized client.
*
* @return the {@code Principal} (to be) associated to the authorized client
*/
public Authentication getPrincipal() {
return this.principal;
}
/**
* Returns the {@link ServerWebExchange}.
*
* @return the {@link ServerWebExchange}
*/
public ServerWebExchange getServerWebExchange() {
return this.serverWebExchange;
}
}

View File

@ -0,0 +1,62 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.server;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import reactor.core.publisher.Mono;
/**
* Implementations of this interface are responsible for the overall management
* of {@link OAuth2AuthorizedClient Authorized Client(s)}.
*
* <p>
* The primary responsibilities include:
* <ol>
* <li>Authorizing (or re-authorizing) an OAuth 2.0 Client
* by leveraging a {@link ReactiveOAuth2AuthorizedClientProvider}(s).</li>
* <li>Managing the persistence of an {@link OAuth2AuthorizedClient} between requests,
* typically using an {@link ServerOAuth2AuthorizedClientRepository}.</li>
* </ol>
*
* @author Joe Grandja
* @since 5.2
* @see OAuth2AuthorizedClient
* @see ReactiveOAuth2AuthorizedClientProvider
* @see ServerOAuth2AuthorizedClientRepository
*/
public interface ServerOAuth2AuthorizedClientManager {
/**
* Attempt to authorize or re-authorize (if required) the {@link ClientRegistration client}
* identified by the provided {@link ServerOAuth2AuthorizeRequest#getClientRegistrationId() clientRegistrationId}.
* Implementations must return an empty {@code Mono} if authorization is not supported for the specified client,
* e.g. the associated {@link ReactiveOAuth2AuthorizedClientProvider}(s) does not support
* the {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type configured for the client.
*
* <p>
* In the case of re-authorization, implementations must return the provided {@link ServerOAuth2AuthorizeRequest#getAuthorizedClient() authorized client}
* if re-authorization is not supported for the client OR is not required,
* e.g. a {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available OR
* the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired.
*
* @param authorizeRequest the authorize request
* @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if authorization is not supported for the specified client
*/
Mono<OAuth2AuthorizedClient> authorize(ServerOAuth2AuthorizeRequest authorizeRequest);
}

View File

@ -0,0 +1,84 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.junit.Before;
import org.junit.Test;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Tests for {@link AuthorizationCodeReactiveOAuth2AuthorizedClientProvider}.
*
* @author Joe Grandja
*/
public class AuthorizationCodeReactiveOAuth2AuthorizedClientProviderTests {
private AuthorizationCodeReactiveOAuth2AuthorizedClientProvider authorizedClientProvider;
private ClientRegistration clientRegistration;
private OAuth2AuthorizedClient authorizedClient;
private Authentication principal;
@Before
public void setup() {
this.authorizedClientProvider = new AuthorizationCodeReactiveOAuth2AuthorizedClientProvider();
this.clientRegistration = TestClientRegistrations.clientRegistration().build();
this.authorizedClient = new OAuth2AuthorizedClient(
this.clientRegistration, "principal", TestOAuth2AccessTokens.scopes("read", "write"));
this.principal = new TestingAuthenticationToken("principal", "password");
}
@Test
public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null).block())
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void authorizeWhenNotAuthorizationCodeThenUnableToAuthorize() {
ClientRegistration clientCredentialsClient = TestClientRegistrations.clientCredentials().build();
OAuth2AuthorizationContext authorizationContext =
OAuth2AuthorizationContext.withClientRegistration(clientCredentialsClient)
.principal(this.principal)
.build();
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
}
@Test
public void authorizeWhenAuthorizationCodeAndAuthorizedThenNotAuthorize() {
OAuth2AuthorizationContext authorizationContext =
OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.build();
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
}
@Test
public void authorizeWhenAuthorizationCodeAndNotAuthorizedThenAuthorize() {
OAuth2AuthorizationContext authorizationContext =
OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration)
.principal(this.principal)
.build();
assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext).block())
.isInstanceOf(ClientAuthorizationRequiredException.class);
}
}

View File

@ -0,0 +1,150 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.junit.Before;
import org.junit.Test;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import reactor.core.publisher.Mono;
import java.time.Duration;
import java.time.Instant;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Tests for {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider}.
*
* @author Joe Grandja
*/
public class ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests {
private ClientCredentialsReactiveOAuth2AuthorizedClientProvider authorizedClientProvider;
private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient;
private ClientRegistration clientRegistration;
private Authentication principal;
@Before
public void setup() {
this.authorizedClientProvider = new ClientCredentialsReactiveOAuth2AuthorizedClientProvider();
this.accessTokenResponseClient = mock(ReactiveOAuth2AccessTokenResponseClient.class);
this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient);
this.clientRegistration = TestClientRegistrations.clientCredentials().build();
this.principal = new TestingAuthenticationToken("principal", "password");
}
@Test
public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("accessTokenResponseClient cannot be null");
}
@Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("clockSkew cannot be null");
}
@Test
public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1)))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("clockSkew must be >= 0");
}
@Test
public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null).block())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("context cannot be null");
}
@Test
public void authorizeWhenNotClientCredentialsThenUnableToAuthorize() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
OAuth2AuthorizationContext authorizationContext =
OAuth2AuthorizationContext.withClientRegistration(clientRegistration)
.principal(this.principal)
.build();
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
}
@Test
public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
OAuth2AuthorizationContext authorizationContext =
OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration)
.principal(this.principal)
.build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block();
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
}
@Test
public void authorizeWhenClientCredentialsAndTokenExpiredThenReauthorize() {
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60));
OAuth2AccessToken accessToken = new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.clientRegistration, this.principal.getName(), accessToken);
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
OAuth2AuthorizationContext authorizationContext =
OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block();
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
}
@Test
public void authorizeWhenClientCredentialsAndTokenNotExpiredThenNotReauthorize() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.noScopes());
OAuth2AuthorizationContext authorizationContext =
OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
}
}

View File

@ -0,0 +1,97 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.junit.Test;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import reactor.core.publisher.Mono;
import java.util.Collections;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Tests for {@link DelegatingReactiveOAuth2AuthorizedClientProvider}.
*
* @author Joe Grandja
*/
public class DelegatingReactiveOAuth2AuthorizedClientProviderTests {
@Test
public void constructorWhenProvidersIsEmptyThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new DelegatingReactiveOAuth2AuthorizedClientProvider(new ReactiveOAuth2AuthorizedClientProvider[0]))
.isInstanceOf(IllegalArgumentException.class);
assertThatThrownBy(() -> new DelegatingReactiveOAuth2AuthorizedClientProvider(Collections.emptyList()))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() {
DelegatingReactiveOAuth2AuthorizedClientProvider delegate = new DelegatingReactiveOAuth2AuthorizedClientProvider(
mock(ReactiveOAuth2AuthorizedClientProvider.class));
assertThatThrownBy(() -> delegate.authorize(null).block())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("context cannot be null");
}
@Test
public void authorizeWhenProviderCanAuthorizeThenReturnAuthorizedClient() {
Authentication principal = new TestingAuthenticationToken("principal", "password");
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
clientRegistration, principal.getName(), TestOAuth2AccessTokens.noScopes());
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider1 = mock(ReactiveOAuth2AuthorizedClientProvider.class);
when(authorizedClientProvider1.authorize(any())).thenReturn(Mono.empty());
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider2 = mock(ReactiveOAuth2AuthorizedClientProvider.class);
when(authorizedClientProvider2.authorize(any())).thenReturn(Mono.empty());
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider3 = mock(ReactiveOAuth2AuthorizedClientProvider.class);
when(authorizedClientProvider3.authorize(any())).thenReturn(Mono.just(authorizedClient));
DelegatingReactiveOAuth2AuthorizedClientProvider delegate = new DelegatingReactiveOAuth2AuthorizedClientProvider(
authorizedClientProvider1, authorizedClientProvider2, authorizedClientProvider3);
OAuth2AuthorizationContext context = OAuth2AuthorizationContext.withClientRegistration(clientRegistration)
.principal(principal)
.build();
OAuth2AuthorizedClient reauthorizedClient = delegate.authorize(context).block();
assertThat(reauthorizedClient).isSameAs(authorizedClient);
}
@Test
public void authorizeWhenProviderCantAuthorizeThenReturnNull() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
OAuth2AuthorizationContext context = OAuth2AuthorizationContext.withClientRegistration(clientRegistration)
.principal(new TestingAuthenticationToken("principal", "password"))
.build();
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider1 = mock(ReactiveOAuth2AuthorizedClientProvider.class);
when(authorizedClientProvider1.authorize(any())).thenReturn(Mono.empty());
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider2 = mock(ReactiveOAuth2AuthorizedClientProvider.class);
when(authorizedClientProvider2.authorize(any())).thenReturn(Mono.empty());
DelegatingReactiveOAuth2AuthorizedClientProvider delegate = new DelegatingReactiveOAuth2AuthorizedClientProvider(
authorizedClientProvider1, authorizedClientProvider2);
assertThat(delegate.authorize(context).block()).isNull();
}
}

View File

@ -44,28 +44,28 @@ public class OAuth2AuthorizationContextTests {
} }
@Test @Test
public void forClientWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { public void withClientRegistrationWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2AuthorizationContext.withClientRegistration((ClientRegistration) null).build()) assertThatThrownBy(() -> OAuth2AuthorizationContext.withClientRegistration(null).build())
.isInstanceOf(IllegalArgumentException.class) .isInstanceOf(IllegalArgumentException.class)
.hasMessage("clientRegistration cannot be null"); .hasMessage("clientRegistration cannot be null");
} }
@Test @Test
public void forClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { public void withAuthorizedClientWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2AuthorizationContext.withAuthorizedClient((OAuth2AuthorizedClient) null).build()) assertThatThrownBy(() -> OAuth2AuthorizationContext.withAuthorizedClient(null).build())
.isInstanceOf(IllegalArgumentException.class) .isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizedClient cannot be null"); .hasMessage("authorizedClient cannot be null");
} }
@Test @Test
public void forClientWhenPrincipalIsNullThenThrowIllegalArgumentException() { public void withClientRegistrationWhenPrincipalIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration).build()) assertThatThrownBy(() -> OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration).build())
.isInstanceOf(IllegalArgumentException.class) .isInstanceOf(IllegalArgumentException.class)
.hasMessage("principal cannot be null"); .hasMessage("principal cannot be null");
} }
@Test @Test
public void forClientWhenAllValuesProvidedThenAllValuesAreSet() { public void withAuthorizedClientWhenAllValuesProvidedThenAllValuesAreSet() {
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient)
.principal(this.principal) .principal(this.principal)
.attribute("attribute1", "value1") .attribute("attribute1", "value1")

View File

@ -0,0 +1,246 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import reactor.core.publisher.Mono;
import java.time.Duration;
import java.time.Instant;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
/**
* Tests for {@link ReactiveOAuth2AuthorizedClientProviderBuilder}.
*
* @author Joe Grandja
*/
public class ReactiveOAuth2AuthorizedClientProviderBuilderTests {
private ClientRegistration.Builder clientRegistrationBuilder;
private Authentication principal;
private MockWebServer server;
@Before
public void setup() throws Exception {
this.server = new MockWebServer();
this.server.start();
String tokenUri = this.server.url("/oauth2/token").toString();
this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration().tokenUri(tokenUri);
this.principal = new TestingAuthenticationToken("principal", "password");
}
@After
public void cleanup() throws Exception {
this.server.shutdown();
}
@Test
public void providerWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> ReactiveOAuth2AuthorizedClientProviderBuilder.builder().provider(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void buildWhenAuthorizationCodeProviderThenProviderAuthorizes() {
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider =
ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode()
.build();
OAuth2AuthorizationContext authorizationContext =
OAuth2AuthorizationContext.withClientRegistration(this.clientRegistrationBuilder.build())
.principal(this.principal)
.build();
assertThatThrownBy(() -> authorizedClientProvider.authorize(authorizationContext).block())
.isInstanceOf(ClientAuthorizationRequiredException.class);
}
@Test
public void buildWhenRefreshTokenProviderThenProviderReauthorizes() throws Exception {
String accessTokenSuccessResponse = "{\n" +
" \"access_token\": \"access-token-1234\",\n" +
" \"token_type\": \"bearer\",\n" +
" \"expires_in\": \"3600\"\n" +
"}\n";
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider =
ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
.refreshToken()
.build();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.clientRegistrationBuilder.build(),
this.principal.getName(),
expiredAccessToken(),
TestOAuth2RefreshTokens.refreshToken());
OAuth2AuthorizationContext authorizationContext =
OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(authorizationContext).block();
assertThat(reauthorizedClient).isNotNull();
assertThat(this.server.getRequestCount()).isEqualTo(1);
RecordedRequest recordedRequest = this.server.takeRequest();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("grant_type=refresh_token");
}
@Test
public void buildWhenClientCredentialsProviderThenProviderAuthorizes() throws Exception {
String accessTokenSuccessResponse = "{\n" +
" \"access_token\": \"access-token-1234\",\n" +
" \"token_type\": \"bearer\",\n" +
" \"expires_in\": \"3600\"\n" +
"}\n";
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider =
ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
.clientCredentials()
.build();
OAuth2AuthorizationContext authorizationContext =
OAuth2AuthorizationContext.withClientRegistration(this.clientRegistrationBuilder.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).build())
.principal(this.principal)
.build();
OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext).block();
assertThat(authorizedClient).isNotNull();
assertThat(this.server.getRequestCount()).isEqualTo(1);
RecordedRequest recordedRequest = this.server.takeRequest();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("grant_type=client_credentials");
}
@Test
public void buildWhenAllProvidersThenProvidersAuthorize() throws Exception {
String accessTokenSuccessResponse = "{\n" +
" \"access_token\": \"access-token-1234\",\n" +
" \"token_type\": \"bearer\",\n" +
" \"expires_in\": \"3600\"\n" +
"}\n";
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider =
ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode()
.refreshToken()
.clientCredentials()
.build();
// authorization_code
OAuth2AuthorizationContext authorizationCodeContext =
OAuth2AuthorizationContext.withClientRegistration(this.clientRegistrationBuilder.build())
.principal(this.principal)
.build();
assertThatThrownBy(() -> authorizedClientProvider.authorize(authorizationCodeContext).block())
.isInstanceOf(ClientAuthorizationRequiredException.class);
// refresh_token
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.clientRegistrationBuilder.build(),
this.principal.getName(),
expiredAccessToken(),
TestOAuth2RefreshTokens.refreshToken());
OAuth2AuthorizationContext refreshTokenContext =
OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(refreshTokenContext).block();
assertThat(reauthorizedClient).isNotNull();
assertThat(this.server.getRequestCount()).isEqualTo(1);
RecordedRequest recordedRequest = this.server.takeRequest();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("grant_type=refresh_token");
// client_credentials
OAuth2AuthorizationContext clientCredentialsContext =
OAuth2AuthorizationContext.withClientRegistration(this.clientRegistrationBuilder.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).build())
.principal(this.principal)
.build();
authorizedClient = authorizedClientProvider.authorize(clientCredentialsContext).block();
assertThat(authorizedClient).isNotNull();
assertThat(this.server.getRequestCount()).isEqualTo(2);
recordedRequest = this.server.takeRequest();
formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("grant_type=client_credentials");
}
@Test
public void buildWhenCustomProviderThenProviderCalled() {
ReactiveOAuth2AuthorizedClientProvider customProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class);
when(customProvider.authorize(any())).thenReturn(Mono.empty());
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider =
ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
.provider(customProvider)
.build();
OAuth2AuthorizationContext authorizationContext =
OAuth2AuthorizationContext.withClientRegistration(this.clientRegistrationBuilder.build())
.principal(this.principal)
.build();
authorizedClientProvider.authorize(authorizationContext).block();
verify(customProvider).authorize(any(OAuth2AuthorizationContext.class));
}
private OAuth2AccessToken expiredAccessToken() {
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60));
return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt);
}
private MockResponse jsonResponse(String json) {
return new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(json);
}
}

View File

@ -0,0 +1,188 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import reactor.core.publisher.Mono;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashSet;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
/**
* Tests for {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider}.
*
* @author Joe Grandja
*/
public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests {
private RefreshTokenReactiveOAuth2AuthorizedClientProvider authorizedClientProvider;
private ReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient;
private ClientRegistration clientRegistration;
private Authentication principal;
private OAuth2AuthorizedClient authorizedClient;
@Before
public void setup() {
this.authorizedClientProvider = new RefreshTokenReactiveOAuth2AuthorizedClientProvider();
this.accessTokenResponseClient = mock(ReactiveOAuth2AccessTokenResponseClient.class);
this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient);
this.clientRegistration = TestClientRegistrations.clientRegistration().build();
this.principal = new TestingAuthenticationToken("principal", "password");
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60));
OAuth2AccessToken expiredAccessToken = new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt);
this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(),
expiredAccessToken, TestOAuth2RefreshTokens.refreshToken());
}
@Test
public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("accessTokenResponseClient cannot be null");
}
@Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("clockSkew cannot be null");
}
@Test
public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1)))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("clockSkew must be >= 0");
}
@Test
public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null).block())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("context cannot be null");
}
@Test
public void authorizeWhenNotAuthorizedThenUnableToReauthorize() {
OAuth2AuthorizationContext authorizationContext =
OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration)
.principal(this.principal)
.build();
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
}
@Test
public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.clientRegistration, this.principal.getName(), this.authorizedClient.getAccessToken());
OAuth2AuthorizationContext authorizationContext =
OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
}
@Test
public void authorizeWhenAuthorizedAndAccessTokenNotExpiredThenNotReauthorize() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken());
OAuth2AuthorizationContext authorizationContext =
OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
}
@Test
public void authorizeWhenAuthorizedAndAccessTokenExpiredThenReauthorize() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse()
.refreshToken("new-refresh-token")
.build();
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
OAuth2AuthorizationContext authorizationContext =
OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.build();
OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block();
assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(reauthorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
}
@Test
public void authorizeWhenAuthorizedAndRequestScopeProvidedThenScopeRequested() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse()
.refreshToken("new-refresh-token")
.build();
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
String[] requestScope = new String[] { "read", "write" };
OAuth2AuthorizationContext authorizationContext =
OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, requestScope)
.build();
this.authorizedClientProvider.authorize(authorizationContext).block();
ArgumentCaptor<OAuth2RefreshTokenGrantRequest> refreshTokenGrantRequestArgCaptor =
ArgumentCaptor.forClass(OAuth2RefreshTokenGrantRequest.class);
verify(this.accessTokenResponseClient).getTokenResponse(refreshTokenGrantRequestArgCaptor.capture());
assertThat(refreshTokenGrantRequestArgCaptor.getValue().getScopes()).isEqualTo(new HashSet<>(Arrays.asList(requestScope)));
}
@Test
public void authorizeWhenAuthorizedAndInvalidRequestScopeProvidedThenThrowIllegalArgumentException() {
String invalidRequestScope = "read write";
OAuth2AuthorizationContext authorizationContext =
OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, invalidRequestScope)
.build();
assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext).block())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageStartingWith("The context attribute must be of type String[] '" +
OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'");
}
}

View File

@ -0,0 +1,217 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import java.time.Instant;
import java.util.Collections;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Tests for {@link WebClientReactiveRefreshTokenTokenResponseClient}.
*
* @author Joe Grandja
*/
public class WebClientReactiveRefreshTokenTokenResponseClientTests {
private WebClientReactiveRefreshTokenTokenResponseClient tokenResponseClient = new WebClientReactiveRefreshTokenTokenResponseClient();
private ClientRegistration.Builder clientRegistrationBuilder;
private OAuth2AccessToken accessToken;
private OAuth2RefreshToken refreshToken;
private MockWebServer server;
@Before
public void setup() throws Exception {
this.server = new MockWebServer();
this.server.start();
String tokenUri = this.server.url("/oauth2/token").toString();
this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration().tokenUri(tokenUri);
this.accessToken = TestOAuth2AccessTokens.scopes("read", "write");
this.refreshToken = TestOAuth2RefreshTokens.refreshToken();
}
@After
public void cleanup() throws Exception {
this.server.shutdown();
}
@Test
public void setWebClientWhenClientIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.tokenResponseClient.setWebClient(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void getTokenResponseWhenRequestIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(null).block())
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
String accessTokenSuccessResponse = "{\n" +
" \"access_token\": \"access-token-1234\",\n" +
" \"token_type\": \"bearer\",\n" +
" \"expires_in\": \"3600\"\n" +
"}\n";
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block();
Instant expiresAtAfter = Instant.now().plusSeconds(3600);
RecordedRequest recordedRequest = this.server.takeRequest();
assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString());
assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)).isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8");
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic ");
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("grant_type=refresh_token");
assertThat(formParameters).contains("refresh_token=refresh-token");
assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234");
assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER);
assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter);
assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly(this.accessToken.getScopes().toArray(new String[0]));
assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo(this.refreshToken.getTokenValue());
}
@Test
public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception {
String accessTokenSuccessResponse = "{\n" +
" \"access_token\": \"access-token-1234\",\n" +
" \"token_type\": \"bearer\",\n" +
" \"expires_in\": \"3600\"\n" +
"}\n";
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.clientAuthenticationMethod(ClientAuthenticationMethod.POST)
.build();
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
clientRegistration, this.accessToken, this.refreshToken);
this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block();
RecordedRequest recordedRequest = this.server.takeRequest();
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("client_id=client-id");
assertThat(formParameters).contains("client_secret=client-secret");
}
@Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
String accessTokenSuccessResponse = "{\n" +
" \"access_token\": \"access-token-1234\",\n" +
" \"token_type\": \"not-bearer\",\n" +
" \"expires_in\": \"3600\"\n" +
"}\n";
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block())
.isInstanceOf(OAuth2AuthorizationException.class)
.hasMessageContaining("[invalid_token_response] An error occurred parsing the Access Token response")
.hasMessageContaining("Token type must be \"Bearer\"");
}
@Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception {
String accessTokenSuccessResponse = "{\n" +
" \"access_token\": \"access-token-1234\",\n" +
" \"token_type\": \"bearer\",\n" +
" \"expires_in\": \"3600\",\n" +
" \"scope\": \"read\"\n" +
"}\n";
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken, Collections.singleton("read"));
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block();
RecordedRequest recordedRequest = this.server.takeRequest();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("scope=read");
assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read");
}
@Test
public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{\n" +
" \"error\": \"unauthorized_client\"\n" +
"}\n";
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400));
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block())
.isInstanceOf(OAuth2AuthorizationException.class)
.hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response")
.hasMessageContaining("HTTP Status Code 400");
}
@Test
public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() {
this.server.enqueue(new MockResponse().setResponseCode(500));
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block())
.isInstanceOf(OAuth2AuthorizationException.class)
.hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response")
.hasMessageContaining("HTTP Status Code 500");
}
private MockResponse jsonResponse(String json) {
return new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(json);
}
}

View File

@ -36,15 +36,23 @@ import org.springframework.http.codec.json.Jackson2JsonEncoder;
import org.springframework.http.codec.multipart.MultipartHttpMessageWriter; import org.springframework.http.codec.multipart.MultipartHttpMessageWriter;
import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.mock.http.client.reactive.MockClientHttpRequest; import org.springframework.mock.http.client.reactive.MockClientHttpRequest;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.reactive.function.client.OAuth2AuthorizedClientResolver.Request; import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken;
@ -68,12 +76,10 @@ import java.util.Map;
import java.util.Optional; import java.util.Optional;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.*;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.http.HttpMethod.GET; import static org.springframework.http.HttpMethod.GET;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId; import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient; import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient;
@ -91,10 +97,12 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
private ReactiveClientRegistrationRepository clientRegistrationRepository; private ReactiveClientRegistrationRepository clientRegistrationRepository;
@Mock @Mock
private OAuth2AuthorizedClientResolver authorizedClientResolver; private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient;
@Mock @Mock
private ServerWebExchange serverWebExchange; private ReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> refreshTokenTokenResponseClient;
private ServerWebExchange serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build();
@Captor @Captor
private ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor; private ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor;
@ -113,7 +121,45 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Before @Before
public void setup() { public void setup() {
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider =
ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode()
.refreshToken(configurer -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient))
.clientCredentials(configurer -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient))
.build();
DefaultServerOAuth2AuthorizedClientManager authorizedClientManager = new DefaultServerOAuth2AuthorizedClientManager(
this.clientRegistrationRepository, this.authorizedClientRepository);
authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager);
}
@Test
public void constructorWhenAuthorizedClientManagerIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new ServerOAuth2AuthorizedClientExchangeFilterFunction(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void setClientCredentialsTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.function.setClientCredentialsTokenResponseClient(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("clientCredentialsTokenResponseClient cannot be null");
}
@Test
public void setClientCredentialsTokenResponseClientWhenNotDefaultAuthorizedClientManagerThenThrowIllegalStateException() {
assertThatThrownBy(() -> this.function.setClientCredentialsTokenResponseClient(new WebClientReactiveClientCredentialsTokenResponseClient()))
.isInstanceOf(IllegalStateException.class)
.hasMessage("The client cannot be set when the constructor used is \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientManager)\". " +
"Instead, use the constructor \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\".");
}
@Test
public void setAccessTokenExpiresSkewWhenNotDefaultAuthorizedClientManagerThenThrowIllegalStateException() {
assertThatThrownBy(() -> this.function.setAccessTokenExpiresSkew(Duration.ofSeconds(30)))
.isInstanceOf(IllegalStateException.class)
.hasMessage("The accessTokenExpiresSkew cannot be set when the constructor used is \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientManager)\". " +
"Instead, use the constructor \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\".");
} }
@Test @Test
@ -134,7 +180,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
.attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
this.function.filter(request, this.exchange).block(); this.function.filter(request, this.exchange)
.subscriberContext(serverWebExchange())
.block();
assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue()); assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue());
} }
@ -148,7 +196,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
.attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
this.function.filter(request, this.exchange).block(); this.function.filter(request, this.exchange)
.subscriberContext(serverWebExchange())
.block();
HttpHeaders headers = this.exchange.getRequest().headers(); HttpHeaders headers = this.exchange.getRequest().headers();
assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue()); assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue());
@ -156,47 +206,35 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test @Test
public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() { public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() {
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("new-token")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(360)
.build();
when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
ClientRegistration registration = TestClientRegistrations.clientCredentials().build(); ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
String clientRegistrationId = registration.getClientId();
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.authorizedClientResolver);
OAuth2AccessToken newAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
"new-token",
Instant.now(),
Instant.now().plus(Duration.ofDays(1)));
OAuth2AuthorizedClient newAuthorizedClient = new OAuth2AuthorizedClient(registration,
"principalName", newAccessToken, null);
Request r = new Request(clientRegistrationId, authentication, null);
when(this.authorizedClientResolver.clientCredentials(any(), any(), any())).thenReturn(Mono.just(newAuthorizedClient));
when(this.authorizedClientResolver.createDefaultedRequest(any(), any(), any())).thenReturn(Mono.just(r));
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
OAuth2AccessToken accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), OAuth2AccessToken accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
this.accessToken.getTokenValue(), this.accessToken.getTokenValue(),
issuedAt, issuedAt,
accessTokenExpiresAt); accessTokenExpiresAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration,
"principalName", accessToken, null); "principalName", accessToken, null);
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
this.function.filter(request, this.exchange) this.function.filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.subscriberContext(serverWebExchange())
.block(); .block();
verify(this.clientCredentialsTokenResponseClient).getTokenResponse(any());
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any()); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
verify(this.authorizedClientResolver).clientCredentials(any(), any(), any());
verify(this.authorizedClientResolver).createDefaultedRequest(any(), any(), any());
List<ClientRequest> requests = this.exchange.getRequests(); List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1); assertThat(requests).hasSize(1);
@ -212,8 +250,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
ClientRegistration registration = TestClientRegistrations.clientCredentials().build(); ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.authorizedClientResolver);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration,
"principalName", this.accessToken, null); "principalName", this.accessToken, null);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
@ -222,10 +258,10 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
this.function.filter(request, this.exchange) this.function.filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.subscriberContext(serverWebExchange())
.block(); .block();
verify(this.authorizedClientResolver, never()).clientCredentials(any(), any(), any()); verify(this.clientCredentialsTokenResponseClient, never()).getTokenResponse(any());
verify(this.authorizedClientResolver, never()).createDefaultedRequest(any(), any(), any());
List<ClientRequest> requests = this.exchange.getRequests(); List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1); assertThat(requests).hasSize(1);
@ -238,24 +274,23 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test @Test
public void filterWhenRefreshRequiredThenRefresh() { public void filterWhenRefreshRequiredThenRefresh() {
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
.tokenType(OAuth2AccessToken.TokenType.BEARER) .tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(3600) .expiresIn(3600)
.refreshToken("refresh-1") .refreshToken("refresh-1")
.build(); .build();
when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response)); when(this.refreshTokenTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(response));
Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
this.accessToken.getTokenValue(), this.accessToken.getTokenValue(),
issuedAt, issuedAt,
accessTokenExpiresAt); accessTokenExpiresAt);
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, refreshToken); "principalName", this.accessToken, refreshToken);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
@ -263,8 +298,10 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
this.function.filter(request, this.exchange) this.function.filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.subscriberContext(serverWebExchange())
.block(); .block();
verify(this.refreshTokenTokenResponseClient).getTokenResponse(any());
verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(authentication), any()); verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(authentication), any());
OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor.getValue(); OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor.getValue();
@ -272,84 +309,26 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(response.getRefreshToken()); assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(response.getRefreshToken());
List<ClientRequest> requests = this.exchange.getRequests(); List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(2); assertThat(requests).hasSize(1);
ClientRequest request0 = requests.get(0); ClientRequest request0 = requests.get(0);
assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1");
assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com/login/oauth/access_token"); assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request0.method()).isEqualTo(HttpMethod.POST); assertThat(request0.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token"); assertThat(getBody(request0)).isEmpty();
ClientRequest request1 = requests.get(1);
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1");
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request1)).isEmpty();
}
@Test
public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefreshToken() {
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(3600)
// .refreshToken(xxx) // No refreshToken in response
.build();
when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response));
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
this.accessToken.getTokenValue(),
issuedAt,
accessTokenExpiresAt);
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, refreshToken);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.build();
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
this.function.filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.block();
verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(authentication), any());
OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor.getValue();
assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken());
assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(authorizedClient.getRefreshToken());
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(2);
ClientRequest request0 = requests.get(0);
assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com/login/oauth/access_token");
assertThat(request0.method()).isEqualTo(HttpMethod.POST);
assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token");
ClientRequest request1 = requests.get(1);
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1");
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request1)).isEmpty();
} }
@Test @Test
public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() { public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
.tokenType(OAuth2AccessToken.TokenType.BEARER) .tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(3600) .expiresIn(3600)
.refreshToken("refresh-1") .refreshToken("refresh-1")
.build(); .build();
when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response)); when(this.refreshTokenTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(response));
Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
this.accessToken.getTokenValue(), this.accessToken.getTokenValue(),
issuedAt, issuedAt,
@ -363,24 +342,20 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
.build(); .build();
this.function.filter(request, this.exchange) this.function.filter(request, this.exchange)
.subscriberContext(serverWebExchange())
.block(); .block();
verify(this.refreshTokenTokenResponseClient).getTokenResponse(any());
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any()); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any());
List<ClientRequest> requests = this.exchange.getRequests(); List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(2); assertThat(requests).hasSize(1);
ClientRequest request0 = requests.get(0); ClientRequest request0 = requests.get(0);
assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1");
assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com/login/oauth/access_token"); assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request0.method()).isEqualTo(HttpMethod.POST); assertThat(request0.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token"); assertThat(getBody(request0)).isEmpty();
ClientRequest request1 = requests.get(1);
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1");
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request1)).isEmpty();
} }
@Test @Test
@ -391,7 +366,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
.attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
this.function.filter(request, this.exchange).block(); this.function.filter(request, this.exchange)
.subscriberContext(serverWebExchange())
.block();
List<ClientRequest> requests = this.exchange.getRequests(); List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1); assertThat(requests).hasSize(1);
@ -412,7 +389,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
.attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
this.function.filter(request, this.exchange).block(); this.function.filter(request, this.exchange)
.subscriberContext(serverWebExchange())
.block();
List<ClientRequest> requests = this.exchange.getRequests(); List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1); assertThat(requests).hasSize(1);
@ -430,12 +409,13 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, refreshToken); "principalName", this.accessToken, refreshToken);
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient)); when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(clientRegistrationId(this.registration.getRegistrationId())) .attributes(clientRegistrationId(this.registration.getRegistrationId()))
.build(); .build();
this.function.filter(request, this.exchange).block(); this.function.filter(request, this.exchange)
.subscriberContext(serverWebExchange())
.block();
List<ClientRequest> requests = this.exchange.getRequests(); List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1); assertThat(requests).hasSize(1);
@ -454,11 +434,12 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, refreshToken); "principalName", this.accessToken, refreshToken);
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient)); when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.build(); .build();
this.function.filter(request, this.exchange).block(); this.function.filter(request, this.exchange)
.subscriberContext(serverWebExchange())
.block();
List<ClientRequest> requests = this.exchange.getRequests(); List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1); assertThat(requests).hasSize(1);
@ -478,7 +459,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, refreshToken); "principalName", this.accessToken, refreshToken);
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient)); when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.build(); .build();
@ -488,6 +468,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
this.function this.function
.filter(request, this.exchange) .filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.subscriberContext(serverWebExchange())
.block(); .block();
List<ClientRequest> requests = this.exchange.getRequests(); List<ClientRequest> requests = this.exchange.getRequests();
@ -526,7 +507,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, refreshToken); "principalName", this.accessToken, refreshToken);
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient)); when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(clientRegistrationId(this.registration.getRegistrationId())) .attributes(clientRegistrationId(this.registration.getRegistrationId()))
.build(); .build();

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2019 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -22,24 +22,32 @@ import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.core.MethodParameter; import org.springframework.core.MethodParameter;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder;
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.util.context.Context; import reactor.util.context.Context;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -55,24 +63,50 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
private ReactiveClientRegistrationRepository clientRegistrationRepository; private ReactiveClientRegistrationRepository clientRegistrationRepository;
@Mock @Mock
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private OAuth2AuthorizedClientArgumentResolver argumentResolver;
private OAuth2AuthorizedClient authorizedClient;
private ServerWebExchange serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build();
private OAuth2AuthorizedClientArgumentResolver argumentResolver;
private ClientRegistration clientRegistration;
private OAuth2AuthorizedClient authorizedClient;
private Authentication authentication = new TestingAuthenticationToken("test", "this"); private Authentication authentication = new TestingAuthenticationToken("test", "this");
@Before @Before
public void setUp() { public void setUp() {
this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, this.authorizedClientRepository); ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider =
this.authorizedClient = mock(OAuth2AuthorizedClient.class); ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode()
.refreshToken()
.clientCredentials()
.build();
DefaultServerOAuth2AuthorizedClientManager authorizedClientManager = new DefaultServerOAuth2AuthorizedClientManager(
this.clientRegistrationRepository, this.authorizedClientRepository);
authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager);
this.clientRegistration = TestClientRegistrations.clientRegistration().build();
this.authorizedClient = new OAuth2AuthorizedClient(
this.clientRegistration, this.authentication.getName(), TestOAuth2AccessTokens.noScopes());
when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.just(this.authorizedClient)); when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.just(this.authorizedClient));
} }
@Test @Test
public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null, this.authorizedClientRepository))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, null)) assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, null))
.isInstanceOf(IllegalArgumentException.class); .isInstanceOf(IllegalArgumentException.class);
} }
@Test
public void constructorWhenAuthorizedClientManagerIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test @Test
public void supportsParameterWhenParameterTypeOAuth2AuthorizedClientThenTrue() { public void supportsParameterWhenParameterTypeOAuth2AuthorizedClientThenTrue() {
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
@ -101,8 +135,6 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
@Test @Test
public void resolveArgumentWhenRegistrationIdEmptyAndOAuth2AuthenticationThenResolves() { public void resolveArgumentWhenRegistrationIdEmptyAndOAuth2AuthenticationThenResolves() {
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(
TestClientRegistrations.clientRegistration().build()));
this.authentication = mock(OAuth2AuthenticationToken.class); this.authentication = mock(OAuth2AuthenticationToken.class);
when(((OAuth2AuthenticationToken) this.authentication).getAuthorizedClientRegistrationId()).thenReturn("client1"); when(((OAuth2AuthenticationToken) this.authentication).getAuthorizedClientRegistrationId()).thenReturn("client1");
MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class); MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class);
@ -112,24 +144,19 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
@Test @Test
public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenResolves() { public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenResolves() {
this.authentication = null; this.authentication = null;
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(
TestClientRegistrations.clientRegistration().build()));
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient); assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient);
} }
@Test @Test
public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() { public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() {
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(
TestClientRegistrations.clientRegistration().build()));
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient); assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient);
} }
@Test @Test
public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() { public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() {
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just( when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.clientRegistration));
TestClientRegistrations.clientRegistration().build()));
when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.empty()); when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.empty());
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
assertThatThrownBy(() -> resolveArgument(methodParameter)) assertThatThrownBy(() -> resolveArgument(methodParameter))
@ -139,9 +166,14 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
private Object resolveArgument(MethodParameter methodParameter) { private Object resolveArgument(MethodParameter methodParameter) {
return this.argumentResolver.resolveArgument(methodParameter, null, null) return this.argumentResolver.resolveArgument(methodParameter, null, null)
.subscriberContext(this.authentication == null ? Context.empty() : ReactiveSecurityContextHolder.withAuthentication(this.authentication)) .subscriberContext(this.authentication == null ? Context.empty() : ReactiveSecurityContextHolder.withAuthentication(this.authentication))
.subscriberContext(serverWebExchange())
.block(); .block();
} }
private Context serverWebExchange() {
return Context.of(ServerWebExchange.class, this.serverWebExchange);
}
private MethodParameter getMethodParameter(String methodName, Class<?>... paramTypes) { private MethodParameter getMethodParameter(String methodName, Class<?>... paramTypes) {
Method method = ReflectionUtils.findMethod( Method method = ReflectionUtils.findMethod(
TestController.class, methodName, paramTypes); TestController.class, methodName, paramTypes);

View File

@ -0,0 +1,295 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.server;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import java.util.Collections;
import java.util.function.Function;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
/**
* Tests for {@link DefaultServerOAuth2AuthorizedClientManager}.
*
* @author Joe Grandja
*/
public class DefaultServerOAuth2AuthorizedClientManagerTests {
private ReactiveClientRegistrationRepository clientRegistrationRepository;
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider;
private Function contextAttributesMapper;
private DefaultServerOAuth2AuthorizedClientManager authorizedClientManager;
private ClientRegistration clientRegistration;
private Authentication principal;
private OAuth2AuthorizedClient authorizedClient;
private MockServerWebExchange serverWebExchange;
private ArgumentCaptor<OAuth2AuthorizationContext> authorizationContextCaptor;
@SuppressWarnings("unchecked")
@Before
public void setup() {
this.clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class);
when(this.clientRegistrationRepository.findByRegistrationId(
anyString())).thenReturn(Mono.empty());
this.authorizedClientRepository = mock(ServerOAuth2AuthorizedClientRepository.class);
when(this.authorizedClientRepository.loadAuthorizedClient(
anyString(), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(Mono.empty());
this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class);
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.empty());
this.contextAttributesMapper = mock(Function.class);
when(this.contextAttributesMapper.apply(any())).thenReturn(Collections.emptyMap());
this.authorizedClientManager = new DefaultServerOAuth2AuthorizedClientManager(
this.clientRegistrationRepository, this.authorizedClientRepository);
this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider);
this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper);
this.clientRegistration = TestClientRegistrations.clientRegistration().build();
this.principal = new TestingAuthenticationToken("principal", "password");
this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(),
TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken());
this.serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build();
this.authorizationContextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationContext.class);
}
@Test
public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new DefaultServerOAuth2AuthorizedClientManager(null, this.authorizedClientRepository))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("clientRegistrationRepository cannot be null");
}
@Test
public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new DefaultServerOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizedClientRepository cannot be null");
}
@Test
public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizedClientProvider cannot be null");
}
@Test
public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("contextAttributesMapper cannot be null");
}
@Test
public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.authorize(null).block())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizeRequest cannot be null");
}
@Test
public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() {
ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest(
"invalid-registration-id", this.principal, this.serverWebExchange);
assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'");
}
@SuppressWarnings("unchecked")
@Test
public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() {
when(this.clientRegistrationRepository.findByRegistrationId(
eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest(
this.clientRegistration.getRegistrationId(), this.principal, this.serverWebExchange);
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isNull();
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isNull();
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.serverWebExchange));
}
@SuppressWarnings("unchecked")
@Test
public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() {
when(this.clientRegistrationRepository.findByRegistrationId(
eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
when(this.authorizedClientProvider.authorize(
any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest(
this.clientRegistration.getRegistrationId(), this.principal, this.serverWebExchange);
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isNull();
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(this.authorizedClient);
verify(this.authorizedClientRepository).saveAuthorizedClient(
eq(this.authorizedClient), eq(this.principal), eq(this.serverWebExchange));
}
@SuppressWarnings("unchecked")
@Test
public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() {
when(this.clientRegistrationRepository.findByRegistrationId(
eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
when(this.authorizedClientRepository.loadAuthorizedClient(
eq(this.clientRegistration.getRegistrationId()), eq(this.principal), eq(this.serverWebExchange))).thenReturn(Mono.just(this.authorizedClient));
OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(
this.clientRegistration, this.principal.getName(),
TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken());
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient));
ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest(
this.clientRegistration.getRegistrationId(), this.principal, this.serverWebExchange);
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(any());
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(reauthorizedClient);
verify(this.authorizedClientRepository).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal), eq(this.serverWebExchange));
}
@SuppressWarnings("unchecked")
@Test
public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() {
ServerOAuth2AuthorizeRequest reauthorizeRequest = new ServerOAuth2AuthorizeRequest(
this.authorizedClient, this.principal, this.serverWebExchange);
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(this.authorizedClient);
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.serverWebExchange));
}
@SuppressWarnings("unchecked")
@Test
public void reauthorizeWhenSupportedProviderThenReauthorized() {
OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(
this.clientRegistration, this.principal.getName(),
TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken());
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient));
ServerOAuth2AuthorizeRequest reauthorizeRequest = new ServerOAuth2AuthorizeRequest(
this.authorizedClient, this.principal, this.serverWebExchange);
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(reauthorizedClient);
verify(this.authorizedClientRepository).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal), eq(this.serverWebExchange));
}
@SuppressWarnings("unchecked")
@Test
public void reauthorizeWhenRequestScopeParameterThenMappedToContext() {
OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(
this.clientRegistration, this.principal.getName(),
TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken());
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient));
// Override the mock with the default
this.authorizedClientManager.setContextAttributesMapper(
new DefaultServerOAuth2AuthorizedClientManager.DefaultContextAttributesMapper());
this.serverWebExchange = MockServerWebExchange.builder(
MockServerHttpRequest
.get("/")
.queryParam(OAuth2ParameterNames.SCOPE, "read write"))
.build();
ServerOAuth2AuthorizeRequest reauthorizeRequest = new ServerOAuth2AuthorizeRequest(
this.authorizedClient, this.principal, this.serverWebExchange);
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizationContext.getAttributes()).containsKey(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
assertThat(requestScopeAttribute).contains("read", "write");
assertThat(authorizedClient).isSameAs(reauthorizedClient);
verify(this.authorizedClientRepository).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal), eq(this.serverWebExchange));
}
}

View File

@ -0,0 +1,93 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.server;
import org.junit.Test;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Tests for {@link ServerOAuth2AuthorizeRequest}.
*
* @author Joe Grandja
*/
public class ServerOAuth2AuthorizeRequestTests {
private ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
private Authentication principal = new TestingAuthenticationToken("principal", "password");
private OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.clientRegistration, this.principal.getName(),
TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken());
private MockServerWebExchange serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build();
@Test
public void constructorWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new ServerOAuth2AuthorizeRequest((String) null, this.principal, this.serverWebExchange))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("clientRegistrationId cannot be empty");
}
@Test
public void constructorWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new ServerOAuth2AuthorizeRequest((OAuth2AuthorizedClient) null, this.principal, this.serverWebExchange))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizedClient cannot be null");
}
@Test
public void constructorWhenPrincipalIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new ServerOAuth2AuthorizeRequest(this.clientRegistration.getRegistrationId(), null, this.serverWebExchange))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("principal cannot be null");
}
@Test
public void constructorWhenServerWebExchangeIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new ServerOAuth2AuthorizeRequest(this.clientRegistration.getRegistrationId(), this.principal, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("serverWebExchange cannot be null");
}
@Test
public void constructorClientRegistrationIdWhenAllValuesProvidedThenAllValuesAreSet() {
ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest(
this.clientRegistration.getRegistrationId(), this.principal, this.serverWebExchange);
assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.clientRegistration.getRegistrationId());
assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizeRequest.getServerWebExchange()).isEqualTo(this.serverWebExchange);
}
@Test
public void constructorAuthorizedClientWhenAllValuesProvidedThenAllValuesAreSet() {
ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest(
this.authorizedClient, this.principal, this.serverWebExchange);
assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.authorizedClient.getClientRegistration().getRegistrationId());
assertThat(authorizeRequest.getAuthorizedClient()).isEqualTo(this.authorizedClient);
assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizeRequest.getServerWebExchange()).isEqualTo(this.serverWebExchange);
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2019 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,8 +18,12 @@ package sample.config;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction; import org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction;
import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClient;
@ -31,13 +35,31 @@ import org.springframework.web.reactive.function.client.WebClient;
public class WebClientConfig { public class WebClientConfig {
@Bean @Bean
WebClient webClient(ReactiveClientRegistrationRepository clientRegistrationRepository, WebClient webClient(ServerOAuth2AuthorizedClientManager authorizedClientManager) {
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
ServerOAuth2AuthorizedClientExchangeFilterFunction oauth = ServerOAuth2AuthorizedClientExchangeFilterFunction oauth =
new ServerOAuth2AuthorizedClientExchangeFilterFunction(clientRegistrationRepository, authorizedClientRepository); new ServerOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager);
oauth.setDefaultOAuth2AuthorizedClient(true); oauth.setDefaultOAuth2AuthorizedClient(true);
return WebClient.builder() return WebClient.builder()
.filter(oauth) .filter(oauth)
.build(); .build();
} }
@Bean
ServerOAuth2AuthorizedClientManager authorizedClientManager(
ReactiveClientRegistrationRepository clientRegistrationRepository,
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider =
ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode()
.refreshToken()
.clientCredentials()
.build();
DefaultServerOAuth2AuthorizedClientManager authorizedClientManager =
new DefaultServerOAuth2AuthorizedClientManager(
clientRegistrationRepository, authorizedClientRepository);
authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
return authorizedClientManager;
}
} }