diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java index 3fdd744a51..73107c4ce1 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java @@ -20,11 +20,14 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.ImportSelector; import org.springframework.core.type.AnnotationMetadata; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; -import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver; +import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.util.ClassUtils; import org.springframework.web.reactive.config.WebFluxConfigurer; import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer; @@ -63,7 +66,16 @@ final class ReactiveOAuth2ClientImportSelector implements ImportSelector { @Override public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) { if (this.authorizedClientRepository != null && this.clientRegistrationRepository != null) { - configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, getAuthorizedClientRepository())); + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = + ReactiveOAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken() + .clientCredentials() + .build(); + DefaultServerOAuth2AuthorizedClientManager authorizedClientManager = new DefaultServerOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, getAuthorizedClientRepository()); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager)); } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.java new file mode 100644 index 0000000000..002432bd37 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.java @@ -0,0 +1,53 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.util.Assert; +import reactor.core.publisher.Mono; + +/** + * An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider} + * for the {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} grant. + * + * @author Joe Grandja + * @since 5.2 + * @see ReactiveOAuth2AuthorizedClientProvider + */ +public final class AuthorizationCodeReactiveOAuth2AuthorizedClientProvider implements ReactiveOAuth2AuthorizedClientProvider { + + /** + * Attempt to authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. + * Returns an empty {@code Mono} if authorization is not supported, + * e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} + * is not {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} OR the client is already authorized. + * + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if authorization is not supported + */ + @Override + public Mono authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getClientRegistration().getAuthorizationGrantType()) && + context.getAuthorizedClient() == null) { + // ClientAuthorizationRequiredException is caught by OAuth2AuthorizationRequestRedirectWebFilter which initiates authorization + return Mono.error(() -> new ClientAuthorizationRequiredException(context.getClientRegistration().getRegistrationId())); + } + return Mono.empty(); + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProvider.java new file mode 100644 index 0000000000..4b31b4269a --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProvider.java @@ -0,0 +1,108 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.util.Assert; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.time.Instant; + +/** + * An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider} + * for the {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} grant. + * + * @author Joe Grandja + * @since 5.2 + * @see ReactiveOAuth2AuthorizedClientProvider + * @see WebClientReactiveClientCredentialsTokenResponseClient + */ +public final class ClientCredentialsReactiveOAuth2AuthorizedClientProvider implements ReactiveOAuth2AuthorizedClientProvider { + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = + new WebClientReactiveClientCredentialsTokenResponseClient(); + private Duration clockSkew = Duration.ofSeconds(60); + + /** + * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. + * Returns an empty {@code Mono} if authorization (or re-authorization) is not supported, + * e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} + * is not {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} OR + * the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. + * + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if authorization (or re-authorization) is not supported + */ + @Override + public Mono authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + + ClientRegistration clientRegistration = context.getClientRegistration(); + if (!AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { + return Mono.empty(); + } + + OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); + if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { + // If client is already authorized but access token is NOT expired than no need for re-authorization + return Mono.empty(); + } + + // As per spec, in section 4.4.3 Access Token Response + // https://tools.ietf.org/html/rfc6749#section-4.4.3 + // A refresh token SHOULD NOT be included. + // + // Therefore, renewing an expired access token (re-authorization) + // is the same as acquiring a new access token (authorization). + + return Mono.just(new OAuth2ClientCredentialsGrantRequest(clientRegistration)) + .flatMap(this.accessTokenResponseClient::getTokenResponse) + .map(tokenResponse -> new OAuth2AuthorizedClient( + clientRegistration, context.getPrincipal().getName(), tokenResponse.getAccessToken())); + } + + private boolean hasTokenExpired(AbstractOAuth2Token token) { + return token.getExpiresAt().isBefore(Instant.now().minus(this.clockSkew)); + } + + /** + * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant. + * + * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant + */ + public void setAccessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { + Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); + this.accessTokenResponseClient = accessTokenResponseClient; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the + * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is 60 seconds. + * An access token is considered expired if it's before {@code Instant.now() - clockSkew}. + * + * @param clockSkew the maximum acceptable clock skew + */ + public void setClockSkew(Duration clockSkew) { + Assert.notNull(clockSkew, "clockSkew cannot be null"); + Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0"); + this.clockSkew = clockSkew; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingReactiveOAuth2AuthorizedClientProvider.java new file mode 100644 index 0000000000..1264d792c5 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingReactiveOAuth2AuthorizedClientProvider.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider} that simply delegates + * to it's internal {@code List} of {@link ReactiveOAuth2AuthorizedClientProvider}(s). + *

+ * Each provider is given a chance to + * {@link ReactiveOAuth2AuthorizedClientProvider#authorize(OAuth2AuthorizationContext) authorize} + * the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided context + * with the first available {@link OAuth2AuthorizedClient} being returned. + * + * @author Joe Grandja + * @since 5.2 + * @see ReactiveOAuth2AuthorizedClientProvider + */ +public final class DelegatingReactiveOAuth2AuthorizedClientProvider implements ReactiveOAuth2AuthorizedClientProvider { + private final List authorizedClientProviders; + + /** + * Constructs a {@code DelegatingReactiveOAuth2AuthorizedClientProvider} using the provided parameters. + * + * @param authorizedClientProviders a list of {@link ReactiveOAuth2AuthorizedClientProvider}(s) + */ + public DelegatingReactiveOAuth2AuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider... authorizedClientProviders) { + Assert.notEmpty(authorizedClientProviders, "authorizedClientProviders cannot be empty"); + this.authorizedClientProviders = Collections.unmodifiableList(Arrays.asList(authorizedClientProviders)); + } + + /** + * Constructs a {@code DelegatingReactiveOAuth2AuthorizedClientProvider} using the provided parameters. + * + * @param authorizedClientProviders a {@code List} of {@link OAuth2AuthorizedClientProvider}(s) + */ + public DelegatingReactiveOAuth2AuthorizedClientProvider(List authorizedClientProviders) { + Assert.notEmpty(authorizedClientProviders, "authorizedClientProviders cannot be empty"); + this.authorizedClientProviders = Collections.unmodifiableList(new ArrayList<>(authorizedClientProviders)); + } + + @Override + public Mono authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + return Flux.fromIterable(this.authorizedClientProviders) + .concatMap(authorizedClientProvider -> authorizedClientProvider.authorize(context)) + .next(); + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProvider.java new file mode 100644 index 0000000000..0775846f0c --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProvider.java @@ -0,0 +1,44 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import reactor.core.publisher.Mono; + +/** + * A strategy for authorizing (or re-authorizing) an OAuth 2.0 Client. + * Implementations will typically implement a specific {@link AuthorizationGrantType authorization grant} type. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClient + * @see OAuth2AuthorizationContext + * @see Section 1.3 Authorization Grant + */ +public interface ReactiveOAuth2AuthorizedClientProvider { + + /** + * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided context. + * Implementations must return an empty {@code Mono} if authorization is not supported for the specified client, + * e.g. the provider doesn't support the {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type configured for the client. + * + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if authorization is not supported for the specified client + */ + Mono authorize(OAuth2AuthorizationContext context); + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilder.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilder.java new file mode 100644 index 0000000000..22b438d42e --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilder.java @@ -0,0 +1,267 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.util.Assert; + +import java.time.Duration; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +/** + * A builder that builds a {@link DelegatingReactiveOAuth2AuthorizedClientProvider} composed of + * one or more {@link ReactiveOAuth2AuthorizedClientProvider}(s) that implement specific authorization grants. + * The supported authorization grants are {@link #authorizationCode() authorization_code}, + * {@link #refreshToken() refresh_token} and {@link #clientCredentials() client_credentials}. + * In addition to the standard authorization grants, an implementation of an extension grant + * may be supplied via {@link #provider(ReactiveOAuth2AuthorizedClientProvider)}. + * + * @author Joe Grandja + * @since 5.2 + * @see ReactiveOAuth2AuthorizedClientProvider + * @see AuthorizationCodeReactiveOAuth2AuthorizedClientProvider + * @see RefreshTokenReactiveOAuth2AuthorizedClientProvider + * @see ClientCredentialsReactiveOAuth2AuthorizedClientProvider + * @see DelegatingReactiveOAuth2AuthorizedClientProvider + */ +public final class ReactiveOAuth2AuthorizedClientProviderBuilder { + private final Map, Builder> builders = new LinkedHashMap<>(); + + private ReactiveOAuth2AuthorizedClientProviderBuilder() { + } + + /** + * Returns a new {@link ReactiveOAuth2AuthorizedClientProviderBuilder} for configuring the supported authorization grant(s). + * + * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} + */ + public static ReactiveOAuth2AuthorizedClientProviderBuilder builder() { + return new ReactiveOAuth2AuthorizedClientProviderBuilder(); + } + + /** + * Configures a {@link ReactiveOAuth2AuthorizedClientProvider} to be composed with the {@link DelegatingReactiveOAuth2AuthorizedClientProvider}. + * This may be used for implementations of extension authorization grants. + * + * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} + */ + public ReactiveOAuth2AuthorizedClientProviderBuilder provider(ReactiveOAuth2AuthorizedClientProvider provider) { + Assert.notNull(provider, "provider cannot be null"); + this.builders.computeIfAbsent(provider.getClass(), k -> () -> provider); + return ReactiveOAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * Configures support for the {@code authorization_code} grant. + * + * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} + */ + public ReactiveOAuth2AuthorizedClientProviderBuilder authorizationCode() { + this.builders.computeIfAbsent(AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.class, k -> new AuthorizationCodeGrantBuilder()); + return ReactiveOAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * A builder for the {@code authorization_code} grant. + */ + public class AuthorizationCodeGrantBuilder implements Builder { + + private AuthorizationCodeGrantBuilder() { + } + + /** + * Builds an instance of {@link AuthorizationCodeReactiveOAuth2AuthorizedClientProvider}. + * + * @return the {@link AuthorizationCodeReactiveOAuth2AuthorizedClientProvider} + */ + @Override + public ReactiveOAuth2AuthorizedClientProvider build() { + return new AuthorizationCodeReactiveOAuth2AuthorizedClientProvider(); + } + } + + /** + * Configures support for the {@code refresh_token} grant. + * + * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} + */ + public ReactiveOAuth2AuthorizedClientProviderBuilder refreshToken() { + this.builders.computeIfAbsent(RefreshTokenReactiveOAuth2AuthorizedClientProvider.class, k -> new RefreshTokenGrantBuilder()); + return ReactiveOAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * Configures support for the {@code refresh_token} grant. + * + * @param builderConsumer a {@code Consumer} of {@link RefreshTokenGrantBuilder} used for further configuration + * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} + */ + public ReactiveOAuth2AuthorizedClientProviderBuilder refreshToken(Consumer builderConsumer) { + RefreshTokenGrantBuilder builder = (RefreshTokenGrantBuilder) this.builders.computeIfAbsent( + RefreshTokenReactiveOAuth2AuthorizedClientProvider.class, k -> new RefreshTokenGrantBuilder()); + builderConsumer.accept(builder); + return ReactiveOAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * A builder for the {@code refresh_token} grant. + */ + public class RefreshTokenGrantBuilder implements Builder { + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; + private Duration clockSkew; + + private RefreshTokenGrantBuilder() { + } + + /** + * Sets the client used when requesting an access token credential at the Token Endpoint. + * + * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint + * @return the {@link RefreshTokenGrantBuilder} + */ + public RefreshTokenGrantBuilder accessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { + this.accessTokenResponseClient = accessTokenResponseClient; + return this; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the access token expiry. + * An access token is considered expired if it's before {@code Instant.now() - clockSkew}. + * + * @param clockSkew the maximum acceptable clock skew + * @return the {@link RefreshTokenGrantBuilder} + */ + public RefreshTokenGrantBuilder clockSkew(Duration clockSkew) { + this.clockSkew = clockSkew; + return this; + } + + /** + * Builds an instance of {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider}. + * + * @return the {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider} + */ + @Override + public ReactiveOAuth2AuthorizedClientProvider build() { + RefreshTokenReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenReactiveOAuth2AuthorizedClientProvider(); + if (this.accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + } + if (this.clockSkew != null) { + authorizedClientProvider.setClockSkew(this.clockSkew); + } + return authorizedClientProvider; + } + } + + /** + * Configures support for the {@code client_credentials} grant. + * + * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} + */ + public ReactiveOAuth2AuthorizedClientProviderBuilder clientCredentials() { + this.builders.computeIfAbsent(ClientCredentialsReactiveOAuth2AuthorizedClientProvider.class, k -> new ClientCredentialsGrantBuilder()); + return ReactiveOAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * Configures support for the {@code client_credentials} grant. + * + * @param builderConsumer a {@code Consumer} of {@link ClientCredentialsGrantBuilder} used for further configuration + * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} + */ + public ReactiveOAuth2AuthorizedClientProviderBuilder clientCredentials(Consumer builderConsumer) { + ClientCredentialsGrantBuilder builder = (ClientCredentialsGrantBuilder) this.builders.computeIfAbsent( + ClientCredentialsReactiveOAuth2AuthorizedClientProvider.class, k -> new ClientCredentialsGrantBuilder()); + builderConsumer.accept(builder); + return ReactiveOAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * A builder for the {@code client_credentials} grant. + */ + public class ClientCredentialsGrantBuilder implements Builder { + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; + private Duration clockSkew; + + private ClientCredentialsGrantBuilder() { + } + + /** + * Sets the client used when requesting an access token credential at the Token Endpoint. + * + * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint + * @return the {@link ClientCredentialsGrantBuilder} + */ + public ClientCredentialsGrantBuilder accessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { + this.accessTokenResponseClient = accessTokenResponseClient; + return this; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the access token expiry. + * An access token is considered expired if it's before {@code Instant.now() - clockSkew}. + * + * @param clockSkew the maximum acceptable clock skew + * @return the {@link ClientCredentialsGrantBuilder} + */ + public ClientCredentialsGrantBuilder clockSkew(Duration clockSkew) { + this.clockSkew = clockSkew; + return this; + } + + /** + * Builds an instance of {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider}. + * + * @return the {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider} + */ + @Override + public ReactiveOAuth2AuthorizedClientProvider build() { + ClientCredentialsReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = new ClientCredentialsReactiveOAuth2AuthorizedClientProvider(); + if (this.accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + } + if (this.clockSkew != null) { + authorizedClientProvider.setClockSkew(this.clockSkew); + } + return authorizedClientProvider; + } + } + + /** + * Builds an instance of {@link DelegatingReactiveOAuth2AuthorizedClientProvider} + * composed of one or more {@link ReactiveOAuth2AuthorizedClientProvider}(s). + * + * @return the {@link DelegatingReactiveOAuth2AuthorizedClientProvider} + */ + public ReactiveOAuth2AuthorizedClientProvider build() { + List authorizedClientProviders = + this.builders.values().stream() + .map(Builder::build) + .collect(Collectors.toList()); + return new DelegatingReactiveOAuth2AuthorizedClientProvider(authorizedClientProviders); + } + + interface Builder { + ReactiveOAuth2AuthorizedClientProvider build(); + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java new file mode 100644 index 0000000000..000b03c98a --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java @@ -0,0 +1,119 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.WebClientReactiveRefreshTokenTokenResponseClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.util.Assert; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +/** + * An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider} + * for the {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant. + * + * @author Joe Grandja + * @since 5.2 + * @see ReactiveOAuth2AuthorizedClientProvider + * @see WebClientReactiveRefreshTokenTokenResponseClient + */ +public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider implements ReactiveOAuth2AuthorizedClientProvider { + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = + new WebClientReactiveRefreshTokenTokenResponseClient(); + private Duration clockSkew = Duration.ofSeconds(60); + + /** + * Attempt to re-authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. + * Returns an empty {@code Mono} if re-authorization is not supported, + * e.g. the client is not authorized OR the {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} + * is not available for the authorized client OR the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. + * + *

+ * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: + *

    + *
  1. {@code "org.springframework.security.oauth2.client.REQUEST_SCOPE"} (optional) - a {@code String[]} of scope(s) + * to be requested by the {@link OAuth2AuthorizationContext#getClientRegistration() client}
  2. + *
+ * + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if re-authorization is not supported + */ + @Override + public Mono authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + + OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); + if (authorizedClient == null || + authorizedClient.getRefreshToken() == null || + !hasTokenExpired(authorizedClient.getAccessToken())) { + return Mono.empty(); + } + + Object requestScope = context.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + Set scopes = Collections.emptySet(); + if (requestScope != null) { + Assert.isInstanceOf(String[].class, requestScope, + "The context attribute must be of type String[] '" + OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); + scopes = new HashSet<>(Arrays.asList((String[]) requestScope)); + } + ClientRegistration clientRegistration = context.getClientRegistration(); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + clientRegistration, authorizedClient.getAccessToken(), authorizedClient.getRefreshToken(), scopes); + + return Mono.just(refreshTokenGrantRequest) + .flatMap(this.accessTokenResponseClient::getTokenResponse) + .map(tokenResponse -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), + tokenResponse.getAccessToken(), tokenResponse.getRefreshToken())); + } + + private boolean hasTokenExpired(AbstractOAuth2Token token) { + return token.getExpiresAt().isBefore(Instant.now().minus(this.clockSkew)); + } + + /** + * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code refresh_token} grant. + * + * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code refresh_token} grant + */ + public void setAccessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { + Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); + this.accessTokenResponseClient = accessTokenResponseClient; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the + * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is 60 seconds. + * An access token is considered expired if it's before {@code Instant.now() - clockSkew}. + * + * @param clockSkew the maximum acceptable clock skew + */ + public void setClockSkew(Duration clockSkew) { + Assert.notNull(clockSkew, "clockSkew cannot be null"); + Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0"); + this.clockSkew = clockSkew; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java new file mode 100644 index 0000000000..4d7938242c --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java @@ -0,0 +1,142 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; + +import java.util.Collections; +import java.util.function.Consumer; + +import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse; + +/** + * An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient} + * for the {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant. + * This implementation uses {@link WebClient} when requesting + * an access token credential at the Authorization Server's Token Endpoint. + * + * @author Joe Grandja + * @since 5.2 + * @see ReactiveOAuth2AccessTokenResponseClient + * @see OAuth2RefreshTokenGrantRequest + * @see OAuth2AccessTokenResponse + * @see Section 6 Refreshing an Access Token + */ +public final class WebClientReactiveRefreshTokenTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient { + private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; + private WebClient webClient = WebClient.builder().build(); + + @Override + public Mono getTokenResponse(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { + Assert.notNull(refreshTokenGrantRequest, "refreshTokenGrantRequest cannot be null"); + return Mono.defer(() -> { + ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration(); + return this.webClient.post() + .uri(clientRegistration.getProviderDetails().getTokenUri()) + .headers(tokenRequestHeaders(clientRegistration)) + .body(tokenRequestBody(refreshTokenGrantRequest)) + .exchange() + .flatMap(response -> { + if (!response.statusCode().is2xxSuccessful()) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + + "HTTP Status Code " + response.rawStatusCode(), null); + return response + .bodyToMono(DataBuffer.class) + .map(DataBufferUtils::release) + .then(Mono.error(new OAuth2AuthorizationException(oauth2Error))); + } + return response.body(oauth2AccessTokenResponse()); + }) + .map(tokenResponse -> tokenResponse(refreshTokenGrantRequest, tokenResponse)); + }); + } + + private static Consumer tokenRequestHeaders(ClientRegistration clientRegistration) { + return headers -> { + headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED); + headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); + if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) { + headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); + } + }; + } + + private static BodyInserters.FormInserter tokenRequestBody(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { + ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration(); + BodyInserters.FormInserter body = BodyInserters.fromFormData( + OAuth2ParameterNames.GRANT_TYPE, refreshTokenGrantRequest.getGrantType().getValue()); + body.with(OAuth2ParameterNames.REFRESH_TOKEN, + refreshTokenGrantRequest.getRefreshToken().getTokenValue()); + if (!CollectionUtils.isEmpty(refreshTokenGrantRequest.getScopes())) { + body.with(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(refreshTokenGrantRequest.getScopes(), " ")); + } + if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { + body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); + body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); + } + return body; + } + + private static OAuth2AccessTokenResponse tokenResponse(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest, + OAuth2AccessTokenResponse accessTokenResponse) { + if (!CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes()) && + accessTokenResponse.getRefreshToken() != null) { + return accessTokenResponse; + } + + OAuth2AccessTokenResponse.Builder tokenResponseBuilder = OAuth2AccessTokenResponse.withResponse(accessTokenResponse); + if (CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes())) { + // As per spec, in Section 5.1 Successful Access Token Response + // https://tools.ietf.org/html/rfc6749#section-5.1 + // If AccessTokenResponse.scope is empty, then default to the scope + // originally requested by the client in the Token Request + tokenResponseBuilder.scopes(refreshTokenGrantRequest.getAccessToken().getScopes()); + } + if (accessTokenResponse.getRefreshToken() == null) { + // Reuse existing refresh token + tokenResponseBuilder.refreshToken(refreshTokenGrantRequest.getRefreshToken().getTokenValue()); + } + return tokenResponseBuilder.build(); + } + + /** + * Sets the {@link WebClient} used when requesting the OAuth 2.0 Access Token Response. + * + * @param webClient the {@link WebClient} used when requesting the Access Token Response + */ + public void setWebClient(WebClient webClient) { + Assert.notNull(webClient, "webClient cannot be null"); + this.webClient = webClient; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java index c8da1dca42..9719fe98b5 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java @@ -16,23 +16,26 @@ package org.springframework.security.oauth2.client.web.reactive.function.client; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; -import org.springframework.http.MediaType; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.oauth2.client.ClientCredentialsReactiveOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder; +import org.springframework.security.oauth2.client.RefreshTokenReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizeRequest; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.util.Assert; -import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ExchangeFilterFunction; @@ -40,21 +43,17 @@ import org.springframework.web.reactive.function.client.ExchangeFunction; import org.springframework.web.server.ServerWebExchange; import reactor.core.publisher.Mono; -import java.net.URI; -import java.time.Clock; import java.time.Duration; -import java.time.Instant; import java.util.Map; import java.util.Optional; import java.util.function.Consumer; -import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse; - /** * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the * token as a Bearer Token. * * @author Rob Winch + * @author Joe Grandja * @since 5.1 */ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction { @@ -76,21 +75,59 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_USER")); - private Clock clock = Clock.systemUTC(); + private ServerOAuth2AuthorizedClientManager authorizedClientManager; + private boolean defaultAuthorizedClientManager; + + private boolean defaultOAuth2AuthorizedClient; + + private String defaultClientRegistrationId; + + @Deprecated private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); - private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + @Deprecated + private ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient; - private final OAuth2AuthorizedClientResolver authorizedClientResolver; - public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { - this(authorizedClientRepository, new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository)); + /** + * Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters. + * + * @since 5.2 + * @param authorizedClientManager the {@link ServerOAuth2AuthorizedClientManager} which manages the authorized client(s) + */ + public ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientManager authorizedClientManager) { + Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null"); + this.authorizedClientManager = authorizedClientManager; } - ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository, OAuth2AuthorizedClientResolver authorizedClientResolver) { - this.authorizedClientRepository = authorizedClientRepository; - this.authorizedClientResolver = authorizedClientResolver; + /** + * Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientRepository the repository of authorized clients + */ + public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, + ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository); + this.defaultAuthorizedClientManager = true; + } + + private static ServerOAuth2AuthorizedClientManager createDefaultAuthorizedClientManager( + ReactiveClientRegistrationRepository clientRegistrationRepository, + ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = + ReactiveOAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken() + .clientCredentials() + .build(); + DefaultServerOAuth2AuthorizedClientManager authorizedClientManager = new DefaultServerOAuth2AuthorizedClientManager( + clientRegistrationRepository, authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + + return authorizedClientManager; } /** @@ -99,7 +136,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements * *
 	 * WebClient webClient = WebClient.builder()
-	 *    .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientRepository))
+	 *    .filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager))
 	 *    .build();
 	 * Mono response = webClient
 	 *    .get()
@@ -114,8 +151,6 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 	 * are true:
 	 *
 	 * 
    - *
  • The ReactiveOAuth2AuthorizedClientService on the - * {@link ServerOAuth2AuthorizedClientExchangeFilterFunction} is not null
  • *
  • A refresh token is present on the OAuth2AuthorizedClient
  • *
  • The access token will be expired in * {@link #setAccessTokenExpiresSkew(Duration)}
  • @@ -136,12 +171,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements } /** - * Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for + * Modifies the {@link ClientRequest#attributes()} to include the {@link ServerWebExchange} to be used for * providing the Bearer Token. Example usage: * *
     	 * WebClient webClient = WebClient.builder()
    -	 *    .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientRepository))
    +	 *    .filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager))
     	 *    .build();
     	 * Mono response = webClient
     	 *    .get()
    @@ -190,7 +225,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
     	 *                                      Default is false.
     	 */
     	public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
    -		this.authorizedClientResolver.setDefaultOAuth2AuthorizedClient(defaultOAuth2AuthorizedClient);
    +		this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient;
     	}
     
     	/**
    @@ -199,124 +234,127 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
     	 * @param clientRegistrationId the id to use
     	 */
     	public void setDefaultClientRegistrationId(String clientRegistrationId) {
    -		this.authorizedClientResolver.setDefaultClientRegistrationId(clientRegistrationId);
    +		this.defaultClientRegistrationId = clientRegistrationId;
     	}
     
     	/**
    -	 * Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
    -	 * client_credentials grant.
    +	 * Sets the {@link ReactiveOAuth2AccessTokenResponseClient} used for getting an {@link OAuth2AuthorizedClient} for the client_credentials grant.
    +	 *
    +	 * @deprecated Use {@link #ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientManager)} instead.
    +	 * 				Create an instance of {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider} configured with a
    +	 * 				{@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient) WebClientReactiveClientCredentialsTokenResponseClient}
    +	 * 				(or a custom one) and than supply it to {@link DefaultServerOAuth2AuthorizedClientManager#setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider) DefaultOAuth2AuthorizedClientManager}.
    +	 *
     	 * @param clientCredentialsTokenResponseClient the client to use
     	 */
    +	@Deprecated
     	public void setClientCredentialsTokenResponseClient(
     			ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) {
    -		this.authorizedClientResolver.setClientCredentialsTokenResponseClient(clientCredentialsTokenResponseClient);
    +		Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
    +		Assert.state(this.defaultAuthorizedClientManager, "The client cannot be set when the constructor used is \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientManager)\". " +
    +				"Instead, use the constructor \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\".");
    +		this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
    +		updateDefaultAuthorizedClientManager();
    +	}
    +
    +	private void updateDefaultAuthorizedClientManager() {
    +		ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider =
    +				ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
    +						.authorizationCode()
    +						.refreshToken(configurer -> configurer.clockSkew(this.accessTokenExpiresSkew))
    +						.clientCredentials(this::updateClientCredentialsProvider)
    +						.build();
    +		((DefaultServerOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider);
    +	}
    +
    +	private void updateClientCredentialsProvider(ReactiveOAuth2AuthorizedClientProviderBuilder.ClientCredentialsGrantBuilder builder) {
    +		if (this.clientCredentialsTokenResponseClient != null) {
    +			builder.accessTokenResponseClient(this.clientCredentialsTokenResponseClient);
    +		}
    +		builder.clockSkew(this.accessTokenExpiresSkew);
     	}
     
     	/**
     	 * An access token will be considered expired by comparing its expiration to now +
     	 * this skewed Duration. The default is 1 minute.
    +	 *
    +	 * @deprecated The {@code accessTokenExpiresSkew} should be configured with the specific {@link ReactiveOAuth2AuthorizedClientProvider} implementation,
    +	 * 				e.g. {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider#setClockSkew(Duration) ClientCredentialsReactiveOAuth2AuthorizedClientProvider} or
    +	 * 				{@link RefreshTokenReactiveOAuth2AuthorizedClientProvider#setClockSkew(Duration) RefreshTokenReactiveOAuth2AuthorizedClientProvider}.
    +	 *
     	 * @param accessTokenExpiresSkew the Duration to use.
     	 */
    +	@Deprecated
     	public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) {
     		Assert.notNull(accessTokenExpiresSkew, "accessTokenExpiresSkew cannot be null");
    +		Assert.state(this.defaultAuthorizedClientManager, "The accessTokenExpiresSkew cannot be set when the constructor used is \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientManager)\". " +
    +				"Instead, use the constructor \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\".");
     		this.accessTokenExpiresSkew = accessTokenExpiresSkew;
    +		updateDefaultAuthorizedClientManager();
     	}
     
     	@Override
     	public Mono filter(ClientRequest request, ExchangeFunction next) {
    -		return authorizedClient(request, next)
    +		return authorizedClient(request)
     				.map(authorizedClient -> bearer(request, authorizedClient))
     				.flatMap(next::exchange)
     				.switchIfEmpty(Mono.defer(() -> next.exchange(request)));
     	}
     
    -	private Mono authorizedClient(ClientRequest request, ExchangeFunction next) {
    +	private Mono authorizedClient(ClientRequest request) {
     		OAuth2AuthorizedClient authorizedClientFromAttrs = oauth2AuthorizedClient(request);
     		return Mono.justOrEmpty(authorizedClientFromAttrs)
    -				.switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(request)))
    -				.flatMap(authorizedClient -> refreshIfNecessary(request, next, authorizedClient));
    +				.switchIfEmpty(Mono.defer(() ->
    +						authorizeRequest(request).flatMap(this.authorizedClientManager::authorize)))
    +				.flatMap(authorizedClient ->
    +						reauthorizeRequest(request, authorizedClient).flatMap(this.authorizedClientManager::authorize));
     	}
     
    -	private Mono loadAuthorizedClient(ClientRequest request) {
    -		return createRequest(request)
    -			.flatMap(r -> this.authorizedClientResolver.loadAuthorizedClient(r));
    +	private Mono authorizeRequest(ClientRequest request) {
    +		Mono authentication = currentAuthentication();
    +
    +		Mono clientRegistrationId = Mono.justOrEmpty(clientRegistrationId(request))
    +				.switchIfEmpty(Mono.justOrEmpty(this.defaultClientRegistrationId))
    +				.switchIfEmpty(clientRegistrationId(authentication));
    +
    +		Mono> serverWebExchange = Mono.justOrEmpty(serverWebExchange(request))
    +				.switchIfEmpty(currentServerWebExchange())
    +				.map(Optional::of)
    +				.defaultIfEmpty(Optional.empty());
    +
    +		return Mono.zip(clientRegistrationId, authentication, serverWebExchange)
    +				.map(t3 -> new ServerOAuth2AuthorizeRequest(t3.getT1(), t3.getT2(), t3.getT3().orElse(null)));
     	}
     
    -	private Mono createRequest(ClientRequest request) {
    -		String clientRegistrationId = clientRegistrationId(request);
    -		Authentication authentication = null;
    -		ServerWebExchange exchange = serverWebExchange(request);
    -		return this.authorizedClientResolver.createDefaultedRequest(clientRegistrationId, authentication, exchange);
    +	private Mono reauthorizeRequest(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {
    +		Mono authentication = currentAuthentication();
    +
    +		Mono> serverWebExchange = Mono.justOrEmpty(serverWebExchange(request))
    +				.switchIfEmpty(currentServerWebExchange())
    +				.map(Optional::of)
    +				.defaultIfEmpty(Optional.empty());
    +
    +		return Mono.zip(authentication, serverWebExchange)
    +				.map(t2 -> new ServerOAuth2AuthorizeRequest(authorizedClient, t2.getT1(), t2.getT2().orElse(null)));
     	}
     
    -	private Mono refreshIfNecessary(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
    -		ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
    -		if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) {
    -			return createRequest(request)
    -					.flatMap(r -> authorizeWithClientCredentials(clientRegistration, r));
    -		} else if (shouldRefreshToken(authorizedClient)) {
    -			return createRequest(request)
    -				.flatMap(r -> authorizeWithRefreshToken(next, authorizedClient, r));
    -		}
    -		return Mono.just(authorizedClient);
    +	private Mono currentAuthentication() {
    +		return ReactiveSecurityContextHolder.getContext()
    +				.map(SecurityContext::getAuthentication)
    +				.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
     	}
     
    -	private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) {
    -		return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType());
    +	private Mono clientRegistrationId(Mono authentication) {
    +		return authentication
    +				.filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken)
    +				.cast(OAuth2AuthenticationToken.class)
    +				.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
     	}
     
    -	private Mono authorizeWithClientCredentials(ClientRegistration clientRegistration, OAuth2AuthorizedClientResolver.Request request) {
    -		Authentication authentication = request.getAuthentication();
    -		ServerWebExchange exchange = request.getExchange();
    -
    -		return this.authorizedClientResolver.clientCredentials(clientRegistration, authentication, exchange).
    -				flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange)
    -						.thenReturn(result));
    -	}
    -
    -	private Mono authorizeWithRefreshToken(ExchangeFunction next,
    -																	OAuth2AuthorizedClient authorizedClient,
    -																	OAuth2AuthorizedClientResolver.Request r) {
    -		ServerWebExchange exchange = r.getExchange();
    -		Authentication authentication = r.getAuthentication();
    -		ClientRegistration clientRegistration = authorizedClient
    -				.getClientRegistration();
    -		String tokenUri = clientRegistration
    -				.getProviderDetails().getTokenUri();
    -		ClientRequest refreshRequest = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri))
    -				.header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
    -				.headers(headers -> headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()))
    -				.body(refreshTokenBody(authorizedClient.getRefreshToken().getTokenValue()))
    -				.build();
    -		return next.exchange(refreshRequest)
    -				.flatMap(refreshResponse -> refreshResponse.body(oauth2AccessTokenResponse()))
    -				.map(accessTokenResponse -> {
    -					OAuth2RefreshToken refreshToken = Optional.ofNullable(accessTokenResponse.getRefreshToken())
    -							.orElse(authorizedClient.getRefreshToken());
    -					return new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), refreshToken);
    -				})
    -				.flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange)
    -						.thenReturn(result));
    -	}
    -
    -	private boolean shouldRefreshToken(OAuth2AuthorizedClient authorizedClient) {
    -		if (this.authorizedClientRepository == null) {
    -			return false;
    -		}
    -		OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();
    -		if (refreshToken == null) {
    -			return false;
    -		}
    -		return hasTokenExpired(authorizedClient);
    -	}
    -
    -	private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
    -		Instant now = this.clock.instant();
    -		Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
    -		if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) {
    -			return true;
    -		}
    -		return false;
    +	private Mono currentServerWebExchange() {
    +		return Mono.subscriberContext()
    +				.filter(c -> c.hasKey(ServerWebExchange.class))
    +				.map(c -> c.get(ServerWebExchange.class));
     	}
     
     	private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {
    @@ -324,10 +362,4 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
     					.headers(headers -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue()))
     					.build();
     	}
    -
    -	private static BodyInserters.FormInserter refreshTokenBody(String refreshToken) {
    -		return BodyInserters
    -				.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue())
    -				.with("refresh_token", refreshToken);
    -	}
     }
    diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java
    index 1255fd0304..0784ec4000 100644
    --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java
    +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java
    @@ -1,5 +1,5 @@
     /*
    - * Copyright 2002-2018 the original author or authors.
    + * Copyright 2002-2019 the original author or authors.
      *
      * Licensed under the Apache License, Version 2.0 (the "License");
      * you may not use this file except in compliance with the License.
    @@ -18,9 +18,20 @@ package org.springframework.security.oauth2.client.web.reactive.result.method.an
     
     import org.springframework.core.MethodParameter;
     import org.springframework.core.annotation.AnnotatedElementUtils;
    +import org.springframework.security.authentication.AnonymousAuthenticationToken;
    +import org.springframework.security.core.Authentication;
    +import org.springframework.security.core.authority.AuthorityUtils;
    +import org.springframework.security.core.context.ReactiveSecurityContextHolder;
    +import org.springframework.security.core.context.SecurityContext;
     import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
    +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
    +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder;
     import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
    +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
     import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
    +import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizedClientManager;
    +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizeRequest;
    +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientManager;
     import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
     import org.springframework.util.Assert;
     import org.springframework.util.StringUtils;
    @@ -46,22 +57,53 @@ import reactor.core.publisher.Mono;
      * 
    * * @author Rob Winch + * @author Joe Grandja * @since 5.1 * @see RegisteredOAuth2AuthorizedClient */ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver { - - private final OAuth2AuthorizedClientResolver authorizedClientResolver; + private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken( + "anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_USER")); + private ServerOAuth2AuthorizedClientManager authorizedClientManager; /** * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters. * - * @param authorizedClientRepository the authorized client repository + * @since 5.2 + * @param authorizedClientManager the {@link ServerOAuth2AuthorizedClientManager} which manages the authorized client(s) */ - public OAuth2AuthorizedClientArgumentResolver(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + public OAuth2AuthorizedClientArgumentResolver(ServerOAuth2AuthorizedClientManager authorizedClientManager) { + Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null"); + this.authorizedClientManager = authorizedClientManager; + } + + /** + * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientRepository the repository of authorized clients + */ + public OAuth2AuthorizedClientArgumentResolver(ReactiveClientRegistrationRepository clientRegistrationRepository, + ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); - this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository); - this.authorizedClientResolver.setDefaultOAuth2AuthorizedClient(true); + this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository); + } + + private static ServerOAuth2AuthorizedClientManager createDefaultAuthorizedClientManager( + ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = + ReactiveOAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken() + .clientCredentials() + .build(); + DefaultServerOAuth2AuthorizedClientManager authorizedClientManager = new DefaultServerOAuth2AuthorizedClientManager( + clientRegistrationRepository, authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + + return authorizedClientManager; } @Override @@ -70,8 +112,7 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth } @Override - public Mono resolveArgument( - MethodParameter parameter, BindingContext bindingContext, ServerWebExchange exchange) { + public Mono resolveArgument(MethodParameter parameter, BindingContext bindingContext, ServerWebExchange exchange) { return Mono.defer(() -> { RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils .findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class); @@ -79,8 +120,41 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth String clientRegistrationId = StringUtils.hasLength(authorizedClientAnnotation.registrationId()) ? authorizedClientAnnotation.registrationId() : null; - return this.authorizedClientResolver.createDefaultedRequest(clientRegistrationId, null, exchange) - .flatMap(this.authorizedClientResolver::loadAuthorizedClient); + return authorizeRequest(clientRegistrationId, exchange) + .flatMap(this.authorizedClientManager::authorize); }); } + + private Mono authorizeRequest(String registrationId, ServerWebExchange exchange) { + Mono defaultedAuthentication = currentAuthentication(); + + Mono defaultedRegistrationId = Mono.justOrEmpty(registrationId) + .switchIfEmpty(clientRegistrationId(defaultedAuthentication)) + .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("The clientRegistrationId could not be resolved. Please provide one"))); + + Mono defaultedExchange = Mono.justOrEmpty(exchange) + .switchIfEmpty(currentServerWebExchange()); + + return Mono.zip(defaultedRegistrationId, defaultedAuthentication, defaultedExchange) + .map(t3 -> new ServerOAuth2AuthorizeRequest(t3.getT1(), t3.getT2(), t3.getT3())); + } + + private Mono currentAuthentication() { + return ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication) + .defaultIfEmpty(ANONYMOUS_USER_TOKEN); + } + + private Mono clientRegistrationId(Mono authentication) { + return authentication + .filter(t -> t instanceof OAuth2AuthenticationToken) + .cast(OAuth2AuthenticationToken.class) + .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId); + } + + private Mono currentServerWebExchange() { + return Mono.subscriberContext() + .filter(c -> c.hasKey(ServerWebExchange.class)) + .map(c -> c.get(ServerWebExchange.class)); + } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizedClientManager.java new file mode 100644 index 0000000000..54d61c7a11 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizedClientManager.java @@ -0,0 +1,143 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web.server; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + +/** + * The default implementation of a {@link ServerOAuth2AuthorizedClientManager}. + * + * @author Joe Grandja + * @since 5.2 + * @see ServerOAuth2AuthorizedClientManager + * @see ReactiveOAuth2AuthorizedClientProvider + */ +public final class DefaultServerOAuth2AuthorizedClientManager implements ServerOAuth2AuthorizedClientManager { + private final ReactiveClientRegistrationRepository clientRegistrationRepository; + private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = context -> Mono.empty(); + private Function> contextAttributesMapper = new DefaultContextAttributesMapper(); + + /** + * Constructs a {@code DefaultServerOAuth2AuthorizedClientManager} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientRepository the repository of authorized clients + */ + public DefaultServerOAuth2AuthorizedClientManager(ReactiveClientRegistrationRepository clientRegistrationRepository, + ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientRepository = authorizedClientRepository; + } + + @Override + public Mono authorize(ServerOAuth2AuthorizeRequest authorizeRequest) { + Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); + + String clientRegistrationId = authorizeRequest.getClientRegistrationId(); + Authentication principal = authorizeRequest.getPrincipal(); + ServerWebExchange serverWebExchange = authorizeRequest.getServerWebExchange(); + + return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient()) + .switchIfEmpty(Mono.defer(() -> + this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange))) + .flatMap(authorizedClient -> { + // Re-authorize + OAuth2AuthorizationContext reauthorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) + .principal(principal) + .attributes(this.contextAttributesMapper.apply(authorizeRequest)) + .build(); + return Mono.just(reauthorizationContext) + .flatMap(this.authorizedClientProvider::authorize) + .doOnNext(reauthorizedClient -> + this.authorizedClientRepository.saveAuthorizedClient( + reauthorizedClient, principal, serverWebExchange)) + // Return the `authorizedClient` if `reauthorizedClient` is null, e.g. re-authorization is not supported + .defaultIfEmpty(authorizedClient); + }) + .switchIfEmpty(Mono.defer(() -> + // Authorize + this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) + .switchIfEmpty(Mono.error(() -> new IllegalArgumentException( + "Could not find ClientRegistration with id '" + clientRegistrationId + "'"))) + .map(clientRegistration -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration) + .principal(principal) + .attributes(this.contextAttributesMapper.apply(authorizeRequest)) + .build()) + .flatMap(this.authorizedClientProvider::authorize) + .doOnNext(authorizedClient -> + this.authorizedClientRepository.saveAuthorizedClient( + authorizedClient, principal, serverWebExchange)) + )); + } + + /** + * Sets the {@link ReactiveOAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. + * + * @param authorizedClientProvider the {@link ReactiveOAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client + */ + public void setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider) { + Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); + this.authorizedClientProvider = authorizedClientProvider; + } + + /** + * Sets the {@code Function} used for mapping attribute(s) from the {@link ServerOAuth2AuthorizeRequest} to a {@code Map} of attributes + * to be associated to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}. + * + * @param contextAttributesMapper the {@code Function} used for supplying the {@code Map} of attributes + * to the {@link OAuth2AuthorizationContext#getAttributes() authorization context} + */ + public void setContextAttributesMapper(Function> contextAttributesMapper) { + Assert.notNull(contextAttributesMapper, "contextAttributesMapper cannot be null"); + this.contextAttributesMapper = contextAttributesMapper; + } + + /** + * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}. + */ + public static class DefaultContextAttributesMapper implements Function> { + + @Override + public Map apply(ServerOAuth2AuthorizeRequest authorizeRequest) { + Map contextAttributes = Collections.emptyMap(); + String scope = authorizeRequest.getServerWebExchange().getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE); + if (StringUtils.hasText(scope)) { + contextAttributes = new HashMap<>(); + contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, + StringUtils.delimitedListToStringArray(scope, " ")); + } + return contextAttributes; + } + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizeRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizeRequest.java new file mode 100644 index 0000000000..6aee1feb4f --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizeRequest.java @@ -0,0 +1,112 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web.server; + +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; + +/** + * Represents a request the {@link ServerOAuth2AuthorizedClientManager} uses to + * {@link ServerOAuth2AuthorizedClientManager#authorize(ServerOAuth2AuthorizeRequest) authorize} (or re-authorize) + * the {@link ClientRegistration client} identified by the provided {@link #getClientRegistrationId() clientRegistrationId}. + * + * @author Joe Grandja + * @since 5.2 + * @see ServerOAuth2AuthorizedClientManager + */ +public class ServerOAuth2AuthorizeRequest { + private final String clientRegistrationId; + private final OAuth2AuthorizedClient authorizedClient; + private final Authentication principal; + private final ServerWebExchange serverWebExchange; + + /** + * Constructs a {@code ServerOAuth2AuthorizeRequest} using the provided parameters. + * + * @param clientRegistrationId the identifier for the {@link ClientRegistration client registration} + * @param principal the {@code Principal} (to be) associated to the authorized client + * @param serverWebExchange the {@code ServerWebExchange} + */ + public ServerOAuth2AuthorizeRequest(String clientRegistrationId, Authentication principal, + ServerWebExchange serverWebExchange) { + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + Assert.notNull(principal, "principal cannot be null"); + Assert.notNull(serverWebExchange, "serverWebExchange cannot be null"); + this.clientRegistrationId = clientRegistrationId; + this.authorizedClient = null; + this.principal = principal; + this.serverWebExchange = serverWebExchange; + } + + /** + * Constructs a {@code ServerOAuth2AuthorizeRequest} using the provided parameters. + * + * @param authorizedClient the {@link OAuth2AuthorizedClient authorized client} + * @param principal the {@code Principal} (to be) associated to the authorized client + * @param serverWebExchange the {@code ServerWebExchange} + */ + public ServerOAuth2AuthorizeRequest(OAuth2AuthorizedClient authorizedClient, Authentication principal, + ServerWebExchange serverWebExchange) { + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + Assert.notNull(principal, "principal cannot be null"); + Assert.notNull(serverWebExchange, "serverWebExchange cannot be null"); + this.clientRegistrationId = authorizedClient.getClientRegistration().getRegistrationId(); + this.authorizedClient = authorizedClient; + this.principal = principal; + this.serverWebExchange = serverWebExchange; + } + + /** + * Returns the identifier for the {@link ClientRegistration client registration}. + * + * @return the identifier for the client registration + */ + public String getClientRegistrationId() { + return this.clientRegistrationId; + } + + /** + * Returns the {@link OAuth2AuthorizedClient authorized client} or {@code null} if it was not provided. + * + * @return the {@link OAuth2AuthorizedClient} or {@code null} if it was not provided + */ + @Nullable + public OAuth2AuthorizedClient getAuthorizedClient() { + return this.authorizedClient; + } + + /** + * Returns the {@code Principal} (to be) associated to the authorized client. + * + * @return the {@code Principal} (to be) associated to the authorized client + */ + public Authentication getPrincipal() { + return this.principal; + } + + /** + * Returns the {@link ServerWebExchange}. + * + * @return the {@link ServerWebExchange} + */ + public ServerWebExchange getServerWebExchange() { + return this.serverWebExchange; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizedClientManager.java new file mode 100644 index 0000000000..dd24f9832e --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizedClientManager.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web.server; + +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import reactor.core.publisher.Mono; + +/** + * Implementations of this interface are responsible for the overall management + * of {@link OAuth2AuthorizedClient Authorized Client(s)}. + * + *

    + * The primary responsibilities include: + *

      + *
    1. Authorizing (or re-authorizing) an OAuth 2.0 Client + * by leveraging a {@link ReactiveOAuth2AuthorizedClientProvider}(s).
    2. + *
    3. Managing the persistence of an {@link OAuth2AuthorizedClient} between requests, + * typically using an {@link ServerOAuth2AuthorizedClientRepository}.
    4. + *
    + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClient + * @see ReactiveOAuth2AuthorizedClientProvider + * @see ServerOAuth2AuthorizedClientRepository + */ +public interface ServerOAuth2AuthorizedClientManager { + + /** + * Attempt to authorize or re-authorize (if required) the {@link ClientRegistration client} + * identified by the provided {@link ServerOAuth2AuthorizeRequest#getClientRegistrationId() clientRegistrationId}. + * Implementations must return an empty {@code Mono} if authorization is not supported for the specified client, + * e.g. the associated {@link ReactiveOAuth2AuthorizedClientProvider}(s) does not support + * the {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type configured for the client. + * + *

    + * In the case of re-authorization, implementations must return the provided {@link ServerOAuth2AuthorizeRequest#getAuthorizedClient() authorized client} + * if re-authorization is not supported for the client OR is not required, + * e.g. a {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available OR + * the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. + * + * @param authorizeRequest the authorize request + * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if authorization is not supported for the specified client + */ + Mono authorize(ServerOAuth2AuthorizeRequest authorizeRequest); + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 0000000000..97bf724011 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link AuthorizationCodeReactiveOAuth2AuthorizedClientProvider}. + * + * @author Joe Grandja + */ +public class AuthorizationCodeReactiveOAuth2AuthorizedClientProviderTests { + private AuthorizationCodeReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; + private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; + private Authentication principal; + + @Before + public void setup() { + this.authorizedClientProvider = new AuthorizationCodeReactiveOAuth2AuthorizedClientProvider(); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, "principal", TestOAuth2AccessTokens.scopes("read", "write")); + this.principal = new TestingAuthenticationToken("principal", "password"); + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null).block()) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void authorizeWhenNotAuthorizationCodeThenUnableToAuthorize() { + ClientRegistration clientCredentialsClient = TestClientRegistrations.clientCredentials().build(); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withClientRegistration(clientCredentialsClient) + .principal(this.principal) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + } + + @Test + public void authorizeWhenAuthorizationCodeAndAuthorizedThenNotAuthorize() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + } + + @Test + public void authorizeWhenAuthorizationCodeAndNotAuthorizedThenAuthorize() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext).block()) + .isInstanceOf(ClientAuthorizationRequiredException.class); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 0000000000..f7a4fe1295 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,150 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.time.Instant; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider}. + * + * @author Joe Grandja + */ +public class ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests { + private ClientCredentialsReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; + private ClientRegistration clientRegistration; + private Authentication principal; + + @Before + public void setup() { + this.authorizedClientProvider = new ClientCredentialsReactiveOAuth2AuthorizedClientProvider(); + this.accessTokenResponseClient = mock(ReactiveOAuth2AccessTokenResponseClient.class); + this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + this.clientRegistration = TestClientRegistrations.clientCredentials().build(); + this.principal = new TestingAuthenticationToken("principal", "password"); + } + + @Test + public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("accessTokenResponseClient cannot be null"); + } + + @Test + public void setClockSkewWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clockSkew cannot be null"); + } + + @Test + public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clockSkew must be >= 0"); + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null).block()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context cannot be null"); + } + + @Test + public void authorizeWhenNotClientCredentialsThenUnableToAuthorize() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withClientRegistration(clientRegistration) + .principal(this.principal) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + } + + @Test + public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); + + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + } + + @Test + public void authorizeWhenClientCredentialsAndTokenExpiredThenReauthorize() { + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); + OAuth2AccessToken accessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), accessToken); + + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); + + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + } + + @Test + public void authorizeWhenClientCredentialsAndTokenNotExpiredThenNotReauthorize() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.noScopes()); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingReactiveOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingReactiveOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 0000000000..45ddf6e528 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingReactiveOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,97 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Test; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import reactor.core.publisher.Mono; + +import java.util.Collections; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link DelegatingReactiveOAuth2AuthorizedClientProvider}. + * + * @author Joe Grandja + */ +public class DelegatingReactiveOAuth2AuthorizedClientProviderTests { + + @Test + public void constructorWhenProvidersIsEmptyThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new DelegatingReactiveOAuth2AuthorizedClientProvider(new ReactiveOAuth2AuthorizedClientProvider[0])) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> new DelegatingReactiveOAuth2AuthorizedClientProvider(Collections.emptyList())) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + DelegatingReactiveOAuth2AuthorizedClientProvider delegate = new DelegatingReactiveOAuth2AuthorizedClientProvider( + mock(ReactiveOAuth2AuthorizedClientProvider.class)); + assertThatThrownBy(() -> delegate.authorize(null).block()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context cannot be null"); + } + + @Test + public void authorizeWhenProviderCanAuthorizeThenReturnAuthorizedClient() { + Authentication principal = new TestingAuthenticationToken("principal", "password"); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, principal.getName(), TestOAuth2AccessTokens.noScopes()); + + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider1 = mock(ReactiveOAuth2AuthorizedClientProvider.class); + when(authorizedClientProvider1.authorize(any())).thenReturn(Mono.empty()); + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider2 = mock(ReactiveOAuth2AuthorizedClientProvider.class); + when(authorizedClientProvider2.authorize(any())).thenReturn(Mono.empty()); + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider3 = mock(ReactiveOAuth2AuthorizedClientProvider.class); + when(authorizedClientProvider3.authorize(any())).thenReturn(Mono.just(authorizedClient)); + + DelegatingReactiveOAuth2AuthorizedClientProvider delegate = new DelegatingReactiveOAuth2AuthorizedClientProvider( + authorizedClientProvider1, authorizedClientProvider2, authorizedClientProvider3); + OAuth2AuthorizationContext context = OAuth2AuthorizationContext.withClientRegistration(clientRegistration) + .principal(principal) + .build(); + OAuth2AuthorizedClient reauthorizedClient = delegate.authorize(context).block(); + assertThat(reauthorizedClient).isSameAs(authorizedClient); + } + + @Test + public void authorizeWhenProviderCantAuthorizeThenReturnNull() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + OAuth2AuthorizationContext context = OAuth2AuthorizationContext.withClientRegistration(clientRegistration) + .principal(new TestingAuthenticationToken("principal", "password")) + .build(); + + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider1 = mock(ReactiveOAuth2AuthorizedClientProvider.class); + when(authorizedClientProvider1.authorize(any())).thenReturn(Mono.empty()); + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider2 = mock(ReactiveOAuth2AuthorizedClientProvider.class); + when(authorizedClientProvider2.authorize(any())).thenReturn(Mono.empty()); + + DelegatingReactiveOAuth2AuthorizedClientProvider delegate = new DelegatingReactiveOAuth2AuthorizedClientProvider( + authorizedClientProvider1, authorizedClientProvider2); + assertThat(delegate.authorize(context).block()).isNull(); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java index 89236d4c4f..efa307459c 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java @@ -44,28 +44,28 @@ public class OAuth2AuthorizationContextTests { } @Test - public void forClientWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.withClientRegistration((ClientRegistration) null).build()) + public void withClientRegistrationWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizationContext.withClientRegistration(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("clientRegistration cannot be null"); } @Test - public void forClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.withAuthorizedClient((OAuth2AuthorizedClient) null).build()) + public void withAuthorizedClientWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizationContext.withAuthorizedClient(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("authorizedClient cannot be null"); } @Test - public void forClientWhenPrincipalIsNullThenThrowIllegalArgumentException() { + public void withClientRegistrationWhenPrincipalIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("principal cannot be null"); } @Test - public void forClientWhenAllValuesProvidedThenAllValuesAreSet() { + public void withAuthorizedClientWhenAllValuesProvidedThenAllValuesAreSet() { OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) .principal(this.principal) .attribute("attribute1", "value1") diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilderTests.java new file mode 100644 index 0000000000..fd77cc4f1d --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilderTests.java @@ -0,0 +1,246 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.time.Instant; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +/** + * Tests for {@link ReactiveOAuth2AuthorizedClientProviderBuilder}. + * + * @author Joe Grandja + */ +public class ReactiveOAuth2AuthorizedClientProviderBuilderTests { + private ClientRegistration.Builder clientRegistrationBuilder; + private Authentication principal; + private MockWebServer server; + + @Before + public void setup() throws Exception { + this.server = new MockWebServer(); + this.server.start(); + String tokenUri = this.server.url("/oauth2/token").toString(); + this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration().tokenUri(tokenUri); + this.principal = new TestingAuthenticationToken("principal", "password"); + } + + @After + public void cleanup() throws Exception { + this.server.shutdown(); + } + + @Test + public void providerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> ReactiveOAuth2AuthorizedClientProviderBuilder.builder().provider(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenAuthorizationCodeProviderThenProviderAuthorizes() { + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = + ReactiveOAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .build(); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withClientRegistration(this.clientRegistrationBuilder.build()) + .principal(this.principal) + .build(); + assertThatThrownBy(() -> authorizedClientProvider.authorize(authorizationContext).block()) + .isInstanceOf(ClientAuthorizationRequiredException.class); + } + + @Test + public void buildWhenRefreshTokenProviderThenProviderReauthorizes() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = + ReactiveOAuth2AuthorizedClientProviderBuilder.builder() + .refreshToken() + .build(); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistrationBuilder.build(), + this.principal.getName(), + expiredAccessToken(), + TestOAuth2RefreshTokens.refreshToken()); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(authorizationContext).block(); + + assertThat(reauthorizedClient).isNotNull(); + + assertThat(this.server.getRequestCount()).isEqualTo(1); + + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("grant_type=refresh_token"); + } + + @Test + public void buildWhenClientCredentialsProviderThenProviderAuthorizes() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = + ReactiveOAuth2AuthorizedClientProviderBuilder.builder() + .clientCredentials() + .build(); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withClientRegistration(this.clientRegistrationBuilder.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).build()) + .principal(this.principal) + .build(); + OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext).block(); + + assertThat(authorizedClient).isNotNull(); + + assertThat(this.server.getRequestCount()).isEqualTo(1); + + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("grant_type=client_credentials"); + } + + @Test + public void buildWhenAllProvidersThenProvidersAuthorize() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = + ReactiveOAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken() + .clientCredentials() + .build(); + + // authorization_code + OAuth2AuthorizationContext authorizationCodeContext = + OAuth2AuthorizationContext.withClientRegistration(this.clientRegistrationBuilder.build()) + .principal(this.principal) + .build(); + assertThatThrownBy(() -> authorizedClientProvider.authorize(authorizationCodeContext).block()) + .isInstanceOf(ClientAuthorizationRequiredException.class); + + + // refresh_token + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistrationBuilder.build(), + this.principal.getName(), + expiredAccessToken(), + TestOAuth2RefreshTokens.refreshToken()); + + OAuth2AuthorizationContext refreshTokenContext = + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(refreshTokenContext).block(); + + assertThat(reauthorizedClient).isNotNull(); + + assertThat(this.server.getRequestCount()).isEqualTo(1); + + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("grant_type=refresh_token"); + + + // client_credentials + OAuth2AuthorizationContext clientCredentialsContext = + OAuth2AuthorizationContext.withClientRegistration(this.clientRegistrationBuilder.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).build()) + .principal(this.principal) + .build(); + authorizedClient = authorizedClientProvider.authorize(clientCredentialsContext).block(); + + assertThat(authorizedClient).isNotNull(); + + assertThat(this.server.getRequestCount()).isEqualTo(2); + + recordedRequest = this.server.takeRequest(); + formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("grant_type=client_credentials"); + } + + @Test + public void buildWhenCustomProviderThenProviderCalled() { + ReactiveOAuth2AuthorizedClientProvider customProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class); + when(customProvider.authorize(any())).thenReturn(Mono.empty()); + + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = + ReactiveOAuth2AuthorizedClientProviderBuilder.builder() + .provider(customProvider) + .build(); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withClientRegistration(this.clientRegistrationBuilder.build()) + .principal(this.principal) + .build(); + authorizedClientProvider.authorize(authorizationContext).block(); + + verify(customProvider).authorize(any(OAuth2AuthorizationContext.class)); + } + + private OAuth2AccessToken expiredAccessToken() { + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); + return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); + } + + private MockResponse jsonResponse(String json) { + return new MockResponse() + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(json); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 0000000000..97bc688a5d --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,188 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashSet; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +/** + * Tests for {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider}. + * + * @author Joe Grandja + */ +public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests { + private RefreshTokenReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; + + @Before + public void setup() { + this.authorizedClientProvider = new RefreshTokenReactiveOAuth2AuthorizedClientProvider(); + this.accessTokenResponseClient = mock(ReactiveOAuth2AccessTokenResponseClient.class); + this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.principal = new TestingAuthenticationToken("principal", "password"); + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); + OAuth2AccessToken expiredAccessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); + this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), + expiredAccessToken, TestOAuth2RefreshTokens.refreshToken()); + } + + @Test + public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("accessTokenResponseClient cannot be null"); + } + + @Test + public void setClockSkewWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clockSkew cannot be null"); + } + + @Test + public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clockSkew must be >= 0"); + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null).block()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context cannot be null"); + } + + @Test + public void authorizeWhenNotAuthorizedThenUnableToReauthorize() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + } + + @Test + public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), this.authorizedClient.getAccessToken()); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + } + + @Test + public void authorizeWhenAuthorizedAndAccessTokenNotExpiredThenNotReauthorize() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken()); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + } + + @Test + public void authorizeWhenAuthorizedAndAccessTokenExpiredThenReauthorize() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("new-refresh-token") + .build(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .build(); + + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); + + assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(reauthorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); + } + + @Test + public void authorizeWhenAuthorizedAndRequestScopeProvidedThenScopeRequested() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("new-refresh-token") + .build(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); + + String[] requestScope = new String[] { "read", "write" }; + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, requestScope) + .build(); + + this.authorizedClientProvider.authorize(authorizationContext).block(); + + ArgumentCaptor refreshTokenGrantRequestArgCaptor = + ArgumentCaptor.forClass(OAuth2RefreshTokenGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(refreshTokenGrantRequestArgCaptor.capture()); + assertThat(refreshTokenGrantRequestArgCaptor.getValue().getScopes()).isEqualTo(new HashSet<>(Arrays.asList(requestScope))); + } + + @Test + public void authorizeWhenAuthorizedAndInvalidRequestScopeProvidedThenThrowIllegalArgumentException() { + String invalidRequestScope = "read write"; + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, invalidRequestScope) + .build(); + + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext).block()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("The context attribute must be of type String[] '" + + OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java new file mode 100644 index 0000000000..8449546b9a --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java @@ -0,0 +1,217 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; + +import java.time.Instant; +import java.util.Collections; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link WebClientReactiveRefreshTokenTokenResponseClient}. + * + * @author Joe Grandja + */ +public class WebClientReactiveRefreshTokenTokenResponseClientTests { + private WebClientReactiveRefreshTokenTokenResponseClient tokenResponseClient = new WebClientReactiveRefreshTokenTokenResponseClient(); + private ClientRegistration.Builder clientRegistrationBuilder; + private OAuth2AccessToken accessToken; + private OAuth2RefreshToken refreshToken; + private MockWebServer server; + + @Before + public void setup() throws Exception { + this.server = new MockWebServer(); + this.server.start(); + String tokenUri = this.server.url("/oauth2/token").toString(); + this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration().tokenUri(tokenUri); + this.accessToken = TestOAuth2AccessTokens.scopes("read", "write"); + this.refreshToken = TestOAuth2RefreshTokens.refreshToken(); + } + + @After + public void cleanup() throws Exception { + this.server.shutdown(); + } + + @Test + public void setWebClientWhenClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.tokenResponseClient.setWebClient(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void getTokenResponseWhenRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(null).block()) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); + + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block(); + + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)).isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("grant_type=refresh_token"); + assertThat(formParameters).contains("refresh_token=refresh-token"); + + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly(this.accessToken.getScopes().toArray(new String[0])); + assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo(this.refreshToken.getTokenValue()); + } + + @Test + public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .clientAuthenticationMethod(ClientAuthenticationMethod.POST) + .build(); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + clientRegistration, this.accessToken, this.refreshToken); + + this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block(); + + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("client_id=client-id"); + assertThat(formParameters).contains("client_secret=client-secret"); + } + + @Test + public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block()) + .isInstanceOf(OAuth2AuthorizationException.class) + .hasMessageContaining("[invalid_token_response] An error occurred parsing the Access Token response") + .hasMessageContaining("Token type must be \"Bearer\""); + } + + @Test + public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken, Collections.singleton("read")); + + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block(); + + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("scope=read"); + + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); + } + + @Test + public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\n" + + " \"error\": \"unauthorized_client\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block()) + .isInstanceOf(OAuth2AuthorizationException.class) + .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .hasMessageContaining("HTTP Status Code 400"); + } + + @Test + public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { + this.server.enqueue(new MockResponse().setResponseCode(500)); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block()) + .isInstanceOf(OAuth2AuthorizationException.class) + .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .hasMessageContaining("HTTP Status Code 500"); + } + + private MockResponse jsonResponse(String json) { + return new MockResponse() + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(json); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 7b8ba1889f..917509bac4 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -36,15 +36,23 @@ import org.springframework.http.codec.json.Jackson2JsonEncoder; import org.springframework.http.codec.multipart.MultipartHttpMessageWriter; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.mock.http.client.reactive.MockClientHttpRequest; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.oauth2.client.web.reactive.function.client.OAuth2AuthorizedClientResolver.Request; +import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; @@ -68,12 +76,10 @@ import java.util.Map; import java.util.Optional; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import static org.springframework.http.HttpMethod.GET; import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId; import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient; @@ -91,10 +97,12 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { private ReactiveClientRegistrationRepository clientRegistrationRepository; @Mock - private OAuth2AuthorizedClientResolver authorizedClientResolver; + private ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient; @Mock - private ServerWebExchange serverWebExchange; + private ReactiveOAuth2AccessTokenResponseClient refreshTokenTokenResponseClient; + + private ServerWebExchange serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build(); @Captor private ArgumentCaptor authorizedClientCaptor; @@ -113,7 +121,45 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Before public void setup() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = + ReactiveOAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken(configurer -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) + .clientCredentials(configurer -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) + .build(); + DefaultServerOAuth2AuthorizedClientManager authorizedClientManager = new DefaultServerOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager); + } + + @Test + public void constructorWhenAuthorizedClientManagerIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new ServerOAuth2AuthorizedClientExchangeFilterFunction(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void setClientCredentialsTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.function.setClientCredentialsTokenResponseClient(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientCredentialsTokenResponseClient cannot be null"); + } + + @Test + public void setClientCredentialsTokenResponseClientWhenNotDefaultAuthorizedClientManagerThenThrowIllegalStateException() { + assertThatThrownBy(() -> this.function.setClientCredentialsTokenResponseClient(new WebClientReactiveClientCredentialsTokenResponseClient())) + .isInstanceOf(IllegalStateException.class) + .hasMessage("The client cannot be set when the constructor used is \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); + } + + @Test + public void setAccessTokenExpiresSkewWhenNotDefaultAuthorizedClientManagerThenThrowIllegalStateException() { + assertThatThrownBy(() -> this.function.setAccessTokenExpiresSkew(Duration.ofSeconds(30))) + .isInstanceOf(IllegalStateException.class) + .hasMessage("The accessTokenExpiresSkew cannot be set when the constructor used is \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); } @Test @@ -134,7 +180,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { .attributes(oauth2AuthorizedClient(authorizedClient)) .build(); - this.function.filter(request, this.exchange).block(); + this.function.filter(request, this.exchange) + .subscriberContext(serverWebExchange()) + .block(); assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue()); } @@ -148,7 +196,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { .attributes(oauth2AuthorizedClient(authorizedClient)) .build(); - this.function.filter(request, this.exchange).block(); + this.function.filter(request, this.exchange) + .subscriberContext(serverWebExchange()) + .block(); HttpHeaders headers = this.exchange.getRequest().headers(); assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue()); @@ -156,47 +206,35 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() { - TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("new-token") + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(360) + .build(); + when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); + ClientRegistration registration = TestClientRegistrations.clientCredentials().build(); - String clientRegistrationId = registration.getClientId(); - - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.authorizedClientResolver); - - OAuth2AccessToken newAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "new-token", - Instant.now(), - Instant.now().plus(Duration.ofDays(1))); - OAuth2AuthorizedClient newAuthorizedClient = new OAuth2AuthorizedClient(registration, - "principalName", newAccessToken, null); - Request r = new Request(clientRegistrationId, authentication, null); - when(this.authorizedClientResolver.clientCredentials(any(), any(), any())).thenReturn(Mono.just(newAuthorizedClient)); - when(this.authorizedClientResolver.createDefaultedRequest(any(), any(), any())).thenReturn(Mono.just(r)); - - when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); - Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); - OAuth2AccessToken accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration, "principalName", accessToken, null); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .build(); - this.function.filter(request, this.exchange) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .subscriberContext(serverWebExchange()) .block(); + verify(this.clientCredentialsTokenResponseClient).getTokenResponse(any()); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any()); - verify(this.authorizedClientResolver).clientCredentials(any(), any(), any()); - verify(this.authorizedClientResolver).createDefaultedRequest(any(), any(), any()); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); @@ -212,8 +250,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); ClientRegistration registration = TestClientRegistrations.clientCredentials().build(); - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.authorizedClientResolver); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration, "principalName", this.accessToken, null); ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) @@ -222,10 +258,10 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { this.function.filter(request, this.exchange) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .subscriberContext(serverWebExchange()) .block(); - verify(this.authorizedClientResolver, never()).clientCredentials(any(), any(), any()); - verify(this.authorizedClientResolver, never()).createDefaultedRequest(any(), any(), any()); + verify(this.clientCredentialsTokenResponseClient, never()).getTokenResponse(any()); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); @@ -238,24 +274,23 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshRequiredThenRefresh() { - when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") .tokenType(OAuth2AccessToken.TokenType.BEARER) .expiresIn(3600) .refreshToken("refresh-1") .build(); - when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response)); + when(this.refreshTokenTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(response)); + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); - this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .build(); @@ -263,8 +298,10 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); this.function.filter(request, this.exchange) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .subscriberContext(serverWebExchange()) .block(); + verify(this.refreshTokenTokenResponseClient).getTokenResponse(any()); verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(authentication), any()); OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor.getValue(); @@ -272,84 +309,26 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(response.getRefreshToken()); List requests = this.exchange.getRequests(); - assertThat(requests).hasSize(2); + assertThat(requests).hasSize(1); ClientRequest request0 = requests.get(0); - assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); - assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com/login/oauth/access_token"); - assertThat(request0.method()).isEqualTo(HttpMethod.POST); - assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token"); - - ClientRequest request1 = requests.get(1); - assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); - assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); - assertThat(request1.method()).isEqualTo(HttpMethod.GET); - assertThat(getBody(request1)).isEmpty(); - } - - @Test - public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefreshToken() { - when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); - OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(3600) -// .refreshToken(xxx) // No refreshToken in response - .build(); - when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response)); - Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); - Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); - - this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), - this.accessToken.getTokenValue(), - issuedAt, - accessTokenExpiresAt); - - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) - .build(); - - TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); - this.function.filter(request, this.exchange) - .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) - .block(); - - verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(authentication), any()); - - OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor.getValue(); - assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken()); - assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(authorizedClient.getRefreshToken()); - - List requests = this.exchange.getRequests(); - assertThat(requests).hasSize(2); - - ClientRequest request0 = requests.get(0); - assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); - assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com/login/oauth/access_token"); - assertThat(request0.method()).isEqualTo(HttpMethod.POST); - assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token"); - - ClientRequest request1 = requests.get(1); - assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); - assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); - assertThat(request1.method()).isEqualTo(HttpMethod.GET); - assertThat(getBody(request1)).isEmpty(); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request0.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request0)).isEmpty(); } @Test public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() { - when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") .tokenType(OAuth2AccessToken.TokenType.BEARER) .expiresIn(3600) .refreshToken("refresh-1") .build(); - when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response)); + when(this.refreshTokenTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(response)); Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); - Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); + Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(), issuedAt, @@ -363,24 +342,20 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { .build(); this.function.filter(request, this.exchange) + .subscriberContext(serverWebExchange()) .block(); + verify(this.refreshTokenTokenResponseClient).getTokenResponse(any()); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any()); List requests = this.exchange.getRequests(); - assertThat(requests).hasSize(2); + assertThat(requests).hasSize(1); ClientRequest request0 = requests.get(0); - assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); - assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com/login/oauth/access_token"); - assertThat(request0.method()).isEqualTo(HttpMethod.POST); - assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token"); - - ClientRequest request1 = requests.get(1); - assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); - assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); - assertThat(request1.method()).isEqualTo(HttpMethod.GET); - assertThat(getBody(request1)).isEmpty(); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request0.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request0)).isEmpty(); } @Test @@ -391,7 +366,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { .attributes(oauth2AuthorizedClient(authorizedClient)) .build(); - this.function.filter(request, this.exchange).block(); + this.function.filter(request, this.exchange) + .subscriberContext(serverWebExchange()) + .block(); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); @@ -412,7 +389,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { .attributes(oauth2AuthorizedClient(authorizedClient)) .build(); - this.function.filter(request, this.exchange).block(); + this.function.filter(request, this.exchange) + .subscriberContext(serverWebExchange()) + .block(); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); @@ -430,12 +409,13 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient)); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration)); ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(clientRegistrationId(this.registration.getRegistrationId())) .build(); - this.function.filter(request, this.exchange).block(); + this.function.filter(request, this.exchange) + .subscriberContext(serverWebExchange()) + .block(); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); @@ -454,11 +434,12 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient)); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration)); ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .build(); - this.function.filter(request, this.exchange).block(); + this.function.filter(request, this.exchange) + .subscriberContext(serverWebExchange()) + .block(); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); @@ -478,7 +459,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient)); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration)); ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .build(); @@ -488,6 +468,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { this.function .filter(request, this.exchange) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .subscriberContext(serverWebExchange()) .block(); List requests = this.exchange.getRequests(); @@ -526,7 +507,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient)); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration)); ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(clientRegistrationId(this.registration.getRegistrationId())) .build(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index b06d4fbd86..aa99e41155 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,24 +22,32 @@ import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.springframework.core.MethodParameter; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.util.ReflectionUtils; +import org.springframework.web.server.ServerWebExchange; import reactor.core.publisher.Mono; import reactor.util.context.Context; import java.lang.reflect.Method; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; @@ -55,24 +63,50 @@ public class OAuth2AuthorizedClientArgumentResolverTests { private ReactiveClientRegistrationRepository clientRegistrationRepository; @Mock private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; - private OAuth2AuthorizedClientArgumentResolver argumentResolver; - private OAuth2AuthorizedClient authorizedClient; + private ServerWebExchange serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build(); + + private OAuth2AuthorizedClientArgumentResolver argumentResolver; + private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; private Authentication authentication = new TestingAuthenticationToken("test", "this"); @Before public void setUp() { - this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, this.authorizedClientRepository); - this.authorizedClient = mock(OAuth2AuthorizedClient.class); + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = + ReactiveOAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken() + .clientCredentials() + .build(); + DefaultServerOAuth2AuthorizedClientManager authorizedClientManager = new DefaultServerOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.authentication.getName(), TestOAuth2AccessTokens.noScopes()); when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.just(this.authorizedClient)); } @Test - public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null, this.authorizedClientRepository)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, null)) .isInstanceOf(IllegalArgumentException.class); } + @Test + public void constructorWhenAuthorizedClientManagerIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null)) + .isInstanceOf(IllegalArgumentException.class); + } + @Test public void supportsParameterWhenParameterTypeOAuth2AuthorizedClientThenTrue() { MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); @@ -101,8 +135,6 @@ public class OAuth2AuthorizedClientArgumentResolverTests { @Test public void resolveArgumentWhenRegistrationIdEmptyAndOAuth2AuthenticationThenResolves() { - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just( - TestClientRegistrations.clientRegistration().build())); this.authentication = mock(OAuth2AuthenticationToken.class); when(((OAuth2AuthenticationToken) this.authentication).getAuthorizedClientRegistrationId()).thenReturn("client1"); MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class); @@ -112,24 +144,19 @@ public class OAuth2AuthorizedClientArgumentResolverTests { @Test public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenResolves() { this.authentication = null; - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just( - TestClientRegistrations.clientRegistration().build())); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient); } @Test public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() { - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just( - TestClientRegistrations.clientRegistration().build())); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient); } @Test public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() { - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just( - TestClientRegistrations.clientRegistration().build())); + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.clientRegistration)); when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.empty()); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); assertThatThrownBy(() -> resolveArgument(methodParameter)) @@ -139,9 +166,14 @@ public class OAuth2AuthorizedClientArgumentResolverTests { private Object resolveArgument(MethodParameter methodParameter) { return this.argumentResolver.resolveArgument(methodParameter, null, null) .subscriberContext(this.authentication == null ? Context.empty() : ReactiveSecurityContextHolder.withAuthentication(this.authentication)) + .subscriberContext(serverWebExchange()) .block(); } + private Context serverWebExchange() { + return Context.of(ServerWebExchange.class, this.serverWebExchange); + } + private MethodParameter getMethodParameter(String methodName, Class... paramTypes) { Method method = ReflectionUtils.findMethod( TestController.class, methodName, paramTypes); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizedClientManagerTests.java new file mode 100644 index 0000000000..9729120cd5 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizedClientManagerTests.java @@ -0,0 +1,295 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web.server; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; + +import java.util.Collections; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +/** + * Tests for {@link DefaultServerOAuth2AuthorizedClientManager}. + * + * @author Joe Grandja + */ +public class DefaultServerOAuth2AuthorizedClientManagerTests { + private ReactiveClientRegistrationRepository clientRegistrationRepository; + private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; + private Function contextAttributesMapper; + private DefaultServerOAuth2AuthorizedClientManager authorizedClientManager; + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; + private MockServerWebExchange serverWebExchange; + private ArgumentCaptor authorizationContextCaptor; + + @SuppressWarnings("unchecked") + @Before + public void setup() { + this.clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class); + when(this.clientRegistrationRepository.findByRegistrationId( + anyString())).thenReturn(Mono.empty()); + this.authorizedClientRepository = mock(ServerOAuth2AuthorizedClientRepository.class); + when(this.authorizedClientRepository.loadAuthorizedClient( + anyString(), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(Mono.empty()); + this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class); + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.empty()); + this.contextAttributesMapper = mock(Function.class); + when(this.contextAttributesMapper.apply(any())).thenReturn(Collections.emptyMap()); + this.authorizedClientManager = new DefaultServerOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); + this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider); + this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.principal = new TestingAuthenticationToken("principal", "password"); + this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); + this.serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build(); + this.authorizationContextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationContext.class); + } + + @Test + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new DefaultServerOAuth2AuthorizedClientManager(null, this.authorizedClientRepository)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationRepository cannot be null"); + } + + @Test + public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new DefaultServerOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientRepository cannot be null"); + } + + @Test + public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientProvider cannot be null"); + } + + @Test + public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("contextAttributesMapper cannot be null"); + } + + @Test + public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.authorize(null).block()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizeRequest cannot be null"); + } + + @Test + public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { + ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest( + "invalid-registration-id", this.principal, this.serverWebExchange); + assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'"); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + + ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest( + this.clientRegistration.getRegistrationId(), this.principal, this.serverWebExchange); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isNull(); + verify(this.authorizedClientRepository, never()).saveAuthorizedClient( + any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.serverWebExchange)); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + + when(this.authorizedClientProvider.authorize( + any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient)); + + ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest( + this.clientRegistration.getRegistrationId(), this.principal, this.serverWebExchange); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(this.authorizedClient); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(this.authorizedClient), eq(this.principal), eq(this.serverWebExchange)); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + when(this.authorizedClientRepository.loadAuthorizedClient( + eq(this.clientRegistration.getRegistrationId()), eq(this.principal), eq(this.serverWebExchange))).thenReturn(Mono.just(this.authorizedClient)); + + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); + + ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest( + this.clientRegistration.getRegistrationId(), this.principal, this.serverWebExchange); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(any()); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(reauthorizedClient); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal), eq(this.serverWebExchange)); + } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { + ServerOAuth2AuthorizeRequest reauthorizeRequest = new ServerOAuth2AuthorizeRequest( + this.authorizedClient, this.principal, this.serverWebExchange); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(this.authorizedClient); + verify(this.authorizedClientRepository, never()).saveAuthorizedClient( + any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.serverWebExchange)); + } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenSupportedProviderThenReauthorized() { + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); + + ServerOAuth2AuthorizeRequest reauthorizeRequest = new ServerOAuth2AuthorizeRequest( + this.authorizedClient, this.principal, this.serverWebExchange); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(reauthorizedClient); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal), eq(this.serverWebExchange)); + } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenRequestScopeParameterThenMappedToContext() { + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); + + // Override the mock with the default + this.authorizedClientManager.setContextAttributesMapper( + new DefaultServerOAuth2AuthorizedClientManager.DefaultContextAttributesMapper()); + + this.serverWebExchange = MockServerWebExchange.builder( + MockServerHttpRequest + .get("/") + .queryParam(OAuth2ParameterNames.SCOPE, "read write")) + .build(); + + ServerOAuth2AuthorizeRequest reauthorizeRequest = new ServerOAuth2AuthorizeRequest( + this.authorizedClient, this.principal, this.serverWebExchange); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + assertThat(authorizationContext.getAttributes()).containsKey(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + assertThat(requestScopeAttribute).contains("read", "write"); + + assertThat(authorizedClient).isSameAs(reauthorizedClient); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal), eq(this.serverWebExchange)); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizeRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizeRequestTests.java new file mode 100644 index 0000000000..d4ab401972 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizeRequestTests.java @@ -0,0 +1,93 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web.server; + +import org.junit.Test; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link ServerOAuth2AuthorizeRequest}. + * + * @author Joe Grandja + */ +public class ServerOAuth2AuthorizeRequestTests { + private ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + private Authentication principal = new TestingAuthenticationToken("principal", "password"); + private OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); + private MockServerWebExchange serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build(); + + @Test + public void constructorWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new ServerOAuth2AuthorizeRequest((String) null, this.principal, this.serverWebExchange)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationId cannot be empty"); + } + + @Test + public void constructorWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new ServerOAuth2AuthorizeRequest((OAuth2AuthorizedClient) null, this.principal, this.serverWebExchange)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClient cannot be null"); + } + + @Test + public void constructorWhenPrincipalIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new ServerOAuth2AuthorizeRequest(this.clientRegistration.getRegistrationId(), null, this.serverWebExchange)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("principal cannot be null"); + } + + @Test + public void constructorWhenServerWebExchangeIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new ServerOAuth2AuthorizeRequest(this.clientRegistration.getRegistrationId(), this.principal, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("serverWebExchange cannot be null"); + } + + @Test + public void constructorClientRegistrationIdWhenAllValuesProvidedThenAllValuesAreSet() { + ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest( + this.clientRegistration.getRegistrationId(), this.principal, this.serverWebExchange); + + assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.clientRegistration.getRegistrationId()); + assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal); + assertThat(authorizeRequest.getServerWebExchange()).isEqualTo(this.serverWebExchange); + } + + @Test + public void constructorAuthorizedClientWhenAllValuesProvidedThenAllValuesAreSet() { + ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest( + this.authorizedClient, this.principal, this.serverWebExchange); + + assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.authorizedClient.getClientRegistration().getRegistrationId()); + assertThat(authorizeRequest.getAuthorizedClient()).isEqualTo(this.authorizedClient); + assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal); + assertThat(authorizeRequest.getServerWebExchange()).isEqualTo(this.serverWebExchange); + } +} diff --git a/samples/boot/oauth2webclient-webflux/src/main/java/sample/config/WebClientConfig.java b/samples/boot/oauth2webclient-webflux/src/main/java/sample/config/WebClientConfig.java index 7a2130f2c9..7fdba365b7 100644 --- a/samples/boot/oauth2webclient-webflux/src/main/java/sample/config/WebClientConfig.java +++ b/samples/boot/oauth2webclient-webflux/src/main/java/sample/config/WebClientConfig.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,8 +18,12 @@ package sample.config; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction; +import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.web.reactive.function.client.WebClient; @@ -31,13 +35,31 @@ import org.springframework.web.reactive.function.client.WebClient; public class WebClientConfig { @Bean - WebClient webClient(ReactiveClientRegistrationRepository clientRegistrationRepository, - ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + WebClient webClient(ServerOAuth2AuthorizedClientManager authorizedClientManager) { ServerOAuth2AuthorizedClientExchangeFilterFunction oauth = - new ServerOAuth2AuthorizedClientExchangeFilterFunction(clientRegistrationRepository, authorizedClientRepository); + new ServerOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager); oauth.setDefaultOAuth2AuthorizedClient(true); return WebClient.builder() .filter(oauth) .build(); } + + @Bean + ServerOAuth2AuthorizedClientManager authorizedClientManager( + ReactiveClientRegistrationRepository clientRegistrationRepository, + ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = + ReactiveOAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken() + .clientCredentials() + .build(); + DefaultServerOAuth2AuthorizedClientManager authorizedClientManager = + new DefaultServerOAuth2AuthorizedClientManager( + clientRegistrationRepository, authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + + return authorizedClientManager; + } }