diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java index 2bfdce2f5f..e218571456 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java @@ -15,6 +15,17 @@ */ package org.springframework.security.oauth2.client.registration; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.security.oauth2.core.AuthenticationMethod; import org.springframework.security.oauth2.core.AuthorizationGrantType; @@ -22,14 +33,7 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import java.io.Serializable; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.LinkedHashSet; -import java.util.Map; -import java.util.Set; +import static java.util.Collections.EMPTY_MAP; /** * A representation of a client registration with an OAuth 2.0 or OpenID Connect 1.0 Provider. @@ -263,6 +267,17 @@ public final class ClientRegistration implements Serializable { return new Builder(registrationId); } + /** + * Returns a new {@link Builder}, initialized with the provided {@link ClientRegistration}. + * + * @param clientRegistration the {@link ClientRegistration} to copy from + * @return the {@link Builder} + */ + public static Builder withClientRegistration(ClientRegistration clientRegistration) { + Assert.notNull(clientRegistration, "clientRegistration cannot be null"); + return new Builder(clientRegistration); + } + /** * A builder for {@link ClientRegistration}. */ @@ -288,6 +303,27 @@ public final class ClientRegistration implements Serializable { this.registrationId = registrationId; } + private Builder(ClientRegistration clientRegistration) { + this.registrationId = clientRegistration.registrationId; + this.clientId = clientRegistration.clientId; + this.clientSecret = clientRegistration.clientSecret; + this.clientAuthenticationMethod = clientRegistration.clientAuthenticationMethod; + this.authorizationGrantType = clientRegistration.authorizationGrantType; + this.redirectUriTemplate = clientRegistration.redirectUriTemplate; + this.scopes = clientRegistration.scopes == null ? null : new HashSet<>(clientRegistration.scopes); + this.authorizationUri = clientRegistration.providerDetails.authorizationUri; + this.tokenUri = clientRegistration.providerDetails.tokenUri; + this.userInfoUri = clientRegistration.providerDetails.userInfoEndpoint.uri; + this.userInfoAuthenticationMethod = clientRegistration.providerDetails.userInfoEndpoint.authenticationMethod; + this.userNameAttributeName = clientRegistration.providerDetails.userInfoEndpoint.userNameAttributeName; + this.jwkSetUri = clientRegistration.providerDetails.jwkSetUri; + Map configurationMetadata = clientRegistration.providerDetails.configurationMetadata; + if (configurationMetadata != EMPTY_MAP) { + this.configurationMetadata = new HashMap<>(configurationMetadata); + } + this.clientName = clientRegistration.clientName; + } + /** * Sets the registration id. * diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java index e770376b9c..2a2adb07fe 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java @@ -15,11 +15,6 @@ */ package org.springframework.security.oauth2.client.registration; -import org.junit.Test; -import org.springframework.security.oauth2.core.AuthenticationMethod; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; - import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; @@ -27,8 +22,16 @@ import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; +import org.junit.Test; + +import org.springframework.security.oauth2.core.AuthenticationMethod; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.springframework.security.oauth2.client.registration.ClientRegistration.withClientRegistration; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; /** * Tests for {@link ClientRegistration}. @@ -696,4 +699,72 @@ public class ClientRegistrationTests { assertThat(registration.getProviderDetails().getTokenUri()).isEqualTo(TOKEN_URI); assertThat(registration.getClientName()).isEqualTo(CLIENT_NAME); } + + @Test + public void buildWhenClientRegistrationProvidedThenMakesACopy() { + ClientRegistration clientRegistration = clientRegistration().build(); + ClientRegistration updated = withClientRegistration(clientRegistration).build(); + assertThat(clientRegistration.getScopes()).isEqualTo(updated.getScopes()); + assertThat(clientRegistration.getScopes()).isNotSameAs(updated.getScopes()); + assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()) + .isEqualTo(updated.getProviderDetails().getConfigurationMetadata()); + assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()) + .isNotSameAs(updated.getProviderDetails().getConfigurationMetadata()); + } + + @Test + public void buildWhenClientRegistrationProvidedThenEachPropertyMatches() { + ClientRegistration clientRegistration = clientRegistration().build(); + ClientRegistration updated = withClientRegistration(clientRegistration).build(); + assertThat(clientRegistration.getRegistrationId()).isEqualTo(updated.getRegistrationId()); + assertThat(clientRegistration.getClientId()).isEqualTo(updated.getClientId()); + assertThat(clientRegistration.getClientSecret()).isEqualTo(updated.getClientSecret()); + assertThat(clientRegistration.getClientAuthenticationMethod()) + .isEqualTo(updated.getClientAuthenticationMethod()); + assertThat(clientRegistration.getAuthorizationGrantType()) + .isEqualTo(updated.getAuthorizationGrantType()); + assertThat(clientRegistration.getRedirectUriTemplate()) + .isEqualTo(updated.getRedirectUriTemplate()); + assertThat(clientRegistration.getScopes()).isEqualTo(updated.getScopes()); + + ClientRegistration.ProviderDetails providerDetails = clientRegistration.getProviderDetails(); + ClientRegistration.ProviderDetails updatedProviderDetails = updated.getProviderDetails(); + assertThat(providerDetails.getAuthorizationUri()) + .isEqualTo(updatedProviderDetails.getAuthorizationUri()); + assertThat(providerDetails.getTokenUri()) + .isEqualTo(updatedProviderDetails.getTokenUri()); + + ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint = providerDetails.getUserInfoEndpoint(); + ClientRegistration.ProviderDetails.UserInfoEndpoint updatedUserInfoEndpoint = updatedProviderDetails.getUserInfoEndpoint(); + assertThat(userInfoEndpoint.getUri()).isEqualTo(updatedUserInfoEndpoint.getUri()); + assertThat(userInfoEndpoint.getAuthenticationMethod()) + .isEqualTo(updatedUserInfoEndpoint.getAuthenticationMethod()); + assertThat(userInfoEndpoint.getUserNameAttributeName()) + .isEqualTo(updatedUserInfoEndpoint.getUserNameAttributeName()); + + assertThat(providerDetails.getJwkSetUri()).isEqualTo(updatedProviderDetails.getJwkSetUri()); + assertThat(providerDetails.getConfigurationMetadata()) + .isEqualTo(updatedProviderDetails.getConfigurationMetadata()); + + assertThat(clientRegistration.getClientName()).isEqualTo(updated.getClientName()); + } + + @Test + public void buildWhenClientRegistrationValuesOverriddenThenPropagated() { + ClientRegistration clientRegistration = clientRegistration().build(); + ClientRegistration updated = withClientRegistration(clientRegistration) + .clientSecret("a-new-secret") + .scope("a-new-scope") + .providerConfigurationMetadata(Collections.singletonMap("a-new-config", "a-new-value")) + .build(); + + assertThat(clientRegistration.getClientSecret()).isNotEqualTo(updated.getClientSecret()); + assertThat(updated.getClientSecret()).isEqualTo("a-new-secret"); + assertThat(clientRegistration.getScopes()).doesNotContain("a-new-scope"); + assertThat(updated.getScopes()).containsExactly("a-new-scope"); + assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()) + .doesNotContainKey("a-new-config").doesNotContainValue("a-new-value"); + assertThat(updated.getProviderDetails().getConfigurationMetadata()) + .containsOnlyKeys("a-new-config").containsValue("a-new-value"); + } }