Improve OAuth2LoginAuthenticationProvider

1. update OAuth2LoginAuthenticationProvider to use
OAuth2AuthorizationCodeAuthenticationProvider
2. apply fix gh-5368 for OAuth2AuthorizationCodeAuthenticationProvider
to return additionalParameters value from accessTokenResponse

Fixes gh-5633
This commit is contained in:
Ruby Hartono 2020-03-19 19:54:12 +07:00 committed by Joe Grandja
parent 4c040e9e8e
commit 45eb34c9a6
3 changed files with 45 additions and 26 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -73,7 +73,8 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
authorizationCodeAuthentication.getClientRegistration(), authorizationCodeAuthentication.getClientRegistration(),
authorizationCodeAuthentication.getAuthorizationExchange(), authorizationCodeAuthentication.getAuthorizationExchange(),
accessTokenResponse.getAccessToken(), accessTokenResponse.getAccessToken(),
accessTokenResponse.getRefreshToken()); accessTokenResponse.getRefreshToken(),
accessTokenResponse.getAdditionalParameters());
authenticationResult.setDetails(authorizationCodeAuthentication.getDetails()); authenticationResult.setDetails(authorizationCodeAuthentication.getDetails());
return authenticationResult; return authenticationResult;

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,7 +28,6 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -60,7 +59,7 @@ import java.util.Map;
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.4">Section 4.1.4 Access Token Response</a> * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.4">Section 4.1.4 Access Token Response</a>
*/ */
public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider { public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider {
private final OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient; private final OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider;
private final OAuth2UserService<OAuth2UserRequest, OAuth2User> userService; private final OAuth2UserService<OAuth2UserRequest, OAuth2User> userService;
private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities); private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities);
@ -74,59 +73,54 @@ public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider
OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient, OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient,
OAuth2UserService<OAuth2UserRequest, OAuth2User> userService) { OAuth2UserService<OAuth2UserRequest, OAuth2User> userService) {
Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null");
Assert.notNull(userService, "userService cannot be null"); Assert.notNull(userService, "userService cannot be null");
this.accessTokenResponseClient = accessTokenResponseClient; this.authorizationCodeAuthenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(accessTokenResponseClient);
this.userService = userService; this.userService = userService;
} }
@Override @Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException { public Authentication authenticate(Authentication authentication) throws AuthenticationException {
OAuth2LoginAuthenticationToken authorizationCodeAuthentication = OAuth2LoginAuthenticationToken loginAuthenticationToken =
(OAuth2LoginAuthenticationToken) authentication; (OAuth2LoginAuthenticationToken) authentication;
// Section 3.1.2.1 Authentication Request - https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest // Section 3.1.2.1 Authentication Request - https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
// scope // scope
// REQUIRED. OpenID Connect requests MUST contain the "openid" scope value. // REQUIRED. OpenID Connect requests MUST contain the "openid" scope value.
if (authorizationCodeAuthentication.getAuthorizationExchange() if (loginAuthenticationToken.getAuthorizationExchange()
.getAuthorizationRequest().getScopes().contains("openid")) { .getAuthorizationRequest().getScopes().contains("openid")) {
// This is an OpenID Connect Authentication Request so return null // This is an OpenID Connect Authentication Request so return null
// and let OidcAuthorizationCodeAuthenticationProvider handle it instead // and let OidcAuthorizationCodeAuthenticationProvider handle it instead
return null; return null;
} }
OAuth2AccessTokenResponse accessTokenResponse; OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthenticationToken;
try { try {
OAuth2AuthorizationExchangeValidator.validate( authorizationCodeAuthenticationToken = (OAuth2AuthorizationCodeAuthenticationToken) this.authorizationCodeAuthenticationProvider
authorizationCodeAuthentication.getAuthorizationExchange()); .authenticate(new OAuth2AuthorizationCodeAuthenticationToken(
loginAuthenticationToken.getClientRegistration(),
accessTokenResponse = this.accessTokenResponseClient.getTokenResponse( loginAuthenticationToken.getAuthorizationExchange()));
new OAuth2AuthorizationCodeGrantRequest(
authorizationCodeAuthentication.getClientRegistration(),
authorizationCodeAuthentication.getAuthorizationExchange()));
} catch (OAuth2AuthorizationException ex) { } catch (OAuth2AuthorizationException ex) {
OAuth2Error oauth2Error = ex.getError(); OAuth2Error oauth2Error = ex.getError();
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
} }
OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken(); OAuth2AccessToken accessToken = authorizationCodeAuthenticationToken.getAccessToken();
Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters(); Map<String, Object> additionalParameters = authorizationCodeAuthenticationToken.getAdditionalParameters();
OAuth2User oauth2User = this.userService.loadUser(new OAuth2UserRequest( OAuth2User oauth2User = this.userService.loadUser(new OAuth2UserRequest(
authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters)); loginAuthenticationToken.getClientRegistration(), accessToken, additionalParameters));
Collection<? extends GrantedAuthority> mappedAuthorities = Collection<? extends GrantedAuthority> mappedAuthorities =
this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities()); this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities());
OAuth2LoginAuthenticationToken authenticationResult = new OAuth2LoginAuthenticationToken( OAuth2LoginAuthenticationToken authenticationResult = new OAuth2LoginAuthenticationToken(
authorizationCodeAuthentication.getClientRegistration(), loginAuthenticationToken.getClientRegistration(),
authorizationCodeAuthentication.getAuthorizationExchange(), loginAuthenticationToken.getAuthorizationExchange(),
oauth2User, oauth2User,
mappedAuthorities, mappedAuthorities,
accessToken, accessToken,
accessTokenResponse.getRefreshToken()); authorizationCodeAuthenticationToken.getRefreshToken());
authenticationResult.setDetails(authorizationCodeAuthentication.getDetails()); authenticationResult.setDetails(loginAuthenticationToken.getDetails());
return authenticationResult; return authenticationResult;
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,6 +16,8 @@
package org.springframework.security.oauth2.client.authentication; package org.springframework.security.oauth2.client.authentication;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -119,4 +121,26 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
assertThat(authenticationResult.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); assertThat(authenticationResult.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(authenticationResult.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); assertThat(authenticationResult.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
} }
// gh-5368
@Test
public void authenticateWhenAuthorizationSuccessResponseThenAdditionalParametersIncluded() {
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2");
OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().additionalParameters(additionalParameters)
.build();
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
success().build());
OAuth2AuthorizationCodeAuthenticationToken authentication = (OAuth2AuthorizationCodeAuthenticationToken) this.authenticationProvider
.authenticate(
new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, authorizationExchange));
assertThat(authentication.getAdditionalParameters())
.containsAllEntriesOf(accessTokenResponse.getAdditionalParameters());
}
} }