mirror of
https://github.com/spring-projects/spring-security.git
synced 2025-05-31 01:02:14 +00:00
Ensure ID Token is updated after refresh token
Signed-off-by: Hao <kyrieeeee2@gmail.com>
This commit is contained in:
parent
ece7489f5b
commit
fc1469ad5e
@ -34,6 +34,9 @@ import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
|
||||
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
|
||||
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
|
||||
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
|
||||
import org.springframework.context.ApplicationContext;
|
||||
import org.springframework.context.ApplicationContextAware;
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
import org.springframework.context.annotation.AnnotationBeanNameGenerator;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
@ -160,7 +163,7 @@ final class OAuth2ClientConfiguration {
|
||||
* @since 6.2.0
|
||||
*/
|
||||
static final class OAuth2AuthorizedClientManagerRegistrar
|
||||
implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware {
|
||||
implements ApplicationContextAware, BeanDefinitionRegistryPostProcessor, BeanFactoryAware {
|
||||
|
||||
static final String BEAN_NAME = "authorizedClientManagerRegistrar";
|
||||
|
||||
@ -179,6 +182,8 @@ final class OAuth2ClientConfiguration {
|
||||
|
||||
private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator();
|
||||
|
||||
private ApplicationEventPublisher eventPublisher;
|
||||
|
||||
private ListableBeanFactory beanFactory;
|
||||
|
||||
@Override
|
||||
@ -302,6 +307,10 @@ final class OAuth2ClientConfiguration {
|
||||
authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
|
||||
}
|
||||
|
||||
if (this.eventPublisher != null) {
|
||||
authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher);
|
||||
}
|
||||
|
||||
return authorizedClientProvider;
|
||||
}
|
||||
|
||||
@ -423,6 +432,11 @@ final class OAuth2ClientConfiguration {
|
||||
return objectProvider.getIfAvailable();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
|
||||
this.eventPublisher = applicationContext;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -57,6 +57,7 @@ import org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationC
|
||||
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
|
||||
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
|
||||
import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider;
|
||||
import org.springframework.security.oauth2.client.oidc.authentication.RefreshOidcIdTokenHandler;
|
||||
import org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry;
|
||||
import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation;
|
||||
import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry;
|
||||
@ -394,6 +395,15 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>>
|
||||
oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper);
|
||||
}
|
||||
http.authenticationProvider(this.postProcess(oidcAuthorizationCodeAuthenticationProvider));
|
||||
|
||||
RefreshOidcIdTokenHandler refreshOidcIdTokenHandler = new RefreshOidcIdTokenHandler();
|
||||
if (this.getSecurityContextHolderStrategy() != null) {
|
||||
refreshOidcIdTokenHandler.setSecurityContextHolderStrategy(this.getSecurityContextHolderStrategy());
|
||||
}
|
||||
if (jwtDecoderFactory != null) {
|
||||
refreshOidcIdTokenHandler.setJwtDecoderFactory(jwtDecoderFactory);
|
||||
}
|
||||
registerDelegateApplicationListener(refreshOidcIdTokenHandler);
|
||||
}
|
||||
else {
|
||||
http.authenticationProvider(new OidcAuthenticationRequestChecker());
|
||||
|
@ -34,6 +34,7 @@ import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
|
||||
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
|
||||
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
|
||||
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
import org.springframework.context.annotation.AnnotationBeanNameGenerator;
|
||||
import org.springframework.core.ResolvableType;
|
||||
import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider;
|
||||
@ -197,6 +198,12 @@ final class OAuth2AuthorizedClientManagerRegistrar implements BeanDefinitionRegi
|
||||
authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
|
||||
}
|
||||
|
||||
ApplicationEventPublisher applicationEventPublisher = getBeanOfType(
|
||||
ResolvableType.forClass(ApplicationEventPublisher.class));
|
||||
if (applicationEventPublisher != null) {
|
||||
authorizedClientProvider.setApplicationEventPublisher(applicationEventPublisher);
|
||||
}
|
||||
|
||||
return authorizedClientProvider;
|
||||
}
|
||||
|
||||
|
@ -25,6 +25,7 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
|
||||
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
|
||||
import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
|
||||
@ -359,6 +360,8 @@ public final class OAuth2AuthorizedClientProviderBuilder {
|
||||
|
||||
private OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient;
|
||||
|
||||
private ApplicationEventPublisher eventPublisher;
|
||||
|
||||
private Duration clockSkew;
|
||||
|
||||
private Clock clock;
|
||||
@ -379,6 +382,17 @@ public final class OAuth2AuthorizedClientProviderBuilder {
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@link ApplicationEventPublisher} used when an access token is
|
||||
* refreshed.
|
||||
* @param eventPublisher the {@link ApplicationEventPublisher}
|
||||
* @return the {@link RefreshTokenGrantBuilder}
|
||||
*/
|
||||
public RefreshTokenGrantBuilder eventPublisher(ApplicationEventPublisher eventPublisher) {
|
||||
this.eventPublisher = eventPublisher;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the maximum acceptable clock skew, which is used when checking the access
|
||||
* token expiry. An access token is considered expired if
|
||||
@ -414,6 +428,9 @@ public final class OAuth2AuthorizedClientProviderBuilder {
|
||||
if (this.accessTokenResponseClient != null) {
|
||||
authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient);
|
||||
}
|
||||
if (this.eventPublisher != null) {
|
||||
authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher);
|
||||
}
|
||||
if (this.clockSkew != null) {
|
||||
authorizedClientProvider.setClockSkew(this.clockSkew);
|
||||
}
|
||||
|
@ -24,10 +24,13 @@ import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
import org.springframework.context.ApplicationEventPublisherAware;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient;
|
||||
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
|
||||
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
|
||||
import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
|
||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
|
||||
import org.springframework.security.oauth2.core.OAuth2Token;
|
||||
@ -43,10 +46,13 @@ import org.springframework.util.Assert;
|
||||
* @see OAuth2AuthorizedClientProvider
|
||||
* @see DefaultRefreshTokenTokenResponseClient
|
||||
*/
|
||||
public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider {
|
||||
public final class RefreshTokenOAuth2AuthorizedClientProvider
|
||||
implements OAuth2AuthorizedClientProvider, ApplicationEventPublisherAware {
|
||||
|
||||
private OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient();
|
||||
|
||||
private ApplicationEventPublisher eventPublisher;
|
||||
|
||||
private Duration clockSkew = Duration.ofSeconds(60);
|
||||
|
||||
private Clock clock = Clock.systemUTC();
|
||||
@ -91,8 +97,17 @@ public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2A
|
||||
authorizedClient.getClientRegistration(), authorizedClient.getAccessToken(),
|
||||
authorizedClient.getRefreshToken(), scopes);
|
||||
OAuth2AccessTokenResponse tokenResponse = getTokenResponse(authorizedClient, refreshTokenGrantRequest);
|
||||
return new OAuth2AuthorizedClient(context.getAuthorizedClient().getClientRegistration(),
|
||||
context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken());
|
||||
|
||||
OAuth2AuthorizedClient updatedOAuth2AuthorizedClient = new OAuth2AuthorizedClient(
|
||||
authorizedClient.getClientRegistration(), context.getPrincipal().getName(),
|
||||
tokenResponse.getAccessToken(), tokenResponse.getRefreshToken());
|
||||
|
||||
if (this.eventPublisher != null) {
|
||||
this.eventPublisher
|
||||
.publishEvent(new OAuth2TokenRefreshedEvent(this, updatedOAuth2AuthorizedClient, tokenResponse));
|
||||
}
|
||||
|
||||
return updatedOAuth2AuthorizedClient;
|
||||
}
|
||||
|
||||
private OAuth2AccessTokenResponse getTokenResponse(OAuth2AuthorizedClient authorizedClient,
|
||||
@ -149,4 +164,9 @@ public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2A
|
||||
this.clock = clock;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
|
||||
this.eventPublisher = applicationEventPublisher;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,47 @@
|
||||
/*
|
||||
* Copyright 2002-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.security.oauth2.client.event;
|
||||
|
||||
import org.springframework.context.ApplicationEvent;
|
||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
|
||||
|
||||
/**
|
||||
* An event that is published when an OAuth2 access token is refreshed.
|
||||
*/
|
||||
public class OAuth2TokenRefreshedEvent extends ApplicationEvent {
|
||||
|
||||
private final OAuth2AuthorizedClient authorizedClient;
|
||||
|
||||
private final OAuth2AccessTokenResponse accessTokenResponse;
|
||||
|
||||
public OAuth2TokenRefreshedEvent(Object source, OAuth2AuthorizedClient authorizedClient,
|
||||
OAuth2AccessTokenResponse accessTokenResponse) {
|
||||
super(source);
|
||||
this.authorizedClient = authorizedClient;
|
||||
this.accessTokenResponse = accessTokenResponse;
|
||||
}
|
||||
|
||||
public OAuth2AuthorizedClient getAuthorizedClient() {
|
||||
return this.authorizedClient;
|
||||
}
|
||||
|
||||
public OAuth2AccessTokenResponse getAccessTokenResponse() {
|
||||
return this.accessTokenResponse;
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,139 @@
|
||||
/*
|
||||
* Copyright 2002-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.security.oauth2.client.oidc.authentication;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.context.ApplicationListener;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.context.SecurityContext;
|
||||
import org.springframework.security.core.context.SecurityContextHolder;
|
||||
import org.springframework.security.core.context.SecurityContextHolderStrategy;
|
||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
||||
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
|
||||
import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
||||
import org.springframework.security.oauth2.core.OAuth2Error;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
|
||||
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
|
||||
import org.springframework.security.oauth2.core.oidc.OidcScopes;
|
||||
import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
|
||||
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
|
||||
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
|
||||
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
|
||||
import org.springframework.security.oauth2.jwt.Jwt;
|
||||
import org.springframework.security.oauth2.jwt.JwtDecoder;
|
||||
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
|
||||
import org.springframework.security.oauth2.jwt.JwtException;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
* An {@link ApplicationListener} that listens for {@link OAuth2TokenRefreshedEvent}s
|
||||
*/
|
||||
public class RefreshOidcIdTokenHandler implements ApplicationListener<OAuth2TokenRefreshedEvent> {
|
||||
|
||||
private static final String MISSING_ID_TOKEN_ERROR_CODE = "missing_id_token";
|
||||
|
||||
private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token";
|
||||
|
||||
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
|
||||
.getContextHolderStrategy();
|
||||
|
||||
private JwtDecoderFactory<ClientRegistration> jwtDecoderFactory = new OidcIdTokenDecoderFactory();
|
||||
|
||||
@Override
|
||||
public void onApplicationEvent(OAuth2TokenRefreshedEvent event) {
|
||||
OAuth2AuthorizedClient authorizedClient = event.getAuthorizedClient();
|
||||
|
||||
if (!authorizedClient.getClientRegistration().getScopes().contains(OidcScopes.OPENID)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
|
||||
if (!(authentication instanceof OAuth2AuthenticationToken oauth2Authentication)) {
|
||||
return;
|
||||
}
|
||||
if (!(authentication.getPrincipal() instanceof DefaultOidcUser defaultOidcUser)) {
|
||||
return;
|
||||
}
|
||||
|
||||
OAuth2AccessTokenResponse accessTokenResponse = event.getAccessTokenResponse();
|
||||
|
||||
String idToken = (String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN);
|
||||
if (idToken == null || idToken.isBlank()) {
|
||||
OAuth2Error missingIdTokenError = new OAuth2Error(MISSING_ID_TOKEN_ERROR_CODE,
|
||||
"ID token is missing in the token response", null);
|
||||
throw new OAuth2AuthenticationException(missingIdTokenError, missingIdTokenError.toString());
|
||||
}
|
||||
|
||||
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
|
||||
OidcIdToken refreshedOidcToken = createOidcToken(clientRegistration, accessTokenResponse);
|
||||
updateSecurityContext(oauth2Authentication, defaultOidcUser, refreshedOidcToken);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
|
||||
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
|
||||
*/
|
||||
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
|
||||
this.securityContextHolderStrategy = securityContextHolderStrategy;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@link JwtDecoderFactory} used for {@link OidcIdToken} signature
|
||||
* verification. The factory returns a {@link JwtDecoder} associated to the provided
|
||||
* {@link ClientRegistration}.
|
||||
* @param jwtDecoderFactory the {@link JwtDecoderFactory} used for {@link OidcIdToken}
|
||||
* signature verification
|
||||
*/
|
||||
public final void setJwtDecoderFactory(JwtDecoderFactory<ClientRegistration> jwtDecoderFactory) {
|
||||
Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null");
|
||||
this.jwtDecoderFactory = jwtDecoderFactory;
|
||||
}
|
||||
|
||||
private void updateSecurityContext(OAuth2AuthenticationToken oauth2Authentication, DefaultOidcUser defaultOidcUser,
|
||||
OidcIdToken refreshedOidcToken) {
|
||||
OidcUser oidcUser = new DefaultOidcUser(defaultOidcUser.getAuthorities(), refreshedOidcToken,
|
||||
defaultOidcUser.getUserInfo(), StandardClaimNames.SUB);
|
||||
|
||||
SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
|
||||
context.setAuthentication(new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(),
|
||||
oauth2Authentication.getAuthorizedClientRegistrationId()));
|
||||
|
||||
this.securityContextHolderStrategy.setContext(context);
|
||||
}
|
||||
|
||||
private OidcIdToken createOidcToken(ClientRegistration clientRegistration,
|
||||
OAuth2AccessTokenResponse accessTokenResponse) {
|
||||
JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
|
||||
Jwt jwt = getJwt(accessTokenResponse, jwtDecoder);
|
||||
return new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims());
|
||||
}
|
||||
|
||||
private Jwt getJwt(OAuth2AccessTokenResponse accessTokenResponse, JwtDecoder jwtDecoder) {
|
||||
try {
|
||||
Map<String, Object> parameters = accessTokenResponse.getAdditionalParameters();
|
||||
return jwtDecoder.decode((String) parameters.get(OidcParameterNames.ID_TOKEN));
|
||||
}
|
||||
catch (JwtException ex) {
|
||||
OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(), null);
|
||||
throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -25,10 +25,12 @@ import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
import org.springframework.security.authentication.TestingAuthenticationToken;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
|
||||
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
|
||||
import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
|
||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||
@ -251,4 +253,55 @@ public class RefreshTokenOAuth2AuthorizedClientProviderTests {
|
||||
+ OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldPublishEventWhenTokenRefreshed() {
|
||||
OAuth2TokenRefreshedAwareEventPublisher eventPublisher = new OAuth2TokenRefreshedAwareEventPublisher();
|
||||
this.authorizedClientProvider.setApplicationEventPublisher(eventPublisher);
|
||||
// @formatter:off
|
||||
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
|
||||
.accessTokenResponse()
|
||||
.refreshToken("new-refresh-token")
|
||||
.build();
|
||||
// @formatter:on
|
||||
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
|
||||
// @formatter:off
|
||||
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
|
||||
.withAuthorizedClient(this.authorizedClient)
|
||||
.principal(this.principal)
|
||||
.build();
|
||||
// @formatter:on
|
||||
this.authorizedClientProvider.authorize(authorizationContext);
|
||||
assertThat(eventPublisher.flag).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldNotPublishEventWhenTokenNotRefreshed() {
|
||||
OAuth2TokenRefreshedAwareEventPublisher eventPublisher = new OAuth2TokenRefreshedAwareEventPublisher();
|
||||
this.authorizedClientProvider.setApplicationEventPublisher(eventPublisher);
|
||||
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
|
||||
this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken());
|
||||
// @formatter:off
|
||||
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
|
||||
.withAuthorizedClient(authorizedClient)
|
||||
.principal(this.principal)
|
||||
.build();
|
||||
// @formatter:on
|
||||
this.authorizedClientProvider.authorize(authorizationContext);
|
||||
assertThat(eventPublisher.flag).isFalse();
|
||||
}
|
||||
|
||||
private static class OAuth2TokenRefreshedAwareEventPublisher implements ApplicationEventPublisher {
|
||||
|
||||
Boolean flag = false;
|
||||
|
||||
@Override
|
||||
public void publishEvent(Object event) {
|
||||
if (OAuth2TokenRefreshedEvent.class.isAssignableFrom(event.getClass())) {
|
||||
this.flag = true;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,284 @@
|
||||
/*
|
||||
* Copyright 2002-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.security.oauth2.client.oidc.authentication;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
|
||||
import org.springframework.security.authentication.TestingAuthenticationToken;
|
||||
import org.springframework.security.core.authority.AuthorityUtils;
|
||||
import org.springframework.security.core.context.SecurityContext;
|
||||
import org.springframework.security.core.context.SecurityContextHolderStrategy;
|
||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
||||
import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
|
||||
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
|
||||
import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
|
||||
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
|
||||
import org.springframework.security.oauth2.core.oidc.OidcScopes;
|
||||
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
|
||||
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
|
||||
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
|
||||
import org.springframework.security.oauth2.jwt.Jwt;
|
||||
import org.springframework.security.oauth2.jwt.JwtDecoder;
|
||||
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
|
||||
import org.springframework.security.oauth2.jwt.JwtException;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.BDDMockito.given;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
class RefreshOidcIdTokenHandlerTests {
|
||||
|
||||
private static final String EXISTING_ID_TOKEN_VALUE = "id-token-value";
|
||||
|
||||
private static final String REFRESHED_ID_TOKEN_VALUE = "new-id-token-value";
|
||||
|
||||
private static final String EXISTING_ACCESS_TOKEN_VALUE = "token-value";
|
||||
|
||||
private static final String REFRESHED_ACCESS_TOKEN_VALUE = "new-token-value";
|
||||
|
||||
private RefreshOidcIdTokenHandler handler;
|
||||
|
||||
private RefreshTokenOAuth2AuthorizedClientProvider provider;
|
||||
|
||||
private ClientRegistration clientRegistration;
|
||||
|
||||
private OAuth2AuthorizedClient authorizedClient;
|
||||
|
||||
private JwtDecoder jwtDecoder;
|
||||
|
||||
private SecurityContext securityContext;
|
||||
|
||||
private OidcIdToken existingIdToken;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
this.handler = new RefreshOidcIdTokenHandler();
|
||||
|
||||
this.clientRegistration = createClientRegistrationWithScopes(OidcScopes.OPENID);
|
||||
this.authorizedClient = createAuthorizedClient(this.clientRegistration);
|
||||
|
||||
this.provider = mock(RefreshTokenOAuth2AuthorizedClientProvider.class);
|
||||
|
||||
JwtDecoderFactory<ClientRegistration> jwtDecoderFactory = mock(JwtDecoderFactory.class);
|
||||
this.jwtDecoder = mock(JwtDecoder.class);
|
||||
SecurityContextHolderStrategy securityContextHolderStrategy = mock(SecurityContextHolderStrategy.class);
|
||||
this.securityContext = mock(SecurityContext.class);
|
||||
|
||||
this.handler.setJwtDecoderFactory(jwtDecoderFactory);
|
||||
this.handler.setSecurityContextHolderStrategy(securityContextHolderStrategy);
|
||||
|
||||
given(jwtDecoderFactory.createDecoder(any())).willReturn(this.jwtDecoder);
|
||||
given(securityContextHolderStrategy.createEmptyContext()).willReturn(this.securityContext);
|
||||
given(securityContextHolderStrategy.getContext()).willReturn(this.securityContext);
|
||||
|
||||
Map<String, Object> claims = new HashMap<>();
|
||||
claims.put("sub", "subject");
|
||||
Jwt existingIdTokenJwt = new Jwt(EXISTING_ID_TOKEN_VALUE, Instant.now(), Instant.now().plusSeconds(3600),
|
||||
Map.of("alg", "RS256"), claims);
|
||||
Jwt refreshedIdTokenJwt = new Jwt(REFRESHED_ID_TOKEN_VALUE, Instant.now(), Instant.now().plusSeconds(3600),
|
||||
Map.of("alg", "RS256"), claims);
|
||||
|
||||
this.existingIdToken = new OidcIdToken(existingIdTokenJwt.getTokenValue(), existingIdTokenJwt.getIssuedAt(),
|
||||
existingIdTokenJwt.getExpiresAt(), existingIdTokenJwt.getClaims());
|
||||
|
||||
given(this.jwtDecoder.decode(existingIdTokenJwt.getTokenValue())).willReturn(existingIdTokenJwt);
|
||||
given(this.jwtDecoder.decode(refreshedIdTokenJwt.getTokenValue())).willReturn(refreshedIdTokenJwt);
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleEventWhenValidIdTokenThenUpdatesSecurityContext() {
|
||||
|
||||
DefaultOidcUser existingUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"),
|
||||
this.existingIdToken);
|
||||
OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(existingUser,
|
||||
existingUser.getAuthorities(), "registration-id");
|
||||
given(this.securityContext.getAuthentication()).willReturn(existingAuth);
|
||||
|
||||
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
|
||||
.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
|
||||
.tokenType(OAuth2AccessToken.TokenType.BEARER)
|
||||
.expiresIn(3600)
|
||||
.additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE))
|
||||
.build();
|
||||
|
||||
OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
|
||||
accessTokenResponse);
|
||||
this.handler.onApplicationEvent(event);
|
||||
|
||||
ArgumentCaptor<OAuth2AuthenticationToken> authenticationCaptor = ArgumentCaptor
|
||||
.forClass(OAuth2AuthenticationToken.class);
|
||||
verify(this.securityContext).setAuthentication(authenticationCaptor.capture());
|
||||
|
||||
OAuth2AuthenticationToken newAuthentication = authenticationCaptor.getValue();
|
||||
assertThat(newAuthentication.getPrincipal()).isInstanceOf(DefaultOidcUser.class);
|
||||
DefaultOidcUser newUser = (DefaultOidcUser) newAuthentication.getPrincipal();
|
||||
assertThat(newUser.getIdToken().getTokenValue()).isEqualTo(REFRESHED_ID_TOKEN_VALUE);
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleEventWhenAuthorizedClientIsNotOidcThenDoesNothing() {
|
||||
|
||||
this.clientRegistration = createClientRegistrationWithScopes("read");
|
||||
this.authorizedClient = createAuthorizedClient(this.clientRegistration);
|
||||
|
||||
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
|
||||
.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
|
||||
.tokenType(OAuth2AccessToken.TokenType.BEARER)
|
||||
.expiresIn(3600)
|
||||
.additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE))
|
||||
.build();
|
||||
|
||||
OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
|
||||
accessTokenResponse);
|
||||
|
||||
this.handler.onApplicationEvent(event);
|
||||
|
||||
verify(this.securityContext, never()).setAuthentication(any());
|
||||
verify(this.jwtDecoder, never()).decode(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleEventWhenAuthenticationNotOAuth2AuthenticationTokenThenDoesNothing() {
|
||||
|
||||
given(this.securityContext.getAuthentication()).willReturn(mock(TestingAuthenticationToken.class));
|
||||
|
||||
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
|
||||
.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
|
||||
.tokenType(OAuth2AccessToken.TokenType.BEARER)
|
||||
.expiresIn(3600)
|
||||
.additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE))
|
||||
.build();
|
||||
|
||||
OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
|
||||
accessTokenResponse);
|
||||
|
||||
this.handler.onApplicationEvent(event);
|
||||
|
||||
verify(this.securityContext, never()).setAuthentication(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleEventWhenNotOidcUserThenDoesNothing() {
|
||||
|
||||
OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(
|
||||
new DefaultOAuth2User(Collections.emptySet(),
|
||||
Collections.singletonMap("custom-attribute", "test-subject"), "custom-attribute"),
|
||||
AuthorityUtils.createAuthorityList("ROLE_USER"), "registration-id");
|
||||
given(this.securityContext.getAuthentication()).willReturn(existingAuth);
|
||||
|
||||
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
|
||||
.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
|
||||
.tokenType(OAuth2AccessToken.TokenType.BEARER)
|
||||
.expiresIn(3600)
|
||||
.additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE))
|
||||
.build();
|
||||
|
||||
OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
|
||||
accessTokenResponse);
|
||||
|
||||
this.handler.onApplicationEvent(event);
|
||||
|
||||
verify(this.securityContext, never()).setAuthentication(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleEventWhenMissingIdTokenThenThrowsException() {
|
||||
|
||||
DefaultOidcUser existingUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"),
|
||||
this.existingIdToken);
|
||||
OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(existingUser,
|
||||
existingUser.getAuthorities(), "registration-id");
|
||||
given(this.securityContext.getAuthentication()).willReturn(existingAuth);
|
||||
|
||||
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
|
||||
.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
|
||||
.tokenType(OAuth2AccessToken.TokenType.BEARER)
|
||||
.expiresIn(3600)
|
||||
.additionalParameters(new HashMap<>()) // missing ID token
|
||||
.build();
|
||||
|
||||
OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
|
||||
accessTokenResponse);
|
||||
|
||||
assertThatExceptionOfType(OAuth2AuthenticationException.class)
|
||||
.isThrownBy(() -> this.handler.onApplicationEvent(event))
|
||||
.withMessageContaining("missing_id_token");
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleEventWhenInvalidIdTokenThenThrowsException() {
|
||||
|
||||
DefaultOidcUser existingUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"),
|
||||
this.existingIdToken);
|
||||
OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(existingUser,
|
||||
existingUser.getAuthorities(), "registration-id");
|
||||
given(this.securityContext.getAuthentication()).willReturn(existingAuth);
|
||||
|
||||
given(this.jwtDecoder.decode(any())).willThrow(new JwtException("Invalid token"));
|
||||
|
||||
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
|
||||
.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
|
||||
.tokenType(OAuth2AccessToken.TokenType.BEARER)
|
||||
.expiresIn(3600)
|
||||
.additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, "invalid-id-token"))
|
||||
.build();
|
||||
|
||||
OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
|
||||
accessTokenResponse);
|
||||
|
||||
assertThatExceptionOfType(OAuth2AuthenticationException.class)
|
||||
.isThrownBy(() -> this.handler.onApplicationEvent(event))
|
||||
.withMessageContaining("invalid_id_token");
|
||||
}
|
||||
|
||||
private ClientRegistration createClientRegistrationWithScopes(String... scope) {
|
||||
return ClientRegistration.withRegistrationId("registration-id")
|
||||
.clientId("client-id")
|
||||
.clientSecret("secret")
|
||||
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
|
||||
.redirectUri("http://localhost")
|
||||
.scope(scope)
|
||||
.authorizationUri("https://provider.com/oauth2/authorize")
|
||||
.tokenUri("https://provider.com/oauth2/token")
|
||||
.jwkSetUri("https://provider.com/jwk")
|
||||
.userInfoUri("https://provider.com/user")
|
||||
.build();
|
||||
}
|
||||
|
||||
private static OAuth2AuthorizedClient createAuthorizedClient(ClientRegistration clientRegistration) {
|
||||
return new OAuth2AuthorizedClient(clientRegistration, "principal-name",
|
||||
new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, EXISTING_ACCESS_TOKEN_VALUE, Instant.now(),
|
||||
Instant.now().plusSeconds(3600)));
|
||||
}
|
||||
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user