diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index 938fd9c215..fc5fb598b3 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -73,7 +73,8 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica authorizationCodeAuthentication.getClientRegistration(), authorizationCodeAuthentication.getAuthorizationExchange(), accessTokenResponse.getAccessToken(), - accessTokenResponse.getRefreshToken()); + accessTokenResponse.getRefreshToken(), + accessTokenResponse.getAdditionalParameters()); authenticationResult.setDetails(authorizationCodeAuthentication.getDetails()); return authenticationResult; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java index 02d4b5e7b1..8d3b70234d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,6 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; -import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.util.Assert; @@ -60,7 +59,7 @@ import java.util.Map; * @see Section 4.1.4 Access Token Response */ public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider { - private final OAuth2AccessTokenResponseClient accessTokenResponseClient; + private final OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider; private final OAuth2UserService userService; private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities); @@ -74,59 +73,54 @@ public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider OAuth2AccessTokenResponseClient accessTokenResponseClient, OAuth2UserService userService) { - Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); Assert.notNull(userService, "userService cannot be null"); - this.accessTokenResponseClient = accessTokenResponseClient; + this.authorizationCodeAuthenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(accessTokenResponseClient); this.userService = userService; } @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { - OAuth2LoginAuthenticationToken authorizationCodeAuthentication = + OAuth2LoginAuthenticationToken loginAuthenticationToken = (OAuth2LoginAuthenticationToken) authentication; // Section 3.1.2.1 Authentication Request - https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest // scope // REQUIRED. OpenID Connect requests MUST contain the "openid" scope value. - if (authorizationCodeAuthentication.getAuthorizationExchange() + if (loginAuthenticationToken.getAuthorizationExchange() .getAuthorizationRequest().getScopes().contains("openid")) { // This is an OpenID Connect Authentication Request so return null // and let OidcAuthorizationCodeAuthenticationProvider handle it instead return null; } - OAuth2AccessTokenResponse accessTokenResponse; + OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthenticationToken; try { - OAuth2AuthorizationExchangeValidator.validate( - authorizationCodeAuthentication.getAuthorizationExchange()); - - accessTokenResponse = this.accessTokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest( - authorizationCodeAuthentication.getClientRegistration(), - authorizationCodeAuthentication.getAuthorizationExchange())); - + authorizationCodeAuthenticationToken = (OAuth2AuthorizationCodeAuthenticationToken) this.authorizationCodeAuthenticationProvider + .authenticate(new OAuth2AuthorizationCodeAuthenticationToken( + loginAuthenticationToken.getClientRegistration(), + loginAuthenticationToken.getAuthorizationExchange())); } catch (OAuth2AuthorizationException ex) { OAuth2Error oauth2Error = ex.getError(); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken(); - Map additionalParameters = accessTokenResponse.getAdditionalParameters(); + OAuth2AccessToken accessToken = authorizationCodeAuthenticationToken.getAccessToken(); + Map additionalParameters = authorizationCodeAuthenticationToken.getAdditionalParameters(); OAuth2User oauth2User = this.userService.loadUser(new OAuth2UserRequest( - authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters)); + loginAuthenticationToken.getClientRegistration(), accessToken, additionalParameters)); Collection mappedAuthorities = this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities()); OAuth2LoginAuthenticationToken authenticationResult = new OAuth2LoginAuthenticationToken( - authorizationCodeAuthentication.getClientRegistration(), - authorizationCodeAuthentication.getAuthorizationExchange(), + loginAuthenticationToken.getClientRegistration(), + loginAuthenticationToken.getAuthorizationExchange(), oauth2User, mappedAuthorities, accessToken, - accessTokenResponse.getRefreshToken()); - authenticationResult.setDetails(authorizationCodeAuthentication.getDetails()); + authorizationCodeAuthenticationToken.getRefreshToken()); + authenticationResult.setDetails(loginAuthenticationToken.getDetails()); return authenticationResult; } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index 2ba5e36927..41ebe4a1e6 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package org.springframework.security.oauth2.client.authentication; import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import org.junit.Before; import org.junit.Test; @@ -119,4 +121,26 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { assertThat(authenticationResult.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); assertThat(authenticationResult.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); } + + // gh-5368 + @Test + public void authenticateWhenAuthorizationSuccessResponseThenAdditionalParametersIncluded() { + Map additionalParameters = new HashMap<>(); + additionalParameters.put("param1", "value1"); + additionalParameters.put("param2", "value2"); + + OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().additionalParameters(additionalParameters) + .build(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, + success().build()); + + OAuth2AuthorizationCodeAuthenticationToken authentication = (OAuth2AuthorizationCodeAuthenticationToken) this.authenticationProvider + .authenticate( + new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, authorizationExchange)); + + assertThat(authentication.getAdditionalParameters()) + .containsAllEntriesOf(accessTokenResponse.getAdditionalParameters()); + } }