Add additional parameters to OAuth2UserRequest
Fixes gh-5368
This commit is contained in:
parent
950a314c9f
commit
8a0c6868cd
|
@ -30,6 +30,7 @@ import org.springframework.security.oauth2.core.user.OAuth2User;
|
|||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* An implementation of an {@link AuthenticationProvider} for OAuth 2.0 Login,
|
||||
|
@ -101,9 +102,10 @@ public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider
|
|||
authorizationCodeAuthentication.getAuthorizationExchange()));
|
||||
|
||||
OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
|
||||
Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
|
||||
|
||||
OAuth2User oauth2User = this.userService.loadUser(
|
||||
new OAuth2UserRequest(authorizationCodeAuthentication.getClientRegistration(), accessToken));
|
||||
OAuth2User oauth2User = this.userService.loadUser(new OAuth2UserRequest(
|
||||
authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters));
|
||||
|
||||
Collection<? extends GrantedAuthority> mappedAuthorities =
|
||||
this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities());
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
package org.springframework.security.oauth2.client.authentication;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.security.authentication.ReactiveAuthenticationManager;
|
||||
import org.springframework.security.core.Authentication;
|
||||
|
@ -109,7 +110,9 @@ public class OAuth2LoginReactiveAuthenticationManager implements
|
|||
|
||||
private Mono<OAuth2AuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) {
|
||||
OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
|
||||
OAuth2UserRequest userRequest = new OAuth2UserRequest(authorizationCodeAuthentication.getClientRegistration(), accessToken);
|
||||
Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
|
||||
OAuth2UserRequest userRequest = new OAuth2UserRequest(
|
||||
authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters);
|
||||
return this.userService.loadUser(userRequest)
|
||||
.flatMap(oauth2User -> {
|
||||
Collection<? extends GrantedAuthority> mappedAuthorities =
|
||||
|
|
|
@ -139,19 +139,18 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati
|
|||
|
||||
ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration();
|
||||
|
||||
if (!accessTokenResponse.getAdditionalParameters().containsKey(OidcParameterNames.ID_TOKEN)) {
|
||||
Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
|
||||
if (!additionalParameters.containsKey(OidcParameterNames.ID_TOKEN)) {
|
||||
OAuth2Error invalidIdTokenError = new OAuth2Error(
|
||||
INVALID_ID_TOKEN_ERROR_CODE,
|
||||
"Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(),
|
||||
null);
|
||||
throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString());
|
||||
}
|
||||
|
||||
OidcIdToken idToken = createOidcToken(clientRegistration, accessTokenResponse);
|
||||
|
||||
OidcUser oidcUser = this.userService.loadUser(
|
||||
new OidcUserRequest(clientRegistration, accessTokenResponse.getAccessToken(), idToken));
|
||||
|
||||
OidcUser oidcUser = this.userService.loadUser(new OidcUserRequest(
|
||||
clientRegistration, accessTokenResponse.getAccessToken(), idToken, additionalParameters));
|
||||
Collection<? extends GrantedAuthority> mappedAuthorities =
|
||||
this.authoritiesMapper.mapAuthorities(oidcUser.getAuthorities());
|
||||
|
||||
|
|
|
@ -159,10 +159,10 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
|
|||
|
||||
private Mono<OAuth2AuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) {
|
||||
OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
|
||||
|
||||
ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration();
|
||||
Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
|
||||
|
||||
if (!accessTokenResponse.getAdditionalParameters().containsKey(OidcParameterNames.ID_TOKEN)) {
|
||||
if (!additionalParameters.containsKey(OidcParameterNames.ID_TOKEN)) {
|
||||
OAuth2Error invalidIdTokenError = new OAuth2Error(
|
||||
INVALID_ID_TOKEN_ERROR_CODE,
|
||||
"Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(),
|
||||
|
@ -171,7 +171,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
|
|||
}
|
||||
|
||||
return createOidcToken(clientRegistration, accessTokenResponse)
|
||||
.map(idToken -> new OidcUserRequest(clientRegistration, accessToken, idToken))
|
||||
.map(idToken -> new OidcUserRequest(clientRegistration, accessToken, idToken, additionalParameters))
|
||||
.flatMap(this.userService::loadUser)
|
||||
.flatMap(oauth2User -> {
|
||||
Collection<? extends GrantedAuthority> mappedAuthorities =
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2017 the original author or authors.
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,6 +21,9 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
|||
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Represents a request the {@link OidcUserService} uses
|
||||
* when initiating a request to the UserInfo Endpoint.
|
||||
|
@ -45,7 +48,22 @@ public class OidcUserRequest extends OAuth2UserRequest {
|
|||
public OidcUserRequest(ClientRegistration clientRegistration,
|
||||
OAuth2AccessToken accessToken, OidcIdToken idToken) {
|
||||
|
||||
super(clientRegistration, accessToken);
|
||||
this(clientRegistration, accessToken, idToken, Collections.emptyMap());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs an {@code OidcUserRequest} using the provided parameters.
|
||||
*
|
||||
* @since 5.1
|
||||
* @param clientRegistration the client registration
|
||||
* @param accessToken the access token credential
|
||||
* @param idToken the ID Token
|
||||
* @param additionalParameters the additional parameters, may be empty
|
||||
*/
|
||||
public OidcUserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken,
|
||||
OidcIdToken idToken, Map<String, Object> additionalParameters) {
|
||||
|
||||
super(clientRegistration, accessToken, additionalParameters);
|
||||
Assert.notNull(idToken, "idToken cannot be null");
|
||||
this.idToken = idToken;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2017 the original author or authors.
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,6 +18,11 @@ package org.springframework.security.oauth2.client.userinfo;
|
|||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Represents a request the {@link OAuth2UserService} uses
|
||||
|
@ -32,6 +37,7 @@ import org.springframework.util.Assert;
|
|||
public class OAuth2UserRequest {
|
||||
private final ClientRegistration clientRegistration;
|
||||
private final OAuth2AccessToken accessToken;
|
||||
private final Map<String, Object> additionalParameters;
|
||||
|
||||
/**
|
||||
* Constructs an {@code OAuth2UserRequest} using the provided parameters.
|
||||
|
@ -40,10 +46,26 @@ public class OAuth2UserRequest {
|
|||
* @param accessToken the access token
|
||||
*/
|
||||
public OAuth2UserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken) {
|
||||
this(clientRegistration, accessToken, Collections.emptyMap());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs an {@code OAuth2UserRequest} using the provided parameters.
|
||||
*
|
||||
* @since 5.1
|
||||
* @param clientRegistration the client registration
|
||||
* @param accessToken the access token
|
||||
* @param additionalParameters the additional parameters, may be empty
|
||||
*/
|
||||
public OAuth2UserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken,
|
||||
Map<String, Object> additionalParameters) {
|
||||
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
|
||||
Assert.notNull(accessToken, "accessToken cannot be null");
|
||||
this.clientRegistration = clientRegistration;
|
||||
this.accessToken = accessToken;
|
||||
this.additionalParameters = Collections.unmodifiableMap(
|
||||
CollectionUtils.isEmpty(additionalParameters) ?
|
||||
Collections.emptyMap() : new LinkedHashMap<>(additionalParameters));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -63,4 +85,14 @@ public class OAuth2UserRequest {
|
|||
public OAuth2AccessToken getAccessToken() {
|
||||
return this.accessToken;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the additional parameters that may be used in the request.
|
||||
*
|
||||
* @since 5.1
|
||||
* @return a {@code Map} of the additional parameters, may be empty.
|
||||
*/
|
||||
public Map<String, Object> getAdditionalParameters() {
|
||||
return this.additionalParameters;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import org.junit.Rule;
|
|||
import org.junit.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.stubbing.Answer;
|
||||
import org.powermock.core.classloader.annotations.PrepareForTest;
|
||||
import org.powermock.modules.junit4.PowerMockRunner;
|
||||
|
@ -35,17 +36,20 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
|||
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
||||
import org.springframework.security.oauth2.core.OAuth2Error;
|
||||
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
||||
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
|
||||
import org.springframework.security.oauth2.core.user.OAuth2User;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.hamcrest.CoreMatchers.containsString;
|
||||
|
@ -164,11 +168,7 @@ public class OAuth2LoginAuthenticationProviderTests {
|
|||
|
||||
@Test
|
||||
public void authenticateWhenLoginSuccessThenReturnAuthentication() {
|
||||
OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
|
||||
OAuth2RefreshToken refreshToken = mock(OAuth2RefreshToken.class);
|
||||
OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
|
||||
when(accessTokenResponse.getAccessToken()).thenReturn(accessToken);
|
||||
when(accessTokenResponse.getRefreshToken()).thenReturn(refreshToken);
|
||||
OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse();
|
||||
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
|
||||
|
||||
OAuth2User principal = mock(OAuth2User.class);
|
||||
|
@ -187,15 +187,13 @@ public class OAuth2LoginAuthenticationProviderTests {
|
|||
assertThat(authentication.getAuthorities()).isEqualTo(authorities);
|
||||
assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
|
||||
assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
|
||||
assertThat(authentication.getAccessToken()).isEqualTo(accessToken);
|
||||
assertThat(authentication.getRefreshToken()).isEqualTo(refreshToken);
|
||||
assertThat(authentication.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
|
||||
assertThat(authentication.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() {
|
||||
OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
|
||||
OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
|
||||
when(accessTokenResponse.getAccessToken()).thenReturn(accessToken);
|
||||
OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse();
|
||||
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
|
||||
|
||||
OAuth2User principal = mock(OAuth2User.class);
|
||||
|
@ -216,4 +214,42 @@ public class OAuth2LoginAuthenticationProviderTests {
|
|||
|
||||
assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities);
|
||||
}
|
||||
|
||||
// gh-5368
|
||||
@Test
|
||||
public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() {
|
||||
OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse();
|
||||
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
|
||||
|
||||
OAuth2User principal = mock(OAuth2User.class);
|
||||
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
|
||||
when(principal.getAuthorities()).thenAnswer(
|
||||
(Answer<List<GrantedAuthority>>) invocation -> authorities);
|
||||
ArgumentCaptor<OAuth2UserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OAuth2UserRequest.class);
|
||||
when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(principal);
|
||||
|
||||
this.authenticationProvider.authenticate(
|
||||
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
|
||||
|
||||
assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()).containsAllEntriesOf(
|
||||
accessTokenResponse.getAdditionalParameters());
|
||||
}
|
||||
|
||||
private OAuth2AccessTokenResponse accessTokenSuccessResponse() {
|
||||
Instant expiresAt = Instant.now().plusSeconds(5);
|
||||
Set<String> scopes = new LinkedHashSet<>(Arrays.asList("scope1", "scope2"));
|
||||
Map<String, Object> additionalParameters = new HashMap<>();
|
||||
additionalParameters.put("param1", "value1");
|
||||
additionalParameters.put("param2", "value2");
|
||||
|
||||
return OAuth2AccessTokenResponse
|
||||
.withToken("access-token-1234")
|
||||
.tokenType(OAuth2AccessToken.TokenType.BEARER)
|
||||
.expiresIn(expiresAt.getEpochSecond())
|
||||
.scopes(scopes)
|
||||
.refreshToken("refresh-token-1234")
|
||||
.additionalParameters(additionalParameters)
|
||||
.build();
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,11 +23,14 @@ import static org.mockito.ArgumentMatchers.any;
|
|||
import static org.mockito.Mockito.when;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
import org.springframework.security.authentication.TestingAuthenticationToken;
|
||||
|
@ -164,7 +167,7 @@ public class OAuth2LoginReactiveAuthenticationManagerTests {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void authenticationWhenOAuth2UserNotFoundThenSuccess() {
|
||||
public void authenticationWhenOAuth2UserFoundThenSuccess() {
|
||||
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
|
||||
.tokenType(OAuth2AccessToken.TokenType.BEARER)
|
||||
.build();
|
||||
|
@ -179,6 +182,27 @@ public class OAuth2LoginReactiveAuthenticationManagerTests {
|
|||
assertThat(result.isAuthenticated()).isTrue();
|
||||
}
|
||||
|
||||
// gh-5368
|
||||
@Test
|
||||
public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() {
|
||||
Map<String, Object> additionalParameters = new HashMap<>();
|
||||
additionalParameters.put("param1", "value1");
|
||||
additionalParameters.put("param2", "value2");
|
||||
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
|
||||
.tokenType(OAuth2AccessToken.TokenType.BEARER)
|
||||
.additionalParameters(additionalParameters)
|
||||
.build();
|
||||
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
|
||||
DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections.singletonMap("user", "rob"), "user");
|
||||
ArgumentCaptor<OAuth2UserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OAuth2UserRequest.class);
|
||||
when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(Mono.just(user));
|
||||
|
||||
this.manager.authenticate(loginToken()).block();
|
||||
|
||||
assertThat(userRequestArgCaptor.getValue().getAdditionalParameters())
|
||||
.containsAllEntriesOf(accessTokenResponse.getAdditionalParameters());
|
||||
}
|
||||
|
||||
private OAuth2LoginAuthenticationToken loginToken() {
|
||||
ClientRegistration clientRegistration = this.registration.build();
|
||||
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest
|
||||
|
|
|
@ -20,6 +20,7 @@ import org.junit.Rule;
|
|||
import org.junit.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.stubbing.Answer;
|
||||
import org.powermock.api.mockito.PowerMockito;
|
||||
import org.powermock.core.classloader.annotations.PrepareForTest;
|
||||
|
@ -37,7 +38,6 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
|||
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
||||
import org.springframework.security.oauth2.core.OAuth2Error;
|
||||
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
||||
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
||||
|
@ -55,6 +55,7 @@ import java.util.HashMap;
|
|||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.hamcrest.CoreMatchers.containsString;
|
||||
|
@ -78,8 +79,6 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
|
|||
private OAuth2AuthorizationExchange authorizationExchange;
|
||||
private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
|
||||
private OAuth2AccessTokenResponse accessTokenResponse;
|
||||
private OAuth2AccessToken accessToken;
|
||||
private OAuth2RefreshToken refreshToken;
|
||||
private OAuth2UserService<OidcUserRequest, OidcUser> userService;
|
||||
private OidcAuthorizationCodeAuthenticationProvider authenticationProvider;
|
||||
|
||||
|
@ -95,9 +94,7 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
|
|||
this.authorizationResponse = mock(OAuth2AuthorizationResponse.class);
|
||||
this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse);
|
||||
this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
|
||||
this.accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
|
||||
this.accessToken = mock(OAuth2AccessToken.class);
|
||||
this.refreshToken = mock(OAuth2RefreshToken.class);
|
||||
this.accessTokenResponse = this.accessTokenSuccessResponse();
|
||||
this.userService = mock(OAuth2UserService.class);
|
||||
this.authenticationProvider = PowerMockito.spy(
|
||||
new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, this.userService));
|
||||
|
@ -111,11 +108,6 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
|
|||
when(this.authorizationResponse.getState()).thenReturn("12345");
|
||||
when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
|
||||
when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com");
|
||||
when(this.accessTokenResponse.getAccessToken()).thenReturn(this.accessToken);
|
||||
when(this.accessTokenResponse.getRefreshToken()).thenReturn(this.refreshToken);
|
||||
Map<String, Object> additionalParameters = new HashMap<>();
|
||||
additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token");
|
||||
when(this.accessTokenResponse.getAdditionalParameters()).thenReturn(additionalParameters);
|
||||
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(this.accessTokenResponse);
|
||||
}
|
||||
|
||||
|
@ -194,7 +186,11 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
|
|||
this.exception.expect(OAuth2AuthenticationException.class);
|
||||
this.exception.expectMessage(containsString("invalid_id_token"));
|
||||
|
||||
when(this.accessTokenResponse.getAdditionalParameters()).thenReturn(Collections.emptyMap());
|
||||
OAuth2AccessTokenResponse accessTokenResponse =
|
||||
OAuth2AccessTokenResponse.withResponse(this.accessTokenSuccessResponse())
|
||||
.additionalParameters(Collections.emptyMap())
|
||||
.build();
|
||||
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
|
||||
|
||||
this.authenticationProvider.authenticate(
|
||||
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
|
||||
|
@ -368,8 +364,8 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
|
|||
assertThat(authentication.getAuthorities()).isEqualTo(authorities);
|
||||
assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
|
||||
assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
|
||||
assertThat(authentication.getAccessToken()).isEqualTo(this.accessToken);
|
||||
assertThat(authentication.getRefreshToken()).isEqualTo(this.refreshToken);
|
||||
assertThat(authentication.getAccessToken()).isEqualTo(this.accessTokenResponse.getAccessToken());
|
||||
assertThat(authentication.getRefreshToken()).isEqualTo(this.accessTokenResponse.getRefreshToken());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -400,6 +396,30 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
|
|||
assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities);
|
||||
}
|
||||
|
||||
// gh-5368
|
||||
@Test
|
||||
public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() throws Exception {
|
||||
Map<String, Object> claims = new HashMap<>();
|
||||
claims.put(IdTokenClaimNames.ISS, "https://provider.com");
|
||||
claims.put(IdTokenClaimNames.SUB, "subject1");
|
||||
claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
|
||||
claims.put(IdTokenClaimNames.AZP, "client1");
|
||||
this.setUpIdToken(claims);
|
||||
|
||||
OidcUser principal = mock(OidcUser.class);
|
||||
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
|
||||
when(principal.getAuthorities()).thenAnswer(
|
||||
(Answer<List<GrantedAuthority>>) invocation -> authorities);
|
||||
ArgumentCaptor<OidcUserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OidcUserRequest.class);
|
||||
when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(principal);
|
||||
|
||||
this.authenticationProvider.authenticate(new OAuth2LoginAuthenticationToken(
|
||||
this.clientRegistration, this.authorizationExchange));
|
||||
|
||||
assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()).containsAllEntriesOf(
|
||||
this.accessTokenResponse.getAdditionalParameters());
|
||||
}
|
||||
|
||||
private void setUpIdToken(Map<String, Object> claims) throws Exception {
|
||||
Instant issuedAt = Instant.now();
|
||||
Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600);
|
||||
|
@ -416,4 +436,23 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
|
|||
when(jwtDecoder.decode(anyString())).thenReturn(idToken);
|
||||
PowerMockito.doReturn(jwtDecoder).when(this.authenticationProvider, "getJwtDecoder", any(ClientRegistration.class));
|
||||
}
|
||||
|
||||
private OAuth2AccessTokenResponse accessTokenSuccessResponse() {
|
||||
Instant expiresAt = Instant.now().plusSeconds(5);
|
||||
Set<String> scopes = new LinkedHashSet<>(Arrays.asList("openid", "profile", "email"));
|
||||
Map<String, Object> additionalParameters = new HashMap<>();
|
||||
additionalParameters.put("param1", "value1");
|
||||
additionalParameters.put("param2", "value2");
|
||||
additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token");
|
||||
|
||||
return OAuth2AccessTokenResponse
|
||||
.withToken("access-token-1234")
|
||||
.tokenType(OAuth2AccessToken.TokenType.BEARER)
|
||||
.expiresIn(expiresAt.getEpochSecond())
|
||||
.scopes(scopes)
|
||||
.refreshToken("refresh-token-1234")
|
||||
.additionalParameters(additionalParameters)
|
||||
.build();
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client.oidc.authentication;
|
|||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
import org.springframework.security.authentication.TestingAuthenticationToken;
|
||||
|
@ -217,6 +218,39 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
|
|||
assertThat(result.isAuthenticated()).isTrue();
|
||||
}
|
||||
|
||||
// gh-5368
|
||||
@Test
|
||||
public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() {
|
||||
Map<String, Object> additionalParameters = new HashMap<>();
|
||||
additionalParameters.put(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue());
|
||||
additionalParameters.put("param1", "value1");
|
||||
additionalParameters.put("param2", "value2");
|
||||
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
|
||||
.tokenType(OAuth2AccessToken.TokenType.BEARER)
|
||||
.additionalParameters(additionalParameters)
|
||||
.build();
|
||||
|
||||
Map<String, Object> claims = new HashMap<>();
|
||||
claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com");
|
||||
claims.put(IdTokenClaimNames.SUB, "rob");
|
||||
claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId"));
|
||||
Instant issuedAt = Instant.now();
|
||||
Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600);
|
||||
Jwt idToken = new Jwt("id-token", issuedAt, expiresAt, claims, claims);
|
||||
|
||||
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
|
||||
DefaultOidcUser user = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), this.idToken);
|
||||
ArgumentCaptor<OidcUserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OidcUserRequest.class);
|
||||
when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(Mono.just(user));
|
||||
when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken));
|
||||
this.manager.setDecoderFactory(c -> this.jwtDecoder);
|
||||
|
||||
this.manager.authenticate(loginToken()).block();
|
||||
|
||||
assertThat(userRequestArgCaptor.getValue().getAdditionalParameters())
|
||||
.containsAllEntriesOf(accessTokenResponse.getAdditionalParameters());
|
||||
}
|
||||
|
||||
private OAuth2LoginAuthenticationToken loginToken() {
|
||||
ClientRegistration clientRegistration = this.registration.build();
|
||||
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2017 the original author or authors.
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,57 +17,87 @@ package org.springframework.security.oauth2.client.oidc.userinfo;
|
|||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.powermock.core.classloader.annotations.PrepareForTest;
|
||||
import org.powermock.modules.junit4.PowerMockRunner;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
|
||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
|
||||
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
/**
|
||||
* Tests for {@link OidcUserRequest}.
|
||||
*
|
||||
* @author Joe Grandja
|
||||
*/
|
||||
@RunWith(PowerMockRunner.class)
|
||||
@PrepareForTest(ClientRegistration.class)
|
||||
public class OidcUserRequestTests {
|
||||
private ClientRegistration clientRegistration;
|
||||
private OAuth2AccessToken accessToken;
|
||||
private OidcIdToken idToken;
|
||||
private Map<String, Object> additionalParameters;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
this.clientRegistration = mock(ClientRegistration.class);
|
||||
this.accessToken = mock(OAuth2AccessToken.class);
|
||||
this.idToken = mock(OidcIdToken.class);
|
||||
this.clientRegistration = ClientRegistration.withRegistrationId("registration-1")
|
||||
.clientId("client-1")
|
||||
.clientSecret("secret")
|
||||
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
|
||||
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
|
||||
.redirectUriTemplate("https://client.com")
|
||||
.scope(new LinkedHashSet<>(Arrays.asList("openid", "profile")))
|
||||
.authorizationUri("https://provider.com/oauth2/authorization")
|
||||
.tokenUri("https://provider.com/oauth2/token")
|
||||
.jwkSetUri("https://provider.com/keys")
|
||||
.clientName("Client 1")
|
||||
.build();
|
||||
this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
|
||||
"access-token-1234", Instant.now(), Instant.now().plusSeconds(60),
|
||||
new LinkedHashSet<>(Arrays.asList("scope1", "scope2")));
|
||||
Map<String, Object> claims = new HashMap<>();
|
||||
claims.put(IdTokenClaimNames.ISS, "https://provider.com");
|
||||
claims.put(IdTokenClaimNames.SUB, "subject1");
|
||||
claims.put(IdTokenClaimNames.AZP, "client-1");
|
||||
this.idToken = new OidcIdToken("id-token-1234", Instant.now(),
|
||||
Instant.now().plusSeconds(3600), claims);
|
||||
this.additionalParameters = new HashMap<>();
|
||||
this.additionalParameters.put("param1", "value1");
|
||||
this.additionalParameters.put("param2", "value2");
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test
|
||||
public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
|
||||
new OidcUserRequest(null, this.accessToken, this.idToken);
|
||||
assertThatThrownBy(() -> new OidcUserRequest(null, this.accessToken, this.idToken))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test
|
||||
public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() {
|
||||
new OidcUserRequest(this.clientRegistration, null, this.idToken);
|
||||
assertThatThrownBy(() -> new OidcUserRequest(this.clientRegistration, null, this.idToken))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test
|
||||
public void constructorWhenIdTokenIsNullThenThrowIllegalArgumentException() {
|
||||
new OidcUserRequest(this.clientRegistration, this.accessToken, null);
|
||||
assertThatThrownBy(() -> new OidcUserRequest(this.clientRegistration, this.accessToken, null))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void constructorWhenAllParametersProvidedAndValidThenCreated() {
|
||||
OidcUserRequest userRequest = new OidcUserRequest(
|
||||
this.clientRegistration, this.accessToken, this.idToken);
|
||||
this.clientRegistration, this.accessToken, this.idToken, this.additionalParameters);
|
||||
|
||||
assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration);
|
||||
assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken);
|
||||
assertThat(userRequest.getIdToken()).isEqualTo(this.idToken);
|
||||
assertThat(userRequest.getAdditionalParameters()).containsAllEntriesOf(this.additionalParameters);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2017 the original author or authors.
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,47 +17,70 @@ package org.springframework.security.oauth2.client.userinfo;
|
|||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.powermock.core.classloader.annotations.PrepareForTest;
|
||||
import org.powermock.modules.junit4.PowerMockRunner;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
|
||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
/**
|
||||
* Tests for {@link OAuth2UserRequest}.
|
||||
*
|
||||
* @author Joe Grandja
|
||||
*/
|
||||
@RunWith(PowerMockRunner.class)
|
||||
@PrepareForTest(ClientRegistration.class)
|
||||
public class OAuth2UserRequestTests {
|
||||
private ClientRegistration clientRegistration;
|
||||
private OAuth2AccessToken accessToken;
|
||||
private Map<String, Object> additionalParameters;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
this.clientRegistration = mock(ClientRegistration.class);
|
||||
this.accessToken = mock(OAuth2AccessToken.class);
|
||||
this.clientRegistration = ClientRegistration.withRegistrationId("registration-1")
|
||||
.clientId("client-1")
|
||||
.clientSecret("secret")
|
||||
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
|
||||
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
|
||||
.redirectUriTemplate("https://client.com")
|
||||
.scope(new LinkedHashSet<>(Arrays.asList("scope1", "scope2")))
|
||||
.authorizationUri("https://provider.com/oauth2/authorization")
|
||||
.tokenUri("https://provider.com/oauth2/token")
|
||||
.clientName("Client 1")
|
||||
.build();
|
||||
this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
|
||||
"access-token-1234", Instant.now(), Instant.now().plusSeconds(60),
|
||||
new LinkedHashSet<>(Arrays.asList("scope1", "scope2")));
|
||||
this.additionalParameters = new HashMap<>();
|
||||
this.additionalParameters.put("param1", "value1");
|
||||
this.additionalParameters.put("param2", "value2");
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test
|
||||
public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
|
||||
new OAuth2UserRequest(null, this.accessToken);
|
||||
assertThatThrownBy(() -> new OAuth2UserRequest(null, this.accessToken))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test
|
||||
public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() {
|
||||
new OAuth2UserRequest(this.clientRegistration, null);
|
||||
assertThatThrownBy(() -> new OAuth2UserRequest(this.clientRegistration, null))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void constructorWhenAllParametersProvidedAndValidThenCreated() {
|
||||
OAuth2UserRequest userRequest = new OAuth2UserRequest(this.clientRegistration, this.accessToken);
|
||||
OAuth2UserRequest userRequest = new OAuth2UserRequest(
|
||||
this.clientRegistration, this.accessToken, this.additionalParameters);
|
||||
|
||||
assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration);
|
||||
assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken);
|
||||
assertThat(userRequest.getAdditionalParameters()).containsAllEntriesOf(this.additionalParameters);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue