From 11984039c24b67e7b14d37c932397f58db68621c Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Wed, 8 Aug 2018 09:32:09 -0400 Subject: [PATCH] Add OidcUserService.setOauth2UserService() Fixes gh-5604 --- .../client/oidc/userinfo/OidcUserService.java | 21 ++++++++++++++----- .../oidc/userinfo/OidcUserServiceTests.java | 11 +++++++++- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java index ae3c6505a0..f08c5ac200 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java @@ -33,6 +33,7 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.Set; @@ -51,14 +52,14 @@ public class OidcUserService implements OAuth2UserService userInfoScopes = new HashSet<>( Arrays.asList(OidcScopes.PROFILE, OidcScopes.EMAIL, OidcScopes.ADDRESS, OidcScopes.PHONE)); - private final OAuth2UserService defaultUserService = new DefaultOAuth2UserService(); + private OAuth2UserService oauth2UserService = new DefaultOAuth2UserService(); @Override public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException { Assert.notNull(userRequest, "userRequest cannot be null"); OidcUserInfo userInfo = null; if (this.shouldRetrieveUserInfo(userRequest)) { - OAuth2User oauth2User = this.defaultUserService.loadUser(userRequest); + OAuth2User oauth2User = this.oauth2UserService.loadUser(userRequest); userInfo = new OidcUserInfo(oauth2User.getAttributes()); // http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse @@ -81,9 +82,8 @@ public class OidcUserService implements OAuth2UserService authorities = new HashSet<>(); - authorities.add(authority); + Set authorities = Collections.singleton( + new OidcUserAuthority(userRequest.getIdToken(), userInfo)); OidcUser user; @@ -121,4 +121,15 @@ public class OidcUserService implements OAuth2UserService oauth2UserService) { + Assert.notNull(oauth2UserService, "oauth2UserService cannot be null"); + this.oauth2UserService = oauth2UserService; + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java index 1586292af6..e388a800fd 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java @@ -18,7 +18,6 @@ package org.springframework.security.oauth2.client.oidc.userinfo; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; - import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -32,6 +31,7 @@ import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService; import org.springframework.security.oauth2.core.AuthenticationMethod; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; @@ -51,6 +51,7 @@ import java.util.Set; import java.util.concurrent.TimeUnit; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.hamcrest.CoreMatchers.containsString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -96,6 +97,14 @@ public class OidcUserServiceTests { idTokenClaims.put(IdTokenClaimNames.SUB, "subject1"); when(this.idToken.getClaims()).thenReturn(idTokenClaims); when(this.idToken.getSubject()).thenReturn("subject1"); + + this.userService.setOauth2UserService(new DefaultOAuth2UserService()); + } + + @Test + public void setOauth2UserServiceWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.userService.setOauth2UserService(null)) + .isInstanceOf(IllegalArgumentException.class); } @Test