diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java
index 938fd9c215..fc5fb598b3 100644
--- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java
+++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -73,7 +73,8 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
authorizationCodeAuthentication.getClientRegistration(),
authorizationCodeAuthentication.getAuthorizationExchange(),
accessTokenResponse.getAccessToken(),
- accessTokenResponse.getRefreshToken());
+ accessTokenResponse.getRefreshToken(),
+ accessTokenResponse.getAdditionalParameters());
authenticationResult.setDetails(authorizationCodeAuthentication.getDetails());
return authenticationResult;
diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java
index 02d4b5e7b1..8d3b70234d 100644
--- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java
+++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -28,7 +28,6 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
-import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.util.Assert;
@@ -60,7 +59,7 @@ import java.util.Map;
* @see Section 4.1.4 Access Token Response
*/
public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider {
- private final OAuth2AccessTokenResponseClient accessTokenResponseClient;
+ private final OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider;
private final OAuth2UserService userService;
private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities);
@@ -74,59 +73,54 @@ public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider
OAuth2AccessTokenResponseClient accessTokenResponseClient,
OAuth2UserService userService) {
- Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null");
Assert.notNull(userService, "userService cannot be null");
- this.accessTokenResponseClient = accessTokenResponseClient;
+ this.authorizationCodeAuthenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(accessTokenResponseClient);
this.userService = userService;
}
@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
- OAuth2LoginAuthenticationToken authorizationCodeAuthentication =
+ OAuth2LoginAuthenticationToken loginAuthenticationToken =
(OAuth2LoginAuthenticationToken) authentication;
// Section 3.1.2.1 Authentication Request - https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
// scope
// REQUIRED. OpenID Connect requests MUST contain the "openid" scope value.
- if (authorizationCodeAuthentication.getAuthorizationExchange()
+ if (loginAuthenticationToken.getAuthorizationExchange()
.getAuthorizationRequest().getScopes().contains("openid")) {
// This is an OpenID Connect Authentication Request so return null
// and let OidcAuthorizationCodeAuthenticationProvider handle it instead
return null;
}
- OAuth2AccessTokenResponse accessTokenResponse;
+ OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthenticationToken;
try {
- OAuth2AuthorizationExchangeValidator.validate(
- authorizationCodeAuthentication.getAuthorizationExchange());
-
- accessTokenResponse = this.accessTokenResponseClient.getTokenResponse(
- new OAuth2AuthorizationCodeGrantRequest(
- authorizationCodeAuthentication.getClientRegistration(),
- authorizationCodeAuthentication.getAuthorizationExchange()));
-
+ authorizationCodeAuthenticationToken = (OAuth2AuthorizationCodeAuthenticationToken) this.authorizationCodeAuthenticationProvider
+ .authenticate(new OAuth2AuthorizationCodeAuthenticationToken(
+ loginAuthenticationToken.getClientRegistration(),
+ loginAuthenticationToken.getAuthorizationExchange()));
} catch (OAuth2AuthorizationException ex) {
OAuth2Error oauth2Error = ex.getError();
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
- OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
- Map additionalParameters = accessTokenResponse.getAdditionalParameters();
+ OAuth2AccessToken accessToken = authorizationCodeAuthenticationToken.getAccessToken();
+ Map additionalParameters = authorizationCodeAuthenticationToken.getAdditionalParameters();
OAuth2User oauth2User = this.userService.loadUser(new OAuth2UserRequest(
- authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters));
+ loginAuthenticationToken.getClientRegistration(), accessToken, additionalParameters));
Collection extends GrantedAuthority> mappedAuthorities =
this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities());
OAuth2LoginAuthenticationToken authenticationResult = new OAuth2LoginAuthenticationToken(
- authorizationCodeAuthentication.getClientRegistration(),
- authorizationCodeAuthentication.getAuthorizationExchange(),
+ loginAuthenticationToken.getClientRegistration(),
+ loginAuthenticationToken.getAuthorizationExchange(),
oauth2User,
mappedAuthorities,
accessToken,
- accessTokenResponse.getRefreshToken());
- authenticationResult.setDetails(authorizationCodeAuthentication.getDetails());
+ authorizationCodeAuthenticationToken.getRefreshToken());
+ authenticationResult.setDetails(loginAuthenticationToken.getDetails());
return authenticationResult;
}
diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java
index 2ba5e36927..41ebe4a1e6 100644
--- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java
+++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,6 +16,8 @@
package org.springframework.security.oauth2.client.authentication;
import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
import org.junit.Before;
import org.junit.Test;
@@ -119,4 +121,26 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
assertThat(authenticationResult.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(authenticationResult.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
}
+
+ // gh-5368
+ @Test
+ public void authenticateWhenAuthorizationSuccessResponseThenAdditionalParametersIncluded() {
+ Map additionalParameters = new HashMap<>();
+ additionalParameters.put("param1", "value1");
+ additionalParameters.put("param2", "value2");
+
+ OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().additionalParameters(additionalParameters)
+ .build();
+ when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
+
+ OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
+ success().build());
+
+ OAuth2AuthorizationCodeAuthenticationToken authentication = (OAuth2AuthorizationCodeAuthenticationToken) this.authenticationProvider
+ .authenticate(
+ new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, authorizationExchange));
+
+ assertThat(authentication.getAdditionalParameters())
+ .containsAllEntriesOf(accessTokenResponse.getAdditionalParameters());
+ }
}