diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java index fadc5f8b92..5893d799d8 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java @@ -18,10 +18,12 @@ package org.springframework.security.oauth2.client.oidc.userinfo; import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.converter.ClaimConversionService; @@ -99,6 +101,10 @@ public class OidcReactiveOAuth2UserService implements OidcUserInfo userInfo = authority.getUserInfo(); Set authorities = new HashSet<>(); authorities.add(authority); + OAuth2AccessToken token = userRequest.getAccessToken(); + for (String scope : token.getScopes()) { + authorities.add(new SimpleGrantedAuthority("SCOPE_" + scope)); + } String userNameAttributeName = userRequest.getClientRegistration() .getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName(); if (StringUtils.hasText(userNameAttributeName)) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java index ec374fbffd..3fb35f5ada 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java @@ -17,8 +17,6 @@ package org.springframework.security.oauth2.client.oidc.userinfo; import java.time.Instant; import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashSet; @@ -29,11 +27,13 @@ import java.util.function.Function; import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.converter.ClaimConversionService; @@ -96,7 +96,6 @@ public class OidcUserService implements OAuth2UserService oauth2UserAuthorities = Collections.emptyList(); if (this.shouldRetrieveUserInfo(userRequest)) { OAuth2User oauth2User = this.oauth2UserService.loadUser(userRequest); @@ -109,7 +108,6 @@ public class OidcUserService implements OAuth2UserService authorities = new LinkedHashSet<>(); authorities.add(new OidcUserAuthority(userRequest.getIdToken(), userInfo)); - authorities.addAll(oauth2UserAuthorities); + OAuth2AccessToken token = userRequest.getAccessToken(); + for (String authority : token.getScopes()) { + authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); + } OidcUser user; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java index c274ebd816..0ca84ca6d4 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java @@ -15,9 +15,6 @@ */ package org.springframework.security.oauth2.client.userinfo; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; @@ -30,7 +27,7 @@ import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.core.ClaimAccessor; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -71,9 +68,6 @@ public class DefaultOAuth2UserService implements OAuth2UserService> PARAMETERIZED_RESPONSE_TYPE = new ParameterizedTypeReference>() {}; - private static final Collection WELL_KNOWN_AUTHORITIES_CLAIM_NAMES = - Arrays.asList("scope", "scp"); - private Converter> requestEntityConverter = new OAuth2UserRequestEntityConverter(); private RestOperations restOperations; @@ -137,7 +131,8 @@ public class DefaultOAuth2UserService implements OAuth2UserService userAttributes = response.getBody(); Set authorities = new LinkedHashSet<>(); authorities.add(new OAuth2UserAuthority(userAttributes)); - for (String authority : getAuthorities(() -> userAttributes)) { + OAuth2AccessToken token = userRequest.getAccessToken(); + for (String authority : token.getScopes()) { authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); } @@ -172,34 +167,4 @@ public class DefaultOAuth2UserService implements OAuth2UserService getAuthorities(ClaimAccessor claims) { - String claimName = getAuthoritiesClaimName(claims); - - if (claimName == null) { - return Collections.emptyList(); - } - - Object authorities = claims.getClaim(claimName); - if (authorities instanceof String) { - if (StringUtils.hasText((String) authorities)) { - return Arrays.asList(((String) authorities).split(" ")); - } else { - return Collections.emptyList(); - } - } else if (authorities instanceof Collection) { - return (Collection) authorities; - } - - return Collections.emptyList(); - } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java index b808cb8068..108d837240 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java @@ -28,7 +28,9 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.core.AuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.user.DefaultOAuth2User; @@ -131,6 +133,10 @@ public class DefaultReactiveOAuth2UserService implements ReactiveOAuth2UserServi GrantedAuthority authority = new OAuth2UserAuthority(attrs); Set authorities = new HashSet<>(); authorities.add(authority); + OAuth2AccessToken token = userRequest.getAccessToken(); + for (String scope : token.getScopes()) { + authorities.add(new SimpleGrantedAuthority("SCOPE_" + scope)); + } return new DefaultOAuth2User(authorities, attrs, userNameAttributeName); }) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java index a713515b69..0876c48247 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java @@ -16,13 +16,25 @@ package org.springframework.security.oauth2.client.oidc.userinfo; +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import reactor.core.publisher.Mono; + import org.springframework.core.convert.converter.Converter; +import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; @@ -36,17 +48,20 @@ import org.springframework.security.oauth2.core.oidc.StandardClaimNames; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.user.DefaultOAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User; -import reactor.core.publisher.Mono; +import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; -import java.time.Duration; -import java.time.Instant; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; - -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.same; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; +import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken; /** * @author Rob Winch @@ -178,6 +193,38 @@ public class OidcReactiveOAuth2UserServiceTests { verify(customClaimTypeConverterFactory).apply(same(userRequest.getClientRegistration())); } + @Test + public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + body.put("sub", "test-subject"); + OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService(); + OidcUserRequest request = new OidcUserRequest( + clientRegistration().build(), scopes("message:read", "message:write"), idToken(body)); + OidcUser user = userService.loadUser(request).block(); + + assertThat(user.getAuthorities()).hasSize(3); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read")); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write")); + } + + @Test + public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + body.put("sub", "test-subject"); + OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService(); + OidcUserRequest request = new OidcUserRequest( + clientRegistration().build(), noScopes(), idToken(body)); + OidcUser user = userService.loadUser(request).block(); + + assertThat(user.getAuthorities()).hasSize(1); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); + } + private OidcUserRequest userRequest() { return new OidcUserRequest(this.registration.build(), this.accessToken, this.idToken); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java index 473666e70f..6b414375ec 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java @@ -16,7 +16,6 @@ package org.springframework.security.oauth2.client.oidc.userinfo; import java.time.Instant; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; @@ -33,19 +32,14 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; -import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; -import org.springframework.http.RequestEntity; -import org.springframework.http.ResponseEntity; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService; -import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.core.AuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; @@ -56,18 +50,16 @@ import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.StandardClaimNames; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; -import org.springframework.web.client.RestOperations; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.hamcrest.CoreMatchers.containsString; -import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.nullable; import static org.mockito.Mockito.same; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken; @@ -272,7 +264,7 @@ public class OidcUserServiceTests { assertThat(user.getUserInfo().getPreferredUsername()).isEqualTo("user1"); assertThat(user.getUserInfo().getEmail()).isEqualTo("user1@example.com"); - assertThat(user.getAuthorities().size()).isEqualTo(1); + assertThat(user.getAuthorities().size()).isEqualTo(3); assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OidcUserAuthority.class); OidcUserAuthority userAuthority = (OidcUserAuthority) user.getAuthorities().iterator().next(); assertThat(userAuthority.getAuthority()).isEqualTo("ROLE_USER"); @@ -499,15 +491,13 @@ public class OidcUserServiceTests { } @Test - public void loadUserWhenAttributesContainScopeThenIndividualScopeAuthorities() { + public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() { Map body = new HashMap<>(); body.put("id", "id"); body.put("sub", "test-subject"); - body.put("scope", "message:read message:write"); OidcUserService userService = new OidcUserService(); - userService.setOauth2UserService(withMockResponse(body)); - OidcUserRequest request = new OidcUserRequest(clientRegistration(). - userInfoUri("uri").build(), scopes("profile"), idToken(body)); + OidcUserRequest request = new OidcUserRequest(clientRegistration().build(), + scopes("message:read", "message:write"), idToken(body)); OidcUser user = userService.loadUser(request); assertThat(user.getAuthorities()).hasSize(3); @@ -518,34 +508,13 @@ public class OidcUserServiceTests { } @Test - public void loadUserWhenAttributesContainScpThenIndividualScopeAuthorities() { + public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() { Map body = new HashMap<>(); body.put("id", "id"); body.put("sub", "test-subject"); - body.put("scp", Arrays.asList("message:read", "message:write")); OidcUserService userService = new OidcUserService(); - userService.setOauth2UserService(withMockResponse(body)); - OidcUserRequest request = new OidcUserRequest(clientRegistration(). - userInfoUri("uri").build(), scopes("profile"), idToken(body)); - OidcUser user = userService.loadUser(request); - - assertThat(user.getAuthorities()).hasSize(3); - Iterator authorities = user.getAuthorities().iterator(); - assertThat(authorities.next()).isInstanceOf(OidcUserAuthority.class); - assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read")); - assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write")); - } - - @Test - public void loadUserWhenAttributesDoesNotContainScopesThenNoScopeAuthorities() { - Map body = new HashMap<>(); - body.put("id", "id"); - body.put("sub", "test-subject"); - body.put("authorities", Arrays.asList("message:read", "message:write")); - OidcUserService userService = new OidcUserService(); - userService.setOauth2UserService(withMockResponse(body)); - OidcUserRequest request = new OidcUserRequest(clientRegistration(). - userInfoUri("uri").build(), scopes("profile"), idToken(body)); + OidcUserRequest request = new OidcUserRequest(clientRegistration().build(), + noScopes(), idToken(body)); OidcUser user = userService.loadUser(request); assertThat(user.getAuthorities()).hasSize(1); @@ -553,18 +522,6 @@ public class OidcUserServiceTests { assertThat(authorities.next()).isInstanceOf(OidcUserAuthority.class); } - private DefaultOAuth2UserService withMockResponse(Map response) { - ResponseEntity> responseEntity = new ResponseEntity<>(response, HttpStatus.OK); - Converter> requestEntityConverter = mock(Converter.class); - RestOperations rest = mock(RestOperations.class); - when(rest.exchange(nullable(RequestEntity.class), any(ParameterizedTypeReference.class))) - .thenReturn(responseEntity); - DefaultOAuth2UserService userService = new DefaultOAuth2UserService(); - userService.setRequestEntityConverter(requestEntityConverter); - userService.setRestOperations(rest); - return userService; - } - private MockResponse jsonResponse(String json) { return new MockResponse() .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java index 346ac21dfc..f42e167b83 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java @@ -15,7 +15,6 @@ */ package org.springframework.security.oauth2.client.userinfo; -import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.Map; @@ -56,6 +55,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; /** * Tests for {@link DefaultOAuth2UserService}. @@ -342,12 +342,12 @@ public class DefaultOAuth2UserServiceTests { } @Test - public void loadUserWhenAttributesContainScopeThenIndividualScopeAuthorities() { + public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() { Map body = new HashMap<>(); body.put("id", "id"); - body.put("scope", "message:read message:write"); DefaultOAuth2UserService userService = withMockResponse(body); - OAuth2UserRequest request = new OAuth2UserRequest(clientRegistration().build(), noScopes()); + OAuth2UserRequest request = new OAuth2UserRequest( + clientRegistration().build(), scopes("message:read", "message:write")); OAuth2User user = userService.loadUser(request); assertThat(user.getAuthorities()).hasSize(3); @@ -358,28 +358,12 @@ public class DefaultOAuth2UserServiceTests { } @Test - public void loadUserWhenAttributesContainScpThenIndividualScopeAuthorities() { + public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() { Map body = new HashMap<>(); body.put("id", "id"); - body.put("scp", Arrays.asList("message:read", "message:write")); DefaultOAuth2UserService userService = withMockResponse(body); - OAuth2UserRequest request = new OAuth2UserRequest(clientRegistration().build(), noScopes()); - OAuth2User user = userService.loadUser(request); - - assertThat(user.getAuthorities()).hasSize(3); - Iterator authorities = user.getAuthorities().iterator(); - assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); - assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read")); - assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write")); - } - - @Test - public void loadUserWhenAttributesDoesNotContainScopesThenNoScopeAuthorities() { - Map body = new HashMap<>(); - body.put("id", "id"); - body.put("authorities", Arrays.asList("message:read", "message:write")); - DefaultOAuth2UserService userService = withMockResponse(body); - OAuth2UserRequest request = new OAuth2UserRequest(clientRegistration().build(), noScopes()); + OAuth2UserRequest request = new OAuth2UserRequest( + clientRegistration().build(), noScopes()); OAuth2User user = userService.loadUser(request); assertThat(user.getAuthorities()).hasSize(1); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java index b856c80252..8d1725b37c 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java @@ -16,15 +16,30 @@ package org.springframework.security.oauth2.client.userinfo; +import java.time.Duration; +import java.time.Instant; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.function.Function; +import java.util.function.Predicate; + import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; import org.junit.After; import org.junit.Before; import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.security.authentication.AuthenticationServiceException; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthenticationMethod; @@ -32,14 +47,17 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; +import org.springframework.web.reactive.function.client.WebClient; -import okhttp3.mockwebserver.RecordedRequest; -import reactor.test.StepVerifier; - -import java.time.Duration; -import java.time.Instant; - -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; /** * @author Rob Winch @@ -211,6 +229,53 @@ public class DefaultReactiveOAuth2UserServiceTests { .isInstanceOf(AuthenticationServiceException.class); } + @Test + public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + DefaultReactiveOAuth2UserService userService = withMockResponse(body); + OAuth2UserRequest request = new OAuth2UserRequest( + clientRegistration().build(), scopes("message:read", "message:write")); + OAuth2User user = userService.loadUser(request).block(); + + assertThat(user.getAuthorities()).hasSize(3); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read")); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write")); + } + + @Test + public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + DefaultReactiveOAuth2UserService userService = withMockResponse(body); + OAuth2UserRequest request = new OAuth2UserRequest( + clientRegistration().build(), noScopes()); + OAuth2User user = userService.loadUser(request).block(); + + assertThat(user.getAuthorities()).hasSize(1); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); + } + + private DefaultReactiveOAuth2UserService withMockResponse(Map body) { + WebClient real = WebClient.builder().build(); + WebClient.RequestHeadersUriSpec spec = spy(real.post()); + WebClient rest = spy(WebClient.class); + WebClient.ResponseSpec clientResponse = mock(WebClient.ResponseSpec.class); + when(rest.get()).thenReturn(spec); + when(spec.retrieve()).thenReturn(clientResponse); + when(clientResponse.onStatus(any(Predicate.class), any(Function.class))) + .thenReturn(clientResponse); + when(clientResponse.bodyToMono(any(ParameterizedTypeReference.class))) + .thenReturn(Mono.just(body)); + + DefaultReactiveOAuth2UserService userService = new DefaultReactiveOAuth2UserService(); + userService.setWebClient(rest); + return userService; + } + private OAuth2UserRequest oauth2UserRequest() { return new OAuth2UserRequest(this.clientRegistration.build(), this.accessToken); }