From e52dd81d03df39405672264d36bb8023bf53c7d2 Mon Sep 17 00:00:00 2001 From: Steve Riesenberg <5248162+sjohnr@users.noreply.github.com> Date: Fri, 1 Mar 2024 11:45:58 -0600 Subject: [PATCH] Customize mapping the OidcUser Closes gh-14672 --- .../OidcReactiveOAuth2UserService.java | 90 ++++++++++++++----- .../oidc/userinfo/OidcUserRequestUtils.java | 25 ++++++ .../client/oidc/userinfo/OidcUserService.java | 80 +++++++++++++---- .../OidcReactiveOAuth2UserServiceTests.java | 53 +++++++++++ .../oidc/userinfo/OidcUserServiceTests.java | 46 ++++++++++ 5 files changed, 252 insertions(+), 42 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java index 566b248d6a..21d2c7d012 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java @@ -18,9 +18,8 @@ package org.springframework.security.oauth2.client.oidc.userinfo; import java.time.Instant; import java.util.HashMap; -import java.util.HashSet; import java.util.Map; -import java.util.Set; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Predicate; @@ -28,7 +27,6 @@ import reactor.core.publisher.Mono; import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.Converter; -import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService; @@ -40,6 +38,7 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.converter.ClaimConversionService; import org.springframework.security.oauth2.core.converter.ClaimTypeConverter; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcUserInfo; import org.springframework.security.oauth2.core.oidc.StandardClaimNames; import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; @@ -47,7 +46,6 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.util.Assert; -import org.springframework.util.StringUtils; /** * An implementation of an {@link ReactiveOAuth2UserService} that supports OpenID Connect @@ -75,6 +73,8 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService< private Predicate retrieveUserInfo = OidcUserRequestUtils::shouldRetrieveUserInfo; + private BiFunction> oidcUserMapper = this::getUser; + /** * Returns the default {@link Converter}'s used for type conversion of claim values * for an {@link OidcUserInfo}. @@ -103,29 +103,15 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService< Assert.notNull(userRequest, "userRequest cannot be null"); // @formatter:off return getUserInfo(userRequest) - .map((userInfo) -> - new OidcUserAuthority(userRequest.getIdToken(), userInfo) - ) - .defaultIfEmpty(new OidcUserAuthority(userRequest.getIdToken(), null)) - .map((authority) -> { - OidcUserInfo userInfo = authority.getUserInfo(); - Set authorities = new HashSet<>(); - authorities.add(authority); - OAuth2AccessToken token = userRequest.getAccessToken(); - for (String scope : token.getScopes()) { - authorities.add(new SimpleGrantedAuthority("SCOPE_" + scope)); - } - String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails() - .getUserInfoEndpoint().getUserNameAttributeName(); - if (StringUtils.hasText(userNameAttributeName)) { - return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo, - userNameAttributeName); - } - return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo); - }); + .flatMap((userInfo) -> this.oidcUserMapper.apply(userRequest, userInfo)) + .switchIfEmpty(Mono.defer(() -> this.oidcUserMapper.apply(userRequest, null))); // @formatter:on } + private Mono getUser(OidcUserRequest userRequest, OidcUserInfo userInfo) { + return Mono.just(OidcUserRequestUtils.getUser(userRequest, userInfo)); + } + private Mono getUserInfo(OidcUserRequest userRequest) { if (!this.retrieveUserInfo.test(userRequest)) { return Mono.empty(); @@ -193,4 +179,60 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService< this.retrieveUserInfo = retrieveUserInfo; } + /** + * Sets the {@code BiFunction} used to map the {@link OidcUser user} from the + * {@link OidcUserRequest user request} and {@link OidcUserInfo user info}. + *

+ * This is useful when you need to map the user or authorities from the access token + * itself. For example, when the authorization server provides authorization + * information in the access token payload you can do the following:

+	 * 	@Bean
+	 * 	public OidcReactiveOAuth2UserService oidcUserService() {
+	 * 		var userService = new OidcReactiveOAuth2UserService();
+	 * 		userService.setOidcUserMapper(oidcUserMapper());
+	 * 		return userService;
+	 * 	}
+	 *
+	 * 	private static BiFunction<OidcUserRequest, OidcUserInfo, Mono<OidcUser>> oidcUserMapper() {
+	 * 		return (userRequest, userInfo) -> {
+	 * 			var accessToken = userRequest.getAccessToken();
+	 * 			var grantedAuthorities = new HashSet<GrantedAuthority>();
+	 * 			// TODO: Map authorities from the access token
+	 * 			var userNameAttributeName = "preferred_username";
+	 * 			return Mono.just(new DefaultOidcUser(
+	 * 				grantedAuthorities,
+	 * 				userRequest.getIdToken(),
+	 * 				userInfo,
+	 * 				userNameAttributeName
+	 * 			));
+	 * 		};
+	 * 	}
+	 * 
+ *

+ * Note that you can access the {@code userNameAttributeName} via the + * {@link ClientRegistration} as follows:

+	 * 	var userNameAttributeName = userRequest.getClientRegistration()
+	 * 		.getProviderDetails()
+	 * 		.getUserInfoEndpoint()
+	 * 		.getUserNameAttributeName();
+	 * 
+ *

+ * By default, a {@link DefaultOidcUser} is created with authorities mapped as + * follows: + *

    + *
  • An {@link OidcUserAuthority} is created from the {@link OidcIdToken} and + * {@link OidcUserInfo} with an authority of {@code OIDC_USER}
  • + *
  • Additional {@link SimpleGrantedAuthority authorities} are mapped from the + * {@link OAuth2AccessToken#getScopes() access token scopes} with a prefix of + * {@code SCOPE_}
  • + *
+ * @param oidcUserMapper the function used to map the {@link OidcUser} from the + * {@link OidcUserRequest} and {@link OidcUserInfo} + * @since 6.3 + */ + public final void setOidcUserMapper(BiFunction> oidcUserMapper) { + Assert.notNull(oidcUserMapper, "oidcUserMapper cannot be null"); + this.oidcUserMapper = oidcUserMapper; + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java index 1cd71aa072..5f39b8d961 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java @@ -16,8 +16,18 @@ package org.springframework.security.oauth2.client.oidc.userinfo; +import java.util.LinkedHashSet; +import java.util.Set; + +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; +import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -66,6 +76,21 @@ final class OidcUserRequestUtils { return false; } + static OidcUser getUser(OidcUserRequest userRequest, OidcUserInfo userInfo) { + Set authorities = new LinkedHashSet<>(); + authorities.add(new OidcUserAuthority(userRequest.getIdToken(), userInfo)); + OAuth2AccessToken token = userRequest.getAccessToken(); + for (String scope : token.getScopes()) { + authorities.add(new SimpleGrantedAuthority("SCOPE_" + scope)); + } + ClientRegistration.ProviderDetails providerDetails = userRequest.getClientRegistration().getProviderDetails(); + String userNameAttributeName = providerDetails.getUserInfoEndpoint().getUserNameAttributeName(); + if (StringUtils.hasText(userNameAttributeName)) { + return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo, userNameAttributeName); + } + return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo); + } + private OidcUserRequestUtils() { } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java index a0d1cd26aa..a7b0151ae0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java @@ -20,15 +20,14 @@ import java.time.Instant; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; -import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Predicate; import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.Converter; -import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration.ProviderDetails; @@ -41,6 +40,7 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.converter.ClaimConversionService; import org.springframework.security.oauth2.core.converter.ClaimTypeConverter; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.OidcUserInfo; import org.springframework.security.oauth2.core.oidc.StandardClaimNames; @@ -57,6 +57,7 @@ import org.springframework.util.StringUtils; * Provider's. * * @author Joe Grandja + * @author Steve Riesenberg * @since 5.0 * @see OAuth2UserService * @see OidcUserRequest @@ -81,6 +82,8 @@ public class OidcUserService implements OAuth2UserService retrieveUserInfo = this::shouldRetrieveUserInfo; + private BiFunction oidcUserMapper = OidcUserRequestUtils::getUser; + /** * Returns the default {@link Converter}'s used for type conversion of claim values * for an {@link OidcUserInfo}. @@ -130,13 +133,7 @@ public class OidcUserService implements OAuth2UserService authorities = new LinkedHashSet<>(); - authorities.add(new OidcUserAuthority(userRequest.getIdToken(), userInfo)); - OAuth2AccessToken token = userRequest.getAccessToken(); - for (String authority : token.getScopes()) { - authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); - } - return getUser(userRequest, userInfo, authorities); + return this.oidcUserMapper.apply(userRequest, userInfo); } private Map getClaims(OidcUserRequest userRequest, OAuth2User oauth2User) { @@ -148,15 +145,6 @@ public class OidcUserService implements OAuth2UserService authorities) { - ProviderDetails providerDetails = userRequest.getClientRegistration().getProviderDetails(); - String userNameAttributeName = providerDetails.getUserInfoEndpoint().getUserNameAttributeName(); - if (StringUtils.hasText(userNameAttributeName)) { - return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo, userNameAttributeName); - } - return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo); - } - private boolean shouldRetrieveUserInfo(OidcUserRequest userRequest) { // Auto-disabled if UserInfo Endpoint URI is not provided ProviderDetails providerDetails = userRequest.getClientRegistration().getProviderDetails(); @@ -255,4 +243,60 @@ public class OidcUserService implements OAuth2UserService + * This is useful when you need to map the user or authorities from the access token + * itself. For example, when the authorization server provides authorization + * information in the access token payload you can do the following:
+	 * 	@Bean
+	 * 	public OidcUserService oidcUserService() {
+	 * 		var userService = new OidcUserService();
+	 * 		userService.setOidcUserMapper(oidcUserMapper());
+	 * 		return userService;
+	 * 	}
+	 *
+	 * 	private static BiFunction<OidcUserRequest, OidcUserInfo, OidcUser> oidcUserMapper() {
+	 * 		return (userRequest, userInfo) -> {
+	 * 			var accessToken = userRequest.getAccessToken();
+	 * 			var grantedAuthorities = new HashSet<GrantedAuthority>();
+	 * 			// TODO: Map authorities from the access token
+	 * 			var userNameAttributeName = "preferred_username";
+	 * 			return new DefaultOidcUser(
+	 * 				grantedAuthorities,
+	 * 				userRequest.getIdToken(),
+	 * 				userInfo,
+	 * 				userNameAttributeName
+	 * 			);
+	 * 		};
+	 * 	}
+	 * 
+ *

+ * Note that you can access the {@code userNameAttributeName} via the + * {@link ClientRegistration} as follows:

+	 * 	var userNameAttributeName = userRequest.getClientRegistration()
+	 * 		.getProviderDetails()
+	 * 		.getUserInfoEndpoint()
+	 * 		.getUserNameAttributeName();
+	 * 
+ *

+ * By default, a {@link DefaultOidcUser} is created with authorities mapped as + * follows: + *

    + *
  • An {@link OidcUserAuthority} is created from the {@link OidcIdToken} and + * {@link OidcUserInfo} with an authority of {@code OIDC_USER}
  • + *
  • Additional {@link SimpleGrantedAuthority authorities} are mapped from the + * {@link OAuth2AccessToken#getScopes() access token scopes} with a prefix of + * {@code SCOPE_}
  • + *
+ * @param oidcUserMapper the function used to map the {@link OidcUser} from the + * {@link OidcUserRequest} and {@link OidcUserInfo} + * @since 6.3 + */ + public final void setOidcUserMapper(BiFunction oidcUserMapper) { + Assert.notNull(oidcUserMapper, "oidcUserMapper cannot be null"); + this.oidcUserMapper = oidcUserMapper; + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java index bf20d712f6..ade191c6da 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.Map; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Predicate; @@ -31,6 +32,7 @@ import okhttp3.mockwebserver.MockWebServer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import reactor.core.publisher.Mono; @@ -53,8 +55,10 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.converter.ClaimTypeConverter; import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; import org.springframework.security.oauth2.core.oidc.StandardClaimNames; import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens; +import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.user.DefaultOAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User; @@ -64,6 +68,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.ArgumentMatchers.same; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -235,6 +241,53 @@ public class OidcReactiveOAuth2UserServiceTests { verify(customRetrieveUserInfo).test(userRequest); } + @Test + public void loadUserWhenCustomOidcUserMapperSetThenUsed() { + Map attributes = new HashMap<>(); + attributes.put(StandardClaimNames.SUB, "subject"); + attributes.put("user", "steve"); + OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), attributes, + "user"); + given(this.oauth2UserService.loadUser(any(OidcUserRequest.class))).willReturn(Mono.just(oauth2User)); + BiFunction> customOidcUserMapper = mock(BiFunction.class); + OidcUser actualUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("a", "b"), this.idToken, + IdTokenClaimNames.SUB); + given(customOidcUserMapper.apply(any(OidcUserRequest.class), any(OidcUserInfo.class))) + .willReturn(Mono.just(actualUser)); + this.userService.setOidcUserMapper(customOidcUserMapper); + OidcUserRequest userRequest = userRequest(); + OidcUser oidcUser = this.userService.loadUser(userRequest).block(); + assertThat(oidcUser).isNotNull(); + assertThat(oidcUser).isEqualTo(actualUser); + ArgumentCaptor userInfoCaptor = ArgumentCaptor.forClass(OidcUserInfo.class); + verify(customOidcUserMapper).apply(eq(userRequest), userInfoCaptor.capture()); + OidcUserInfo userInfo = userInfoCaptor.getValue(); + assertThat(userInfo.getSubject()).isEqualTo("subject"); + assertThat(userInfo.getClaimAsString("user")).isEqualTo("steve"); + } + + @Test + public void loadUserWhenCustomOidcUserMapperSetAndUserInfoNotRetrievedThenUsed() { + // @formatter:off + this.accessToken = new OAuth2AccessToken( + this.accessToken.getTokenType(), + this.accessToken.getTokenValue(), + this.accessToken.getIssuedAt(), + this.accessToken.getExpiresAt(), + Collections.emptySet()); + // @formatter:on + BiFunction> customOidcUserMapper = mock(BiFunction.class); + OidcUser actualUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("a", "b"), this.idToken, + IdTokenClaimNames.SUB); + given(customOidcUserMapper.apply(any(OidcUserRequest.class), isNull())).willReturn(Mono.just(actualUser)); + this.userService.setOidcUserMapper(customOidcUserMapper); + OidcUserRequest userRequest = userRequest(); + OidcUser oidcUser = this.userService.loadUser(userRequest).block(); + assertThat(oidcUser).isNotNull(); + assertThat(oidcUser).isEqualTo(actualUser); + verify(customOidcUserMapper).apply(eq(userRequest), isNull(OidcUserInfo.class)); + } + @Test public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() { OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java index 4b08664a80..f9982c8226 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Predicate; @@ -31,12 +32,14 @@ import okhttp3.mockwebserver.RecordedRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; @@ -49,8 +52,10 @@ import org.springframework.security.oauth2.core.converter.ClaimTypeConverter; import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; import org.springframework.security.oauth2.core.oidc.StandardClaimNames; import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens; +import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; import org.springframework.security.oauth2.core.user.OAuth2User; @@ -60,6 +65,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -140,6 +146,15 @@ public class OidcUserServiceTests { // @formatter:on } + @Test + public void setOidcUserMapperWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.userService.setOidcUserMapper(null)) + .withMessage("oidcUserMapper cannot be null"); + // @formatter:on + } + @Test public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null)); @@ -253,6 +268,37 @@ public class OidcUserServiceTests { assertThat(user.getUserInfo()).isNotNull(); } + @Test + public void loadUserWhenCustomOidcUserMapperSetThenUsed() { + // @formatter:off + String userInfoResponse = "{\n" + + " \"sub\": \"subject1\",\n" + + " \"name\": \"first last\",\n" + + " \"given_name\": \"first\",\n" + + " \"family_name\": \"last\",\n" + + " \"preferred_username\": \"user1\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(userInfoResponse)); + String userInfoUri = this.server.url("/user").toString(); + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); + this.accessToken = TestOAuth2AccessTokens.noScopes(); + BiFunction customOidcUserMapper = mock(BiFunction.class); + OidcUser actualUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("a", "b"), this.idToken, + IdTokenClaimNames.SUB); + given(customOidcUserMapper.apply(any(OidcUserRequest.class), any(OidcUserInfo.class))).willReturn(actualUser); + this.userService.setOidcUserMapper(customOidcUserMapper); + OidcUserRequest userRequest = new OidcUserRequest(clientRegistration, this.accessToken, this.idToken); + OidcUser user = this.userService.loadUser(userRequest); + assertThat(user).isEqualTo(actualUser); + ArgumentCaptor userInfoCaptor = ArgumentCaptor.forClass(OidcUserInfo.class); + verify(customOidcUserMapper).apply(eq(userRequest), userInfoCaptor.capture()); + OidcUserInfo userInfo = userInfoCaptor.getValue(); + assertThat(userInfo.getSubject()).isEqualTo("subject1"); + assertThat(userInfo.getClaimAsString("preferred_username")).isEqualTo("user1"); + } + @Test public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { // @formatter:off