Grant Individual Authorities From Claims

Fixes gh-7339
This commit is contained in:
Josh Cummings 2019-09-04 04:34:44 -06:00
parent 409285fb3d
commit aa1c80c801
5 changed files with 268 additions and 24 deletions

View File

@ -15,6 +15,17 @@
*/
package org.springframework.security.oauth2.client.oidc.userinfo;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import org.springframework.core.convert.TypeDescriptor;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.core.GrantedAuthority;
@ -38,15 +49,6 @@ import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
/**
* An implementation of an {@link OAuth2UserService} that supports OpenID Connect 1.0 Provider's.
*
@ -94,6 +96,7 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
Assert.notNull(userRequest, "userRequest cannot be null");
OidcUserInfo userInfo = null;
Collection<? extends GrantedAuthority> oauth2UserAuthorities = Collections.emptyList();
if (this.shouldRetrieveUserInfo(userRequest)) {
OAuth2User oauth2User = this.oauth2UserService.loadUser(userRequest);
@ -106,6 +109,7 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
claims = DEFAULT_CLAIM_TYPE_CONVERTER.convert(oauth2User.getAttributes());
}
userInfo = new OidcUserInfo(claims);
oauth2UserAuthorities = oauth2User.getAuthorities();
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
@ -127,8 +131,9 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
}
}
Set<GrantedAuthority> authorities = Collections.singleton(
new OidcUserAuthority(userRequest.getIdToken(), userInfo));
Set<GrantedAuthority> authorities = new LinkedHashSet<>();
authorities.add(new OidcUserAuthority(userRequest.getIdToken(), userInfo));
authorities.addAll(oauth2UserAuthorities);
OidcUser user;

View File

@ -15,13 +15,22 @@
*/
package org.springframework.security.oauth2.client.userinfo;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClaimAccessor;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
@ -35,10 +44,6 @@ import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
/**
* An implementation of an {@link OAuth2UserService} that supports standard OAuth 2.0 Provider's.
* <p>
@ -66,6 +71,9 @@ public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserReq
private static final ParameterizedTypeReference<Map<String, Object>> PARAMETERIZED_RESPONSE_TYPE =
new ParameterizedTypeReference<Map<String, Object>>() {};
private static final Collection<String> WELL_KNOWN_AUTHORITIES_CLAIM_NAMES =
Arrays.asList("scope", "scp");
private Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter = new OAuth2UserRequestEntityConverter();
private RestOperations restOperations;
@ -127,7 +135,11 @@ public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserReq
}
Map<String, Object> userAttributes = response.getBody();
Set<GrantedAuthority> authorities = Collections.singleton(new OAuth2UserAuthority(userAttributes));
Set<GrantedAuthority> authorities = new LinkedHashSet<>();
authorities.add(new OAuth2UserAuthority(userAttributes));
for (String authority : getAuthorities(() -> userAttributes)) {
authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
}
return new DefaultOAuth2User(authorities, userAttributes, userNameAttributeName);
}
@ -160,4 +172,34 @@ public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserReq
Assert.notNull(restOperations, "restOperations cannot be null");
this.restOperations = restOperations;
}
private String getAuthoritiesClaimName(ClaimAccessor claims) {
for (String claimName : WELL_KNOWN_AUTHORITIES_CLAIM_NAMES) {
if (claims.containsClaim(claimName)) {
return claimName;
}
}
return null;
}
private Collection<String> getAuthorities(ClaimAccessor claims) {
String claimName = getAuthoritiesClaimName(claims);
if (claimName == null) {
return Collections.emptyList();
}
Object authorities = claims.getClaim(claimName);
if (authorities instanceof String) {
if (StringUtils.hasText((String) authorities)) {
return Arrays.asList(((String) authorities).split(" "));
} else {
return Collections.emptyList();
}
} else if (authorities instanceof Collection) {
return (Collection<String>) authorities;
}
return Collections.emptyList();
}
}

View File

@ -15,6 +15,15 @@
*/
package org.springframework.security.oauth2.client.oidc.userinfo;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
@ -23,12 +32,20 @@ import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.core.AuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
@ -39,20 +56,20 @@ import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority;
import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.springframework.web.client.RestOperations;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.nullable;
import static org.mockito.Mockito.same;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes;
import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken;
/**
* Tests for {@link OidcUserService}.
@ -481,6 +498,73 @@ public class OidcUserServiceTests {
verify(customClaimTypeConverterFactory).apply(same(clientRegistration));
}
@Test
public void loadUserWhenAttributesContainScopeThenIndividualScopeAuthorities() {
Map<String, Object> body = new HashMap<>();
body.put("id", "id");
body.put("sub", "test-subject");
body.put("scope", "message:read message:write");
OidcUserService userService = new OidcUserService();
userService.setOauth2UserService(withMockResponse(body));
OidcUserRequest request = new OidcUserRequest(clientRegistration().
userInfoUri("uri").build(), scopes("profile"), idToken(body));
OidcUser user = userService.loadUser(request);
assertThat(user.getAuthorities()).hasSize(3);
Iterator<? extends GrantedAuthority> authorities = user.getAuthorities().iterator();
assertThat(authorities.next()).isInstanceOf(OidcUserAuthority.class);
assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read"));
assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write"));
}
@Test
public void loadUserWhenAttributesContainScpThenIndividualScopeAuthorities() {
Map<String, Object> body = new HashMap<>();
body.put("id", "id");
body.put("sub", "test-subject");
body.put("scp", Arrays.asList("message:read", "message:write"));
OidcUserService userService = new OidcUserService();
userService.setOauth2UserService(withMockResponse(body));
OidcUserRequest request = new OidcUserRequest(clientRegistration().
userInfoUri("uri").build(), scopes("profile"), idToken(body));
OidcUser user = userService.loadUser(request);
assertThat(user.getAuthorities()).hasSize(3);
Iterator<? extends GrantedAuthority> authorities = user.getAuthorities().iterator();
assertThat(authorities.next()).isInstanceOf(OidcUserAuthority.class);
assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read"));
assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write"));
}
@Test
public void loadUserWhenAttributesDoesNotContainScopesThenNoScopeAuthorities() {
Map<String, Object> body = new HashMap<>();
body.put("id", "id");
body.put("sub", "test-subject");
body.put("authorities", Arrays.asList("message:read", "message:write"));
OidcUserService userService = new OidcUserService();
userService.setOauth2UserService(withMockResponse(body));
OidcUserRequest request = new OidcUserRequest(clientRegistration().
userInfoUri("uri").build(), scopes("profile"), idToken(body));
OidcUser user = userService.loadUser(request);
assertThat(user.getAuthorities()).hasSize(1);
Iterator<? extends GrantedAuthority> authorities = user.getAuthorities().iterator();
assertThat(authorities.next()).isInstanceOf(OidcUserAuthority.class);
}
private DefaultOAuth2UserService withMockResponse(Map<String, Object> response) {
ResponseEntity<Map<String, Object>> responseEntity = new ResponseEntity<>(response, HttpStatus.OK);
Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter = mock(Converter.class);
RestOperations rest = mock(RestOperations.class);
when(rest.exchange(nullable(RequestEntity.class), any(ParameterizedTypeReference.class)))
.thenReturn(responseEntity);
DefaultOAuth2UserService userService = new DefaultOAuth2UserService();
userService.setRequestEntityConverter(requestEntityConverter);
userService.setRestOperations(rest);
return userService;
}
private MockResponse jsonResponse(String json) {
return new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)

View File

@ -15,6 +15,10 @@
*/
package org.springframework.security.oauth2.client.userinfo;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import okhttp3.mockwebserver.MockResponse;
@ -26,18 +30,30 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
import org.springframework.web.client.RestOperations;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
@ -325,6 +341,64 @@ public class DefaultOAuth2UserServiceTests {
assertThat(request.getBody().readUtf8()).isEqualTo("access_token=" + this.accessToken.getTokenValue());
}
@Test
public void loadUserWhenAttributesContainScopeThenIndividualScopeAuthorities() {
Map<String, Object> body = new HashMap<>();
body.put("id", "id");
body.put("scope", "message:read message:write");
DefaultOAuth2UserService userService = withMockResponse(body);
OAuth2UserRequest request = new OAuth2UserRequest(clientRegistration().build(), noScopes());
OAuth2User user = userService.loadUser(request);
assertThat(user.getAuthorities()).hasSize(3);
Iterator<? extends GrantedAuthority> authorities = user.getAuthorities().iterator();
assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class);
assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read"));
assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write"));
}
@Test
public void loadUserWhenAttributesContainScpThenIndividualScopeAuthorities() {
Map<String, Object> body = new HashMap<>();
body.put("id", "id");
body.put("scp", Arrays.asList("message:read", "message:write"));
DefaultOAuth2UserService userService = withMockResponse(body);
OAuth2UserRequest request = new OAuth2UserRequest(clientRegistration().build(), noScopes());
OAuth2User user = userService.loadUser(request);
assertThat(user.getAuthorities()).hasSize(3);
Iterator<? extends GrantedAuthority> authorities = user.getAuthorities().iterator();
assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class);
assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read"));
assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write"));
}
@Test
public void loadUserWhenAttributesDoesNotContainScopesThenNoScopeAuthorities() {
Map<String, Object> body = new HashMap<>();
body.put("id", "id");
body.put("authorities", Arrays.asList("message:read", "message:write"));
DefaultOAuth2UserService userService = withMockResponse(body);
OAuth2UserRequest request = new OAuth2UserRequest(clientRegistration().build(), noScopes());
OAuth2User user = userService.loadUser(request);
assertThat(user.getAuthorities()).hasSize(1);
Iterator<? extends GrantedAuthority> authorities = user.getAuthorities().iterator();
assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class);
}
private DefaultOAuth2UserService withMockResponse(Map<String, Object> response) {
ResponseEntity<Map<String, Object>> responseEntity = new ResponseEntity<>(response, HttpStatus.OK);
Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter = mock(Converter.class);
RestOperations rest = mock(RestOperations.class);
when(rest.exchange(nullable(RequestEntity.class), any(ParameterizedTypeReference.class)))
.thenReturn(responseEntity);
DefaultOAuth2UserService userService = new DefaultOAuth2UserService();
userService.setRequestEntityConverter(requestEntityConverter);
userService.setRestOperations(rest);
return userService;
}
private MockResponse jsonResponse(String json) {
return new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)

View File

@ -0,0 +1,39 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.core.oidc;
import java.time.Instant;
import java.util.Collections;
import java.util.Map;
/**
* Test {@link OidcIdToken}s
*
* @author Josh Cummings
*/
public class TestOidcIdTokens {
public static OidcIdToken idToken() {
return idToken(Collections.singletonMap("id", "id"));
}
public static OidcIdToken idToken(Map<String, Object> claims) {
return new OidcIdToken("token",
Instant.now(),
Instant.now().plusSeconds(86400),
claims);
}
}