diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index 61b8254d81..da3aac3343 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -167,6 +167,7 @@ public final class OAuth2LoginConfigurer> exten public class UserInfoEndpointConfig { private OAuth2UserService userService; + private OAuth2UserService oidcUserService; private Map> customUserTypes = new HashMap<>(); private GrantedAuthoritiesMapper userAuthoritiesMapper; @@ -179,6 +180,12 @@ public final class OAuth2LoginConfigurer> exten return this; } + public UserInfoEndpointConfig oidcUserService(OAuth2UserService oidcUserService) { + Assert.notNull(oidcUserService, "oidcUserService cannot be null"); + this.oidcUserService = oidcUserService; + return this; + } + public UserInfoEndpointConfig customUserType(Class customUserType, String clientRegistrationId) { Assert.notNull(customUserType, "customUserType cannot be null"); Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); @@ -227,7 +234,6 @@ public final class OAuth2LoginConfigurer> exten } } - OAuth2LoginAuthenticationProvider oauth2LoginAuthenticationProvider = new OAuth2LoginAuthenticationProvider(accessTokenResponseClient, oauth2UserService); if (this.userInfoEndpointConfig.userAuthoritiesMapper != null) { @@ -236,8 +242,12 @@ public final class OAuth2LoginConfigurer> exten } http.authenticationProvider(this.postProcess(oauth2LoginAuthenticationProvider)); - OAuth2UserService oidcUserService = new OidcUserService(); + OAuth2UserService oidcUserService = this.userInfoEndpointConfig.oidcUserService; + if (oidcUserService == null) { + oidcUserService = new OidcUserService(); + } JwtDecoderRegistry jwtDecoderRegistry = new NimbusJwtDecoderRegistry(); + OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider = new OidcAuthorizationCodeAuthenticationProvider( accessTokenResponseClient, oidcUserService, jwtDecoderRegistry);