Configure token-exchange via a bean

Issue gh-5199
Issue gh-11783
Closes gh-14701
This commit is contained in:
Steve Riesenberg 2024-03-07 11:03:10 -06:00
parent 85c3d0ab13
commit d6382b83dc
No known key found for this signature in database
GPG Key ID: 3D0169B18AB8F0A9
6 changed files with 208 additions and 2 deletions

View File

@ -51,11 +51,13 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.TokenExchangeOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
@ -183,7 +185,8 @@ final class OAuth2ClientConfiguration {
RefreshTokenOAuth2AuthorizedClientProvider.class,
ClientCredentialsOAuth2AuthorizedClientProvider.class,
PasswordOAuth2AuthorizedClientProvider.class,
JwtBearerOAuth2AuthorizedClientProvider.class
JwtBearerOAuth2AuthorizedClientProvider.class,
TokenExchangeOAuth2AuthorizedClientProvider.class
);
// @formatter:on
@ -255,6 +258,12 @@ final class OAuth2ClientConfiguration {
authorizedClientProviders.add(jwtBearerAuthorizedClientProvider);
}
OAuth2AuthorizedClientProvider tokenExchangeAuthorizedClientProvider = getTokenExchangeAuthorizedClientProvider(
authorizedClientProviderBeans);
if (tokenExchangeAuthorizedClientProvider != null) {
authorizedClientProviders.add(tokenExchangeAuthorizedClientProvider);
}
authorizedClientProviders.addAll(getAdditionalAuthorizedClientProviders(authorizedClientProviderBeans));
authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider(authorizedClientProviders);
}
@ -364,6 +373,25 @@ final class OAuth2ClientConfiguration {
return authorizedClientProvider;
}
private OAuth2AuthorizedClientProvider getTokenExchangeAuthorizedClientProvider(
Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
TokenExchangeOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType(
authorizedClientProviders, TokenExchangeOAuth2AuthorizedClientProvider.class);
OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> accessTokenResponseClient = getBeanOfType(
ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class,
TokenExchangeGrantRequest.class));
if (accessTokenResponseClient != null) {
if (authorizedClientProvider == null) {
authorizedClientProvider = new TokenExchangeOAuth2AuthorizedClientProvider();
}
authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
}
return authorizedClientProvider;
}
private List<OAuth2AuthorizedClientProvider> getAdditionalAuthorizedClientProviders(
Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
List<OAuth2AuthorizedClientProvider> additionalAuthorizedClientProviders = new ArrayList<>(

View File

@ -44,11 +44,13 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.TokenExchangeOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
@ -76,7 +78,8 @@ final class OAuth2AuthorizedClientManagerRegistrar implements BeanDefinitionRegi
RefreshTokenOAuth2AuthorizedClientProvider.class,
ClientCredentialsOAuth2AuthorizedClientProvider.class,
PasswordOAuth2AuthorizedClientProvider.class,
JwtBearerOAuth2AuthorizedClientProvider.class
JwtBearerOAuth2AuthorizedClientProvider.class,
TokenExchangeOAuth2AuthorizedClientProvider.class
);
// @formatter:on
@ -137,6 +140,12 @@ final class OAuth2AuthorizedClientManagerRegistrar implements BeanDefinitionRegi
authorizedClientProviders.add(jwtBearerAuthorizedClientProvider);
}
OAuth2AuthorizedClientProvider tokenExchangeAuthorizedClientProvider = getTokenExchangeAuthorizedClientProvider(
authorizedClientProviderBeans);
if (tokenExchangeAuthorizedClientProvider != null) {
authorizedClientProviders.add(tokenExchangeAuthorizedClientProvider);
}
authorizedClientProviders.addAll(getAdditionalAuthorizedClientProviders(authorizedClientProviderBeans));
authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider(authorizedClientProviders);
}
@ -245,6 +254,25 @@ final class OAuth2AuthorizedClientManagerRegistrar implements BeanDefinitionRegi
return authorizedClientProvider;
}
private OAuth2AuthorizedClientProvider getTokenExchangeAuthorizedClientProvider(
Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
TokenExchangeOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType(
authorizedClientProviders, TokenExchangeOAuth2AuthorizedClientProvider.class);
OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> accessTokenResponseClient = getBeanOfType(
ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class,
TokenExchangeGrantRequest.class));
if (accessTokenResponseClient != null) {
if (authorizedClientProvider == null) {
authorizedClientProvider = new TokenExchangeOAuth2AuthorizedClientProvider();
}
authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
}
return authorizedClientProvider;
}
private List<OAuth2AuthorizedClientProvider> getAdditionalAuthorizedClientProviders(
Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
List<OAuth2AuthorizedClientProvider> additionalAuthorizedClientProviders = new ArrayList<>(

View File

@ -52,6 +52,7 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.TokenExchangeOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.endpoint.AbstractOAuth2AuthorizationGrantRequest;
import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
@ -59,6 +60,7 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCo
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
@ -70,6 +72,7 @@ import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
@ -327,6 +330,47 @@ public class OAuth2AuthorizedClientManagerConfigurationTests {
assertThat(grantRequest.getJwt().getSubject()).isEqualTo("user");
}
@Test
public void authorizeWhenTokenExchangeAccessTokenResponseClientBeanThenUsed() {
this.spring.register(CustomAccessTokenResponseClientsConfig.class).autowire();
testTokenExchangeGrant();
}
@Test
public void authorizeWhenTokenExchangeAuthorizedClientProviderBeanThenUsed() {
this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire();
testTokenExchangeGrant();
}
private void testTokenExchangeGrant() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(TokenExchangeGrantRequest.class)))
.willReturn(accessTokenResponse);
JwtAuthenticationToken authentication = new JwtAuthenticationToken(getJwt());
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("auth0");
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(clientRegistration.getRegistrationId())
.principal(authentication)
.attribute(HttpServletRequest.class.getName(), this.request)
.attribute(HttpServletResponse.class.getName(), this.response)
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
assertThat(authorizedClient).isNotNull();
ArgumentCaptor<TokenExchangeGrantRequest> grantRequestCaptor = ArgumentCaptor
.forClass(TokenExchangeGrantRequest.class);
verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture());
TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue();
assertThat(grantRequest.getClientRegistration().getRegistrationId())
.isEqualTo(clientRegistration.getRegistrationId());
assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE);
assertThat(grantRequest.getSubjectToken()).isEqualTo(authentication.getToken());
}
private static OAuth2AccessToken getExpiredAccessToken() {
Instant expiresAt = Instant.now().minusSeconds(60);
Instant issuedAt = expiresAt.minus(Duration.ofDays(1));
@ -376,6 +420,11 @@ public class OAuth2AuthorizedClientManagerConfigurationTests {
return new MockJwtBearerClient();
}
@Bean
OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> tokenExchangeTokenResponseClient() {
return new MockTokenExchangeClient();
}
@Bean
OAuth2UserService<OAuth2UserRequest, OAuth2User> oauth2UserService() {
return mock(DefaultOAuth2UserService.class);
@ -425,6 +474,13 @@ public class OAuth2AuthorizedClientManagerConfigurationTests {
return authorizedClientProvider;
}
@Bean
TokenExchangeOAuth2AuthorizedClientProvider tokenExchangeAuthorizedClientProvider() {
TokenExchangeOAuth2AuthorizedClientProvider authorizedClientProvider = new TokenExchangeOAuth2AuthorizedClientProvider();
authorizedClientProvider.setAccessTokenResponseClient(new MockTokenExchangeClient());
return authorizedClientProvider;
}
}
abstract static class OAuth2ClientBaseConfig {
@ -463,6 +519,14 @@ public class OAuth2AuthorizedClientManagerConfigurationTests {
.clientId("okta-client-id")
.clientSecret("okta-client-secret")
.authorizationGrantType(AuthorizationGrantType.JWT_BEARER)
.build(),
ClientRegistration.withRegistrationId("auth0")
.clientName("Auth0")
.clientId("auth0-client-id")
.clientSecret("auth0-client-secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC)
.authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE)
.scope("user.read", "user.write")
.build()));
// @formatter:on
}
@ -544,4 +608,13 @@ public class OAuth2AuthorizedClientManagerConfigurationTests {
}
private static class MockTokenExchangeClient implements OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> {
@Override
public OAuth2AccessTokenResponse getTokenResponse(TokenExchangeGrantRequest authorizationGrantRequest) {
return MOCK_RESPONSE_CLIENT.getTokenResponse(authorizationGrantRequest);
}
}
}

View File

@ -49,6 +49,7 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.TokenExchangeOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.endpoint.AbstractOAuth2AuthorizationGrantRequest;
import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
@ -56,11 +57,13 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCo
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
@ -316,6 +319,47 @@ public class OAuth2AuthorizedClientManagerRegistrarTests {
assertThat(grantRequest.getJwt().getSubject()).isEqualTo("user");
}
@Test
public void authorizeWhenTokenExchangeAccessTokenResponseClientBeanThenUsed() {
this.spring.configLocations(xml("clients")).autowire();
testTokenExchangeGrant();
}
@Test
public void authorizeWhenTokenExchangeAuthorizedClientProviderBeanThenUsed() {
this.spring.configLocations(xml("providers")).autowire();
testTokenExchangeGrant();
}
private void testTokenExchangeGrant() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(TokenExchangeGrantRequest.class)))
.willReturn(accessTokenResponse);
JwtAuthenticationToken authentication = new JwtAuthenticationToken(getJwt());
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("auth0");
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(clientRegistration.getRegistrationId())
.principal(authentication)
.attribute(HttpServletRequest.class.getName(), this.request)
.attribute(HttpServletResponse.class.getName(), this.response)
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
assertThat(authorizedClient).isNotNull();
ArgumentCaptor<TokenExchangeGrantRequest> grantRequestCaptor = ArgumentCaptor
.forClass(TokenExchangeGrantRequest.class);
verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture());
TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue();
assertThat(grantRequest.getClientRegistration().getRegistrationId())
.isEqualTo(clientRegistration.getRegistrationId());
assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE);
assertThat(grantRequest.getSubjectToken()).isEqualTo(authentication.getToken());
}
private static OAuth2AccessToken getExpiredAccessToken() {
Instant expiresAt = Instant.now().minusSeconds(60);
Instant issuedAt = expiresAt.minus(Duration.ofDays(1));
@ -356,6 +400,14 @@ public class OAuth2AuthorizedClientManagerRegistrarTests {
.clientId("okta-client-id")
.clientSecret("okta-client-secret")
.authorizationGrantType(AuthorizationGrantType.JWT_BEARER)
.build(),
ClientRegistration.withRegistrationId("auth0")
.clientName("Auth0")
.clientId("auth0-client-id")
.clientSecret("auth0-client-secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC)
.authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE)
.scope("user.read", "user.write")
.build());
// @formatter:on
}
@ -422,6 +474,16 @@ public class OAuth2AuthorizedClientManagerRegistrarTests {
return new MockJwtBearerClient();
}
public static TokenExchangeOAuth2AuthorizedClientProvider tokenExchangeAuthorizedClientProvider() {
TokenExchangeOAuth2AuthorizedClientProvider authorizedClientProvider = new TokenExchangeOAuth2AuthorizedClientProvider();
authorizedClientProvider.setAccessTokenResponseClient(tokenExchangeAccessTokenResponseClient());
return authorizedClientProvider;
}
public static OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> tokenExchangeAccessTokenResponseClient() {
return new MockTokenExchangeClient();
}
private static class MockAuthorizationCodeClient
implements OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> {
@ -472,4 +534,13 @@ public class OAuth2AuthorizedClientManagerRegistrarTests {
}
private static class MockTokenExchangeClient implements OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> {
@Override
public OAuth2AccessTokenResponse getTokenResponse(TokenExchangeGrantRequest authorizationGrantRequest) {
return MOCK_RESPONSE_CLIENT.getTokenResponse(authorizationGrantRequest);
}
}
}

View File

@ -53,4 +53,7 @@
<b:bean class="org.springframework.security.config.http.OAuth2AuthorizedClientManagerRegistrarTests"
factory-method="jwtBearerAccessTokenResponseClient"/>
<b:bean class="org.springframework.security.config.http.OAuth2AuthorizedClientManagerRegistrarTests"
factory-method="tokenExchangeAccessTokenResponseClient"/>
</b:beans>

View File

@ -56,4 +56,7 @@
<b:bean class="org.springframework.security.config.http.OAuth2AuthorizedClientManagerRegistrarTests"
factory-method="jwtBearerAuthorizedClientProvider"/>
<b:bean class="org.springframework.security.config.http.OAuth2AuthorizedClientManagerRegistrarTests"
factory-method="tokenExchangeAuthorizedClientProvider"/>
</b:beans>