mirror of
https://github.com/spring-projects/spring-security.git
synced 2025-06-29 07:12:32 +00:00
Customize when user info is called
Closes gh-13259
This commit is contained in:
parent
27b370b534
commit
96e3e4f8b1
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
* Copyright 2002-2019 the original author or authors.
|
* Copyright 2002-2024 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
@ -22,6 +22,7 @@ import java.util.HashSet;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.function.Function;
|
import java.util.function.Function;
|
||||||
|
import java.util.function.Predicate;
|
||||||
|
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
|
|
||||||
@ -33,6 +34,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
|
|||||||
import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService;
|
import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService;
|
||||||
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
|
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
|
||||||
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
|
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
|
||||||
|
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||||
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
||||||
import org.springframework.security.oauth2.core.OAuth2Error;
|
import org.springframework.security.oauth2.core.OAuth2Error;
|
||||||
@ -71,6 +73,8 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
|
|||||||
private Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory = (
|
private Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory = (
|
||||||
clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER;
|
clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER;
|
||||||
|
|
||||||
|
private Predicate<OidcUserRequest> retrieveUserInfo = OidcUserRequestUtils::shouldRetrieveUserInfo;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the default {@link Converter}'s used for type conversion of claim values
|
* Returns the default {@link Converter}'s used for type conversion of claim values
|
||||||
* for an {@link OidcUserInfo}.
|
* for an {@link OidcUserInfo}.
|
||||||
@ -123,7 +127,7 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
|
|||||||
}
|
}
|
||||||
|
|
||||||
private Mono<OidcUserInfo> getUserInfo(OidcUserRequest userRequest) {
|
private Mono<OidcUserInfo> getUserInfo(OidcUserRequest userRequest) {
|
||||||
if (!OidcUserRequestUtils.shouldRetrieveUserInfo(userRequest)) {
|
if (!this.retrieveUserInfo.test(userRequest)) {
|
||||||
return Mono.empty();
|
return Mono.empty();
|
||||||
}
|
}
|
||||||
// @formatter:off
|
// @formatter:off
|
||||||
@ -169,4 +173,24 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
|
|||||||
this.claimTypeConverterFactory = claimTypeConverterFactory;
|
this.claimTypeConverterFactory = claimTypeConverterFactory;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the {@code Predicate} used to determine if the UserInfo Endpoint should be
|
||||||
|
* called to retrieve information about the End-User (Resource Owner).
|
||||||
|
* <p>
|
||||||
|
* By default, the UserInfo Endpoint is called if all of the following are true:
|
||||||
|
* <ul>
|
||||||
|
* <li>The user info endpoint is defined on the ClientRegistration</li>
|
||||||
|
* <li>The Client Registration uses the
|
||||||
|
* {@link AuthorizationGrantType#AUTHORIZATION_CODE} and scopes in the access token
|
||||||
|
* are defined in the {@link ClientRegistration}</li>
|
||||||
|
* </ul>
|
||||||
|
* @param retrieveUserInfo the function used to determine if the UserInfo Endpoint
|
||||||
|
* should be called
|
||||||
|
* @since 6.3
|
||||||
|
*/
|
||||||
|
public final void setRetrieveUserInfo(Predicate<OidcUserRequest> retrieveUserInfo) {
|
||||||
|
Assert.notNull(retrieveUserInfo, "retrieveUserInfo cannot be null");
|
||||||
|
this.retrieveUserInfo = retrieveUserInfo;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
* Copyright 2002-2019 the original author or authors.
|
* Copyright 2002-2024 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
@ -24,6 +24,7 @@ import java.util.LinkedHashSet;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.function.Function;
|
import java.util.function.Function;
|
||||||
|
import java.util.function.Predicate;
|
||||||
|
|
||||||
import org.springframework.core.convert.TypeDescriptor;
|
import org.springframework.core.convert.TypeDescriptor;
|
||||||
import org.springframework.core.convert.converter.Converter;
|
import org.springframework.core.convert.converter.Converter;
|
||||||
@ -78,6 +79,8 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
|
|||||||
private Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory = (
|
private Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory = (
|
||||||
clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER;
|
clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER;
|
||||||
|
|
||||||
|
private Predicate<OidcUserRequest> retrieveUserInfo = this::shouldRetrieveUserInfo;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the default {@link Converter}'s used for type conversion of claim values
|
* Returns the default {@link Converter}'s used for type conversion of claim values
|
||||||
* for an {@link OidcUserInfo}.
|
* for an {@link OidcUserInfo}.
|
||||||
@ -105,7 +108,7 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
|
|||||||
public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
|
public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
|
||||||
Assert.notNull(userRequest, "userRequest cannot be null");
|
Assert.notNull(userRequest, "userRequest cannot be null");
|
||||||
OidcUserInfo userInfo = null;
|
OidcUserInfo userInfo = null;
|
||||||
if (this.shouldRetrieveUserInfo(userRequest)) {
|
if (this.retrieveUserInfo.test(userRequest)) {
|
||||||
OAuth2User oauth2User = this.oauth2UserService.loadUser(userRequest);
|
OAuth2User oauth2User = this.oauth2UserService.loadUser(userRequest);
|
||||||
Map<String, Object> claims = getClaims(userRequest, oauth2User);
|
Map<String, Object> claims = getClaims(userRequest, oauth2User);
|
||||||
userInfo = new OidcUserInfo(claims);
|
userInfo = new OidcUserInfo(claims);
|
||||||
@ -221,10 +224,35 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
|
|||||||
* resource will be requested, otherwise it will not.
|
* resource will be requested, otherwise it will not.
|
||||||
* @param accessibleScopes the scope(s) that allow access to the user info resource
|
* @param accessibleScopes the scope(s) that allow access to the user info resource
|
||||||
* @since 5.2
|
* @since 5.2
|
||||||
|
* @deprecated Use {@link #setRetrieveUserInfo(Predicate)} instead
|
||||||
*/
|
*/
|
||||||
|
@Deprecated(since = "6.3", forRemoval = true)
|
||||||
public final void setAccessibleScopes(Set<String> accessibleScopes) {
|
public final void setAccessibleScopes(Set<String> accessibleScopes) {
|
||||||
Assert.notNull(accessibleScopes, "accessibleScopes cannot be null");
|
Assert.notNull(accessibleScopes, "accessibleScopes cannot be null");
|
||||||
this.accessibleScopes = accessibleScopes;
|
this.accessibleScopes = accessibleScopes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the {@code Predicate} used to determine if the UserInfo Endpoint should be
|
||||||
|
* called to retrieve information about the End-User (Resource Owner).
|
||||||
|
* <p>
|
||||||
|
* By default, the UserInfo Endpoint is called if all of the following are true:
|
||||||
|
* <ul>
|
||||||
|
* <li>The user info endpoint is defined on the ClientRegistration</li>
|
||||||
|
* <li>The Client Registration uses the
|
||||||
|
* {@link AuthorizationGrantType#AUTHORIZATION_CODE}</li>
|
||||||
|
* <li>The access token contains one or more scopes allowed to access the UserInfo
|
||||||
|
* Endpoint ({@link OidcScopes#PROFILE profile}, {@link OidcScopes#EMAIL email},
|
||||||
|
* {@link OidcScopes#ADDRESS address} or {@link OidcScopes#PHONE phone}) or the access
|
||||||
|
* token scopes are empty</li>
|
||||||
|
* </ul>
|
||||||
|
* @param retrieveUserInfo the function used to determine if the UserInfo Endpoint
|
||||||
|
* should be called
|
||||||
|
* @since 6.3
|
||||||
|
*/
|
||||||
|
public final void setRetrieveUserInfo(Predicate<OidcUserRequest> retrieveUserInfo) {
|
||||||
|
Assert.notNull(retrieveUserInfo, "retrieveUserInfo cannot be null");
|
||||||
|
this.retrieveUserInfo = retrieveUserInfo;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,7 @@ import java.util.HashMap;
|
|||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.function.Function;
|
import java.util.function.Function;
|
||||||
|
import java.util.function.Predicate;
|
||||||
|
|
||||||
import okhttp3.mockwebserver.MockResponse;
|
import okhttp3.mockwebserver.MockResponse;
|
||||||
import okhttp3.mockwebserver.MockWebServer;
|
import okhttp3.mockwebserver.MockWebServer;
|
||||||
@ -107,6 +108,15 @@ public class OidcReactiveOAuth2UserServiceTests {
|
|||||||
assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setClaimTypeConverterFactory(null));
|
assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setClaimTypeConverterFactory(null));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void setRetrieveUserInfoWhenNullThenThrowIllegalArgumentException() {
|
||||||
|
// @formatter:off
|
||||||
|
assertThatIllegalArgumentException()
|
||||||
|
.isThrownBy(() -> this.userService.setRetrieveUserInfo(null))
|
||||||
|
.withMessage("retrieveUserInfo cannot be null");
|
||||||
|
// @formatter:on
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void loadUserWhenUserInfoUriNullThenUserInfoNotRetrieved() {
|
public void loadUserWhenUserInfoUriNullThenUserInfoNotRetrieved() {
|
||||||
this.registration.userInfoUri(null);
|
this.registration.userInfoUri(null);
|
||||||
@ -183,6 +193,48 @@ public class OidcReactiveOAuth2UserServiceTests {
|
|||||||
verify(customClaimTypeConverterFactory).apply(same(userRequest.getClientRegistration()));
|
verify(customClaimTypeConverterFactory).apply(same(userRequest.getClientRegistration()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void loadUserWhenTokenScopesIsEmptyThenUserInfoNotRetrieved() {
|
||||||
|
// @formatter:off
|
||||||
|
OAuth2AccessToken accessToken = new OAuth2AccessToken(
|
||||||
|
this.accessToken.getTokenType(),
|
||||||
|
this.accessToken.getTokenValue(),
|
||||||
|
this.accessToken.getIssuedAt(),
|
||||||
|
this.accessToken.getExpiresAt(),
|
||||||
|
Collections.emptySet());
|
||||||
|
// @formatter:on
|
||||||
|
OidcUserRequest userRequest = new OidcUserRequest(this.registration.build(), accessToken, this.idToken);
|
||||||
|
OidcUser oidcUser = this.userService.loadUser(userRequest).block();
|
||||||
|
assertThat(oidcUser).isNotNull();
|
||||||
|
assertThat(oidcUser.getUserInfo()).isNull();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void loadUserWhenCustomRetrieveUserInfoSetThenUsed() {
|
||||||
|
Map<String, Object> attributes = new HashMap<>();
|
||||||
|
attributes.put(StandardClaimNames.SUB, "subject");
|
||||||
|
attributes.put("user", "steve");
|
||||||
|
OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), attributes,
|
||||||
|
"user");
|
||||||
|
given(this.oauth2UserService.loadUser(any())).willReturn(Mono.just(oauth2User));
|
||||||
|
Predicate<OidcUserRequest> customRetrieveUserInfo = mock(Predicate.class);
|
||||||
|
this.userService.setRetrieveUserInfo(customRetrieveUserInfo);
|
||||||
|
given(customRetrieveUserInfo.test(any(OidcUserRequest.class))).willReturn(true);
|
||||||
|
// @formatter:off
|
||||||
|
OAuth2AccessToken accessToken = new OAuth2AccessToken(
|
||||||
|
this.accessToken.getTokenType(),
|
||||||
|
this.accessToken.getTokenValue(),
|
||||||
|
this.accessToken.getIssuedAt(),
|
||||||
|
this.accessToken.getExpiresAt(),
|
||||||
|
Collections.emptySet());
|
||||||
|
// @formatter:on
|
||||||
|
OidcUserRequest userRequest = new OidcUserRequest(this.registration.build(), accessToken, this.idToken);
|
||||||
|
OidcUser oidcUser = this.userService.loadUser(userRequest).block();
|
||||||
|
assertThat(oidcUser).isNotNull();
|
||||||
|
assertThat(oidcUser.getUserInfo()).isNotNull();
|
||||||
|
verify(customRetrieveUserInfo).test(userRequest);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() {
|
public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() {
|
||||||
OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService();
|
OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService();
|
||||||
|
@ -23,6 +23,7 @@ import java.util.Iterator;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
import java.util.function.Function;
|
import java.util.function.Function;
|
||||||
|
import java.util.function.Predicate;
|
||||||
|
|
||||||
import okhttp3.mockwebserver.MockResponse;
|
import okhttp3.mockwebserver.MockResponse;
|
||||||
import okhttp3.mockwebserver.MockWebServer;
|
import okhttp3.mockwebserver.MockWebServer;
|
||||||
@ -58,6 +59,7 @@ import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
|
|||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||||
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
import static org.mockito.ArgumentMatchers.same;
|
import static org.mockito.ArgumentMatchers.same;
|
||||||
import static org.mockito.BDDMockito.given;
|
import static org.mockito.BDDMockito.given;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
@ -129,6 +131,15 @@ public class OidcUserServiceTests {
|
|||||||
this.userService.setAccessibleScopes(Collections.emptySet());
|
this.userService.setAccessibleScopes(Collections.emptySet());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void setRetrieveUserInfoWhenNullThenThrowIllegalArgumentException() {
|
||||||
|
// @formatter:off
|
||||||
|
assertThatIllegalArgumentException()
|
||||||
|
.isThrownBy(() -> this.userService.setRetrieveUserInfo(null))
|
||||||
|
.withMessage("retrieveUserInfo cannot be null");
|
||||||
|
// @formatter:on
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
|
public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
|
||||||
assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null));
|
assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null));
|
||||||
@ -218,6 +229,30 @@ public class OidcUserServiceTests {
|
|||||||
assertThat(user.getUserInfo()).isNotNull();
|
assertThat(user.getUserInfo()).isNotNull();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void loadUserWhenCustomRetrieveUserInfoSetThenUsed() {
|
||||||
|
// @formatter:off
|
||||||
|
String userInfoResponse = "{\n"
|
||||||
|
+ " \"sub\": \"subject1\",\n"
|
||||||
|
+ " \"name\": \"first last\",\n"
|
||||||
|
+ " \"given_name\": \"first\",\n"
|
||||||
|
+ " \"family_name\": \"last\",\n"
|
||||||
|
+ " \"preferred_username\": \"user1\",\n"
|
||||||
|
+ " \"email\": \"user1@example.com\"\n"
|
||||||
|
+ "}\n";
|
||||||
|
// @formatter:on
|
||||||
|
this.server.enqueue(jsonResponse(userInfoResponse));
|
||||||
|
String userInfoUri = this.server.url("/user").toString();
|
||||||
|
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
|
||||||
|
this.accessToken = TestOAuth2AccessTokens.noScopes();
|
||||||
|
Predicate<OidcUserRequest> customRetrieveUserInfo = mock(Predicate.class);
|
||||||
|
given(customRetrieveUserInfo.test(any(OidcUserRequest.class))).willReturn(true);
|
||||||
|
this.userService.setRetrieveUserInfo(customRetrieveUserInfo);
|
||||||
|
OidcUser user = this.userService
|
||||||
|
.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
|
||||||
|
assertThat(user.getUserInfo()).isNotNull();
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
|
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
|
||||||
// @formatter:off
|
// @formatter:off
|
||||||
|
Loading…
x
Reference in New Issue
Block a user