diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java index d99a6a595e..4313aa7497 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java @@ -486,7 +486,7 @@ public final class ClientRegistration implements Serializable { this.validateClientCredentialsGrantType(); } else if (AuthorizationGrantType.IMPLICIT.equals(this.authorizationGrantType)) { this.validateImplicitGrantType(); - } else { + } else if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(this.authorizationGrantType)) { this.validateAuthorizationCodeGrantType(); } this.validateScopes(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java index 09fc9fd960..0b10d0946e 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java @@ -589,4 +589,27 @@ public class ClientRegistrationTests { .build() ).isInstanceOf(IllegalArgumentException.class); } + + @Test + public void buildWhenCustomGrantAllAttributesProvidedThenAllAttributesAreSet() { + AuthorizationGrantType customGrantType = new AuthorizationGrantType("CUSTOM"); + ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(customGrantType) + .scope(SCOPES.toArray(new String[0])) + .tokenUri(TOKEN_URI) + .clientName(CLIENT_NAME) + .build(); + + assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID); + assertThat(registration.getClientId()).isEqualTo(CLIENT_ID); + assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET); + assertThat(registration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC); + assertThat(registration.getAuthorizationGrantType()).isEqualTo(customGrantType); + assertThat(registration.getScopes()).isEqualTo(SCOPES); + assertThat(registration.getProviderDetails().getTokenUri()).isEqualTo(TOKEN_URI); + assertThat(registration.getClientName()).isEqualTo(CLIENT_NAME); + } }