Add additional parameters to OAuth2UserRequest

Fixes gh-5368
This commit is contained in:
Joe Grandja 2018-08-13 07:51:06 -04:00
parent 950a314c9f
commit 8a0c6868cd
12 changed files with 311 additions and 71 deletions

View File

@ -30,6 +30,7 @@ import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import java.util.Collection; import java.util.Collection;
import java.util.Map;
/** /**
* An implementation of an {@link AuthenticationProvider} for OAuth 2.0 Login, * An implementation of an {@link AuthenticationProvider} for OAuth 2.0 Login,
@ -101,9 +102,10 @@ public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider
authorizationCodeAuthentication.getAuthorizationExchange())); authorizationCodeAuthentication.getAuthorizationExchange()));
OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken(); OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
OAuth2User oauth2User = this.userService.loadUser( OAuth2User oauth2User = this.userService.loadUser(new OAuth2UserRequest(
new OAuth2UserRequest(authorizationCodeAuthentication.getClientRegistration(), accessToken)); authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters));
Collection<? extends GrantedAuthority> mappedAuthorities = Collection<? extends GrantedAuthority> mappedAuthorities =
this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities()); this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities());

View File

@ -16,6 +16,7 @@
package org.springframework.security.oauth2.client.authentication; package org.springframework.security.oauth2.client.authentication;
import java.util.Collection; import java.util.Collection;
import java.util.Map;
import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.authentication.ReactiveAuthenticationManager;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@ -109,7 +110,9 @@ public class OAuth2LoginReactiveAuthenticationManager implements
private Mono<OAuth2AuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) { private Mono<OAuth2AuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) {
OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken(); OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
OAuth2UserRequest userRequest = new OAuth2UserRequest(authorizationCodeAuthentication.getClientRegistration(), accessToken); Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
OAuth2UserRequest userRequest = new OAuth2UserRequest(
authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters);
return this.userService.loadUser(userRequest) return this.userService.loadUser(userRequest)
.flatMap(oauth2User -> { .flatMap(oauth2User -> {
Collection<? extends GrantedAuthority> mappedAuthorities = Collection<? extends GrantedAuthority> mappedAuthorities =

View File

@ -139,19 +139,18 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati
ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration(); ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration();
if (!accessTokenResponse.getAdditionalParameters().containsKey(OidcParameterNames.ID_TOKEN)) { Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
if (!additionalParameters.containsKey(OidcParameterNames.ID_TOKEN)) {
OAuth2Error invalidIdTokenError = new OAuth2Error( OAuth2Error invalidIdTokenError = new OAuth2Error(
INVALID_ID_TOKEN_ERROR_CODE, INVALID_ID_TOKEN_ERROR_CODE,
"Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(), "Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(),
null); null);
throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString()); throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString());
} }
OidcIdToken idToken = createOidcToken(clientRegistration, accessTokenResponse); OidcIdToken idToken = createOidcToken(clientRegistration, accessTokenResponse);
OidcUser oidcUser = this.userService.loadUser( OidcUser oidcUser = this.userService.loadUser(new OidcUserRequest(
new OidcUserRequest(clientRegistration, accessTokenResponse.getAccessToken(), idToken)); clientRegistration, accessTokenResponse.getAccessToken(), idToken, additionalParameters));
Collection<? extends GrantedAuthority> mappedAuthorities = Collection<? extends GrantedAuthority> mappedAuthorities =
this.authoritiesMapper.mapAuthorities(oidcUser.getAuthorities()); this.authoritiesMapper.mapAuthorities(oidcUser.getAuthorities());

View File

@ -159,10 +159,10 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
private Mono<OAuth2AuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) { private Mono<OAuth2AuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) {
OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken(); OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration(); ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration();
Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
if (!accessTokenResponse.getAdditionalParameters().containsKey(OidcParameterNames.ID_TOKEN)) { if (!additionalParameters.containsKey(OidcParameterNames.ID_TOKEN)) {
OAuth2Error invalidIdTokenError = new OAuth2Error( OAuth2Error invalidIdTokenError = new OAuth2Error(
INVALID_ID_TOKEN_ERROR_CODE, INVALID_ID_TOKEN_ERROR_CODE,
"Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(), "Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(),
@ -171,7 +171,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
} }
return createOidcToken(clientRegistration, accessTokenResponse) return createOidcToken(clientRegistration, accessTokenResponse)
.map(idToken -> new OidcUserRequest(clientRegistration, accessToken, idToken)) .map(idToken -> new OidcUserRequest(clientRegistration, accessToken, idToken, additionalParameters))
.flatMap(this.userService::loadUser) .flatMap(this.userService::loadUser)
.flatMap(oauth2User -> { .flatMap(oauth2User -> {
Collection<? extends GrantedAuthority> mappedAuthorities = Collection<? extends GrantedAuthority> mappedAuthorities =

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2017 the original author or authors. * Copyright 2002-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -21,6 +21,9 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import java.util.Collections;
import java.util.Map;
/** /**
* Represents a request the {@link OidcUserService} uses * Represents a request the {@link OidcUserService} uses
* when initiating a request to the UserInfo Endpoint. * when initiating a request to the UserInfo Endpoint.
@ -45,7 +48,22 @@ public class OidcUserRequest extends OAuth2UserRequest {
public OidcUserRequest(ClientRegistration clientRegistration, public OidcUserRequest(ClientRegistration clientRegistration,
OAuth2AccessToken accessToken, OidcIdToken idToken) { OAuth2AccessToken accessToken, OidcIdToken idToken) {
super(clientRegistration, accessToken); this(clientRegistration, accessToken, idToken, Collections.emptyMap());
}
/**
* Constructs an {@code OidcUserRequest} using the provided parameters.
*
* @since 5.1
* @param clientRegistration the client registration
* @param accessToken the access token credential
* @param idToken the ID Token
* @param additionalParameters the additional parameters, may be empty
*/
public OidcUserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken,
OidcIdToken idToken, Map<String, Object> additionalParameters) {
super(clientRegistration, accessToken, additionalParameters);
Assert.notNull(idToken, "idToken cannot be null"); Assert.notNull(idToken, "idToken cannot be null");
this.idToken = idToken; this.idToken = idToken;
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2017 the original author or authors. * Copyright 2002-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,6 +18,11 @@ package org.springframework.security.oauth2.client.userinfo;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
/** /**
* Represents a request the {@link OAuth2UserService} uses * Represents a request the {@link OAuth2UserService} uses
@ -32,6 +37,7 @@ import org.springframework.util.Assert;
public class OAuth2UserRequest { public class OAuth2UserRequest {
private final ClientRegistration clientRegistration; private final ClientRegistration clientRegistration;
private final OAuth2AccessToken accessToken; private final OAuth2AccessToken accessToken;
private final Map<String, Object> additionalParameters;
/** /**
* Constructs an {@code OAuth2UserRequest} using the provided parameters. * Constructs an {@code OAuth2UserRequest} using the provided parameters.
@ -40,10 +46,26 @@ public class OAuth2UserRequest {
* @param accessToken the access token * @param accessToken the access token
*/ */
public OAuth2UserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken) { public OAuth2UserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken) {
this(clientRegistration, accessToken, Collections.emptyMap());
}
/**
* Constructs an {@code OAuth2UserRequest} using the provided parameters.
*
* @since 5.1
* @param clientRegistration the client registration
* @param accessToken the access token
* @param additionalParameters the additional parameters, may be empty
*/
public OAuth2UserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken,
Map<String, Object> additionalParameters) {
Assert.notNull(clientRegistration, "clientRegistration cannot be null"); Assert.notNull(clientRegistration, "clientRegistration cannot be null");
Assert.notNull(accessToken, "accessToken cannot be null"); Assert.notNull(accessToken, "accessToken cannot be null");
this.clientRegistration = clientRegistration; this.clientRegistration = clientRegistration;
this.accessToken = accessToken; this.accessToken = accessToken;
this.additionalParameters = Collections.unmodifiableMap(
CollectionUtils.isEmpty(additionalParameters) ?
Collections.emptyMap() : new LinkedHashMap<>(additionalParameters));
} }
/** /**
@ -63,4 +85,14 @@ public class OAuth2UserRequest {
public OAuth2AccessToken getAccessToken() { public OAuth2AccessToken getAccessToken() {
return this.accessToken; return this.accessToken;
} }
/**
* Returns the additional parameters that may be used in the request.
*
* @since 5.1
* @return a {@code Map} of the additional parameters, may be empty.
*/
public Map<String, Object> getAdditionalParameters() {
return this.additionalParameters;
}
} }

View File

@ -20,6 +20,7 @@ import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.modules.junit4.PowerMockRunner;
@ -35,17 +36,20 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User;
import java.time.Instant;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.containsString;
@ -164,11 +168,7 @@ public class OAuth2LoginAuthenticationProviderTests {
@Test @Test
public void authenticateWhenLoginSuccessThenReturnAuthentication() { public void authenticateWhenLoginSuccessThenReturnAuthentication() {
OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class); OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse();
OAuth2RefreshToken refreshToken = mock(OAuth2RefreshToken.class);
OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
when(accessTokenResponse.getAccessToken()).thenReturn(accessToken);
when(accessTokenResponse.getRefreshToken()).thenReturn(refreshToken);
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
OAuth2User principal = mock(OAuth2User.class); OAuth2User principal = mock(OAuth2User.class);
@ -187,15 +187,13 @@ public class OAuth2LoginAuthenticationProviderTests {
assertThat(authentication.getAuthorities()).isEqualTo(authorities); assertThat(authentication.getAuthorities()).isEqualTo(authorities);
assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange); assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
assertThat(authentication.getAccessToken()).isEqualTo(accessToken); assertThat(authentication.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(authentication.getRefreshToken()).isEqualTo(refreshToken); assertThat(authentication.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
} }
@Test @Test
public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() { public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() {
OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class); OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse();
OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
when(accessTokenResponse.getAccessToken()).thenReturn(accessToken);
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
OAuth2User principal = mock(OAuth2User.class); OAuth2User principal = mock(OAuth2User.class);
@ -216,4 +214,42 @@ public class OAuth2LoginAuthenticationProviderTests {
assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities); assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities);
} }
// gh-5368
@Test
public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() {
OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse();
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
OAuth2User principal = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
when(principal.getAuthorities()).thenAnswer(
(Answer<List<GrantedAuthority>>) invocation -> authorities);
ArgumentCaptor<OAuth2UserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OAuth2UserRequest.class);
when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(principal);
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()).containsAllEntriesOf(
accessTokenResponse.getAdditionalParameters());
}
private OAuth2AccessTokenResponse accessTokenSuccessResponse() {
Instant expiresAt = Instant.now().plusSeconds(5);
Set<String> scopes = new LinkedHashSet<>(Arrays.asList("scope1", "scope2"));
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2");
return OAuth2AccessTokenResponse
.withToken("access-token-1234")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(expiresAt.getEpochSecond())
.scopes(scopes)
.refreshToken("refresh-token-1234")
.additionalParameters(additionalParameters)
.build();
}
} }

View File

@ -23,11 +23,14 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.junit.Before; import org.junit.Before;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
@ -164,7 +167,7 @@ public class OAuth2LoginReactiveAuthenticationManagerTests {
} }
@Test @Test
public void authenticationWhenOAuth2UserNotFoundThenSuccess() { public void authenticationWhenOAuth2UserFoundThenSuccess() {
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
.tokenType(OAuth2AccessToken.TokenType.BEARER) .tokenType(OAuth2AccessToken.TokenType.BEARER)
.build(); .build();
@ -179,6 +182,27 @@ public class OAuth2LoginReactiveAuthenticationManagerTests {
assertThat(result.isAuthenticated()).isTrue(); assertThat(result.isAuthenticated()).isTrue();
} }
// gh-5368
@Test
public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() {
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2");
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.additionalParameters(additionalParameters)
.build();
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections.singletonMap("user", "rob"), "user");
ArgumentCaptor<OAuth2UserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OAuth2UserRequest.class);
when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(Mono.just(user));
this.manager.authenticate(loginToken()).block();
assertThat(userRequestArgCaptor.getValue().getAdditionalParameters())
.containsAllEntriesOf(accessTokenResponse.getAdditionalParameters());
}
private OAuth2LoginAuthenticationToken loginToken() { private OAuth2LoginAuthenticationToken loginToken() {
ClientRegistration clientRegistration = this.registration.build(); ClientRegistration clientRegistration = this.registration.build();
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest

View File

@ -20,6 +20,7 @@ import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.powermock.api.mockito.PowerMockito; import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.core.classloader.annotations.PrepareForTest;
@ -37,7 +38,6 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
@ -55,6 +55,7 @@ import java.util.HashMap;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.containsString;
@ -78,8 +79,6 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
private OAuth2AuthorizationExchange authorizationExchange; private OAuth2AuthorizationExchange authorizationExchange;
private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient; private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
private OAuth2AccessTokenResponse accessTokenResponse; private OAuth2AccessTokenResponse accessTokenResponse;
private OAuth2AccessToken accessToken;
private OAuth2RefreshToken refreshToken;
private OAuth2UserService<OidcUserRequest, OidcUser> userService; private OAuth2UserService<OidcUserRequest, OidcUser> userService;
private OidcAuthorizationCodeAuthenticationProvider authenticationProvider; private OidcAuthorizationCodeAuthenticationProvider authenticationProvider;
@ -95,9 +94,7 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
this.authorizationResponse = mock(OAuth2AuthorizationResponse.class); this.authorizationResponse = mock(OAuth2AuthorizationResponse.class);
this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse); this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse);
this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
this.accessTokenResponse = mock(OAuth2AccessTokenResponse.class); this.accessTokenResponse = this.accessTokenSuccessResponse();
this.accessToken = mock(OAuth2AccessToken.class);
this.refreshToken = mock(OAuth2RefreshToken.class);
this.userService = mock(OAuth2UserService.class); this.userService = mock(OAuth2UserService.class);
this.authenticationProvider = PowerMockito.spy( this.authenticationProvider = PowerMockito.spy(
new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, this.userService)); new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, this.userService));
@ -111,11 +108,6 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
when(this.authorizationResponse.getState()).thenReturn("12345"); when(this.authorizationResponse.getState()).thenReturn("12345");
when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com"); when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com"); when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com");
when(this.accessTokenResponse.getAccessToken()).thenReturn(this.accessToken);
when(this.accessTokenResponse.getRefreshToken()).thenReturn(this.refreshToken);
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token");
when(this.accessTokenResponse.getAdditionalParameters()).thenReturn(additionalParameters);
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(this.accessTokenResponse); when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(this.accessTokenResponse);
} }
@ -194,7 +186,11 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
this.exception.expect(OAuth2AuthenticationException.class); this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_id_token")); this.exception.expectMessage(containsString("invalid_id_token"));
when(this.accessTokenResponse.getAdditionalParameters()).thenReturn(Collections.emptyMap()); OAuth2AccessTokenResponse accessTokenResponse =
OAuth2AccessTokenResponse.withResponse(this.accessTokenSuccessResponse())
.additionalParameters(Collections.emptyMap())
.build();
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
this.authenticationProvider.authenticate( this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
@ -368,8 +364,8 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
assertThat(authentication.getAuthorities()).isEqualTo(authorities); assertThat(authentication.getAuthorities()).isEqualTo(authorities);
assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange); assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
assertThat(authentication.getAccessToken()).isEqualTo(this.accessToken); assertThat(authentication.getAccessToken()).isEqualTo(this.accessTokenResponse.getAccessToken());
assertThat(authentication.getRefreshToken()).isEqualTo(this.refreshToken); assertThat(authentication.getRefreshToken()).isEqualTo(this.accessTokenResponse.getRefreshToken());
} }
@Test @Test
@ -400,6 +396,30 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities); assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities);
} }
// gh-5368
@Test
public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() throws Exception {
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://provider.com");
claims.put(IdTokenClaimNames.SUB, "subject1");
claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
claims.put(IdTokenClaimNames.AZP, "client1");
this.setUpIdToken(claims);
OidcUser principal = mock(OidcUser.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
when(principal.getAuthorities()).thenAnswer(
(Answer<List<GrantedAuthority>>) invocation -> authorities);
ArgumentCaptor<OidcUserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OidcUserRequest.class);
when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(principal);
this.authenticationProvider.authenticate(new OAuth2LoginAuthenticationToken(
this.clientRegistration, this.authorizationExchange));
assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()).containsAllEntriesOf(
this.accessTokenResponse.getAdditionalParameters());
}
private void setUpIdToken(Map<String, Object> claims) throws Exception { private void setUpIdToken(Map<String, Object> claims) throws Exception {
Instant issuedAt = Instant.now(); Instant issuedAt = Instant.now();
Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600); Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600);
@ -416,4 +436,23 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
when(jwtDecoder.decode(anyString())).thenReturn(idToken); when(jwtDecoder.decode(anyString())).thenReturn(idToken);
PowerMockito.doReturn(jwtDecoder).when(this.authenticationProvider, "getJwtDecoder", any(ClientRegistration.class)); PowerMockito.doReturn(jwtDecoder).when(this.authenticationProvider, "getJwtDecoder", any(ClientRegistration.class));
} }
private OAuth2AccessTokenResponse accessTokenSuccessResponse() {
Instant expiresAt = Instant.now().plusSeconds(5);
Set<String> scopes = new LinkedHashSet<>(Arrays.asList("openid", "profile", "email"));
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2");
additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token");
return OAuth2AccessTokenResponse
.withToken("access-token-1234")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(expiresAt.getEpochSecond())
.scopes(scopes)
.refreshToken("refresh-token-1234")
.additionalParameters(additionalParameters)
.build();
}
} }

View File

@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client.oidc.authentication;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
@ -217,6 +218,39 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
assertThat(result.isAuthenticated()).isTrue(); assertThat(result.isAuthenticated()).isTrue();
} }
// gh-5368
@Test
public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() {
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue());
additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2");
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.additionalParameters(additionalParameters)
.build();
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com");
claims.put(IdTokenClaimNames.SUB, "rob");
claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId"));
Instant issuedAt = Instant.now();
Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600);
Jwt idToken = new Jwt("id-token", issuedAt, expiresAt, claims, claims);
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
DefaultOidcUser user = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), this.idToken);
ArgumentCaptor<OidcUserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OidcUserRequest.class);
when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(Mono.just(user));
when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken));
this.manager.setDecoderFactory(c -> this.jwtDecoder);
this.manager.authenticate(loginToken()).block();
assertThat(userRequestArgCaptor.getValue().getAdditionalParameters())
.containsAllEntriesOf(accessTokenResponse.getAdditionalParameters());
}
private OAuth2LoginAuthenticationToken loginToken() { private OAuth2LoginAuthenticationToken loginToken() {
ClientRegistration clientRegistration = this.registration.build(); ClientRegistration clientRegistration = this.registration.build();
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2017 the original author or authors. * Copyright 2002-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,57 +17,87 @@ package org.springframework.security.oauth2.client.oidc.userinfo;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock; import static org.assertj.core.api.Assertions.assertThatThrownBy;
/** /**
* Tests for {@link OidcUserRequest}. * Tests for {@link OidcUserRequest}.
* *
* @author Joe Grandja * @author Joe Grandja
*/ */
@RunWith(PowerMockRunner.class)
@PrepareForTest(ClientRegistration.class)
public class OidcUserRequestTests { public class OidcUserRequestTests {
private ClientRegistration clientRegistration; private ClientRegistration clientRegistration;
private OAuth2AccessToken accessToken; private OAuth2AccessToken accessToken;
private OidcIdToken idToken; private OidcIdToken idToken;
private Map<String, Object> additionalParameters;
@Before @Before
public void setUp() { public void setUp() {
this.clientRegistration = mock(ClientRegistration.class); this.clientRegistration = ClientRegistration.withRegistrationId("registration-1")
this.accessToken = mock(OAuth2AccessToken.class); .clientId("client-1")
this.idToken = mock(OidcIdToken.class); .clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUriTemplate("https://client.com")
.scope(new LinkedHashSet<>(Arrays.asList("openid", "profile")))
.authorizationUri("https://provider.com/oauth2/authorization")
.tokenUri("https://provider.com/oauth2/token")
.jwkSetUri("https://provider.com/keys")
.clientName("Client 1")
.build();
this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
"access-token-1234", Instant.now(), Instant.now().plusSeconds(60),
new LinkedHashSet<>(Arrays.asList("scope1", "scope2")));
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://provider.com");
claims.put(IdTokenClaimNames.SUB, "subject1");
claims.put(IdTokenClaimNames.AZP, "client-1");
this.idToken = new OidcIdToken("id-token-1234", Instant.now(),
Instant.now().plusSeconds(3600), claims);
this.additionalParameters = new HashMap<>();
this.additionalParameters.put("param1", "value1");
this.additionalParameters.put("param2", "value2");
} }
@Test(expected = IllegalArgumentException.class) @Test
public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
new OidcUserRequest(null, this.accessToken, this.idToken); assertThatThrownBy(() -> new OidcUserRequest(null, this.accessToken, this.idToken))
.isInstanceOf(IllegalArgumentException.class);
} }
@Test(expected = IllegalArgumentException.class) @Test
public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() { public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() {
new OidcUserRequest(this.clientRegistration, null, this.idToken); assertThatThrownBy(() -> new OidcUserRequest(this.clientRegistration, null, this.idToken))
.isInstanceOf(IllegalArgumentException.class);
} }
@Test(expected = IllegalArgumentException.class) @Test
public void constructorWhenIdTokenIsNullThenThrowIllegalArgumentException() { public void constructorWhenIdTokenIsNullThenThrowIllegalArgumentException() {
new OidcUserRequest(this.clientRegistration, this.accessToken, null); assertThatThrownBy(() -> new OidcUserRequest(this.clientRegistration, this.accessToken, null))
.isInstanceOf(IllegalArgumentException.class);
} }
@Test @Test
public void constructorWhenAllParametersProvidedAndValidThenCreated() { public void constructorWhenAllParametersProvidedAndValidThenCreated() {
OidcUserRequest userRequest = new OidcUserRequest( OidcUserRequest userRequest = new OidcUserRequest(
this.clientRegistration, this.accessToken, this.idToken); this.clientRegistration, this.accessToken, this.idToken, this.additionalParameters);
assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken); assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken);
assertThat(userRequest.getIdToken()).isEqualTo(this.idToken); assertThat(userRequest.getIdToken()).isEqualTo(this.idToken);
assertThat(userRequest.getAdditionalParameters()).containsAllEntriesOf(this.additionalParameters);
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2017 the original author or authors. * Copyright 2002-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,47 +17,70 @@ package org.springframework.security.oauth2.client.userinfo;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock; import static org.assertj.core.api.Assertions.assertThatThrownBy;
/** /**
* Tests for {@link OAuth2UserRequest}. * Tests for {@link OAuth2UserRequest}.
* *
* @author Joe Grandja * @author Joe Grandja
*/ */
@RunWith(PowerMockRunner.class)
@PrepareForTest(ClientRegistration.class)
public class OAuth2UserRequestTests { public class OAuth2UserRequestTests {
private ClientRegistration clientRegistration; private ClientRegistration clientRegistration;
private OAuth2AccessToken accessToken; private OAuth2AccessToken accessToken;
private Map<String, Object> additionalParameters;
@Before @Before
public void setUp() { public void setUp() {
this.clientRegistration = mock(ClientRegistration.class); this.clientRegistration = ClientRegistration.withRegistrationId("registration-1")
this.accessToken = mock(OAuth2AccessToken.class); .clientId("client-1")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUriTemplate("https://client.com")
.scope(new LinkedHashSet<>(Arrays.asList("scope1", "scope2")))
.authorizationUri("https://provider.com/oauth2/authorization")
.tokenUri("https://provider.com/oauth2/token")
.clientName("Client 1")
.build();
this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
"access-token-1234", Instant.now(), Instant.now().plusSeconds(60),
new LinkedHashSet<>(Arrays.asList("scope1", "scope2")));
this.additionalParameters = new HashMap<>();
this.additionalParameters.put("param1", "value1");
this.additionalParameters.put("param2", "value2");
} }
@Test(expected = IllegalArgumentException.class) @Test
public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
new OAuth2UserRequest(null, this.accessToken); assertThatThrownBy(() -> new OAuth2UserRequest(null, this.accessToken))
.isInstanceOf(IllegalArgumentException.class);
} }
@Test(expected = IllegalArgumentException.class) @Test
public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() { public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() {
new OAuth2UserRequest(this.clientRegistration, null); assertThatThrownBy(() -> new OAuth2UserRequest(this.clientRegistration, null))
.isInstanceOf(IllegalArgumentException.class);
} }
@Test @Test
public void constructorWhenAllParametersProvidedAndValidThenCreated() { public void constructorWhenAllParametersProvidedAndValidThenCreated() {
OAuth2UserRequest userRequest = new OAuth2UserRequest(this.clientRegistration, this.accessToken); OAuth2UserRequest userRequest = new OAuth2UserRequest(
this.clientRegistration, this.accessToken, this.additionalParameters);
assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken); assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken);
assertThat(userRequest.getAdditionalParameters()).containsAllEntriesOf(this.additionalParameters);
} }
} }