Customize when user info is called

Closes gh-13259
This commit is contained in:
Steve Riesenberg 2024-01-29 14:54:44 -06:00
parent 27b370b534
commit 96e3e4f8b1
No known key found for this signature in database
GPG Key ID: 3D0169B18AB8F0A9
4 changed files with 143 additions and 4 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,6 +22,7 @@ import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import reactor.core.publisher.Mono;
@ -33,6 +34,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
@ -71,6 +73,8 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
private Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory = (
clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER;
private Predicate<OidcUserRequest> retrieveUserInfo = OidcUserRequestUtils::shouldRetrieveUserInfo;
/**
* Returns the default {@link Converter}'s used for type conversion of claim values
* for an {@link OidcUserInfo}.
@ -123,7 +127,7 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
}
private Mono<OidcUserInfo> getUserInfo(OidcUserRequest userRequest) {
if (!OidcUserRequestUtils.shouldRetrieveUserInfo(userRequest)) {
if (!this.retrieveUserInfo.test(userRequest)) {
return Mono.empty();
}
// @formatter:off
@ -169,4 +173,24 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
this.claimTypeConverterFactory = claimTypeConverterFactory;
}
/**
* Sets the {@code Predicate} used to determine if the UserInfo Endpoint should be
* called to retrieve information about the End-User (Resource Owner).
* <p>
* By default, the UserInfo Endpoint is called if all of the following are true:
* <ul>
* <li>The user info endpoint is defined on the ClientRegistration</li>
* <li>The Client Registration uses the
* {@link AuthorizationGrantType#AUTHORIZATION_CODE} and scopes in the access token
* are defined in the {@link ClientRegistration}</li>
* </ul>
* @param retrieveUserInfo the function used to determine if the UserInfo Endpoint
* should be called
* @since 6.3
*/
public final void setRetrieveUserInfo(Predicate<OidcUserRequest> retrieveUserInfo) {
Assert.notNull(retrieveUserInfo, "retrieveUserInfo cannot be null");
this.retrieveUserInfo = retrieveUserInfo;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -24,6 +24,7 @@ import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import org.springframework.core.convert.TypeDescriptor;
import org.springframework.core.convert.converter.Converter;
@ -78,6 +79,8 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
private Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory = (
clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER;
private Predicate<OidcUserRequest> retrieveUserInfo = this::shouldRetrieveUserInfo;
/**
* Returns the default {@link Converter}'s used for type conversion of claim values
* for an {@link OidcUserInfo}.
@ -105,7 +108,7 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
Assert.notNull(userRequest, "userRequest cannot be null");
OidcUserInfo userInfo = null;
if (this.shouldRetrieveUserInfo(userRequest)) {
if (this.retrieveUserInfo.test(userRequest)) {
OAuth2User oauth2User = this.oauth2UserService.loadUser(userRequest);
Map<String, Object> claims = getClaims(userRequest, oauth2User);
userInfo = new OidcUserInfo(claims);
@ -221,10 +224,35 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
* resource will be requested, otherwise it will not.
* @param accessibleScopes the scope(s) that allow access to the user info resource
* @since 5.2
* @deprecated Use {@link #setRetrieveUserInfo(Predicate)} instead
*/
@Deprecated(since = "6.3", forRemoval = true)
public final void setAccessibleScopes(Set<String> accessibleScopes) {
Assert.notNull(accessibleScopes, "accessibleScopes cannot be null");
this.accessibleScopes = accessibleScopes;
}
/**
* Sets the {@code Predicate} used to determine if the UserInfo Endpoint should be
* called to retrieve information about the End-User (Resource Owner).
* <p>
* By default, the UserInfo Endpoint is called if all of the following are true:
* <ul>
* <li>The user info endpoint is defined on the ClientRegistration</li>
* <li>The Client Registration uses the
* {@link AuthorizationGrantType#AUTHORIZATION_CODE}</li>
* <li>The access token contains one or more scopes allowed to access the UserInfo
* Endpoint ({@link OidcScopes#PROFILE profile}, {@link OidcScopes#EMAIL email},
* {@link OidcScopes#ADDRESS address} or {@link OidcScopes#PHONE phone}) or the access
* token scopes are empty</li>
* </ul>
* @param retrieveUserInfo the function used to determine if the UserInfo Endpoint
* should be called
* @since 6.3
*/
public final void setRetrieveUserInfo(Predicate<OidcUserRequest> retrieveUserInfo) {
Assert.notNull(retrieveUserInfo, "retrieveUserInfo cannot be null");
this.retrieveUserInfo = retrieveUserInfo;
}
}

View File

@ -24,6 +24,7 @@ import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Predicate;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
@ -107,6 +108,15 @@ public class OidcReactiveOAuth2UserServiceTests {
assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setClaimTypeConverterFactory(null));
}
@Test
public void setRetrieveUserInfoWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.userService.setRetrieveUserInfo(null))
.withMessage("retrieveUserInfo cannot be null");
// @formatter:on
}
@Test
public void loadUserWhenUserInfoUriNullThenUserInfoNotRetrieved() {
this.registration.userInfoUri(null);
@ -183,6 +193,48 @@ public class OidcReactiveOAuth2UserServiceTests {
verify(customClaimTypeConverterFactory).apply(same(userRequest.getClientRegistration()));
}
@Test
public void loadUserWhenTokenScopesIsEmptyThenUserInfoNotRetrieved() {
// @formatter:off
OAuth2AccessToken accessToken = new OAuth2AccessToken(
this.accessToken.getTokenType(),
this.accessToken.getTokenValue(),
this.accessToken.getIssuedAt(),
this.accessToken.getExpiresAt(),
Collections.emptySet());
// @formatter:on
OidcUserRequest userRequest = new OidcUserRequest(this.registration.build(), accessToken, this.idToken);
OidcUser oidcUser = this.userService.loadUser(userRequest).block();
assertThat(oidcUser).isNotNull();
assertThat(oidcUser.getUserInfo()).isNull();
}
@Test
public void loadUserWhenCustomRetrieveUserInfoSetThenUsed() {
Map<String, Object> attributes = new HashMap<>();
attributes.put(StandardClaimNames.SUB, "subject");
attributes.put("user", "steve");
OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), attributes,
"user");
given(this.oauth2UserService.loadUser(any())).willReturn(Mono.just(oauth2User));
Predicate<OidcUserRequest> customRetrieveUserInfo = mock(Predicate.class);
this.userService.setRetrieveUserInfo(customRetrieveUserInfo);
given(customRetrieveUserInfo.test(any(OidcUserRequest.class))).willReturn(true);
// @formatter:off
OAuth2AccessToken accessToken = new OAuth2AccessToken(
this.accessToken.getTokenType(),
this.accessToken.getTokenValue(),
this.accessToken.getIssuedAt(),
this.accessToken.getExpiresAt(),
Collections.emptySet());
// @formatter:on
OidcUserRequest userRequest = new OidcUserRequest(this.registration.build(), accessToken, this.idToken);
OidcUser oidcUser = this.userService.loadUser(userRequest).block();
assertThat(oidcUser).isNotNull();
assertThat(oidcUser.getUserInfo()).isNotNull();
verify(customRetrieveUserInfo).test(userRequest);
}
@Test
public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() {
OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService();

View File

@ -23,6 +23,7 @@ import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Predicate;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
@ -58,6 +59,7 @@ import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
@ -129,6 +131,15 @@ public class OidcUserServiceTests {
this.userService.setAccessibleScopes(Collections.emptySet());
}
@Test
public void setRetrieveUserInfoWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.userService.setRetrieveUserInfo(null))
.withMessage("retrieveUserInfo cannot be null");
// @formatter:on
}
@Test
public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null));
@ -218,6 +229,30 @@ public class OidcUserServiceTests {
assertThat(user.getUserInfo()).isNotNull();
}
@Test
public void loadUserWhenCustomRetrieveUserInfoSetThenUsed() {
// @formatter:off
String userInfoResponse = "{\n"
+ " \"sub\": \"subject1\",\n"
+ " \"name\": \"first last\",\n"
+ " \"given_name\": \"first\",\n"
+ " \"family_name\": \"last\",\n"
+ " \"preferred_username\": \"user1\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
this.accessToken = TestOAuth2AccessTokens.noScopes();
Predicate<OidcUserRequest> customRetrieveUserInfo = mock(Predicate.class);
given(customRetrieveUserInfo.test(any(OidcUserRequest.class))).willReturn(true);
this.userService.setRetrieveUserInfo(customRetrieveUserInfo);
OidcUser user = this.userService
.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
assertThat(user.getUserInfo()).isNotNull();
}
@Test
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
// @formatter:off