Ensure ID Token is updated after refresh token (Reactive)

Closes gh-17188

Signed-off-by: Evgeniy Cheban <mister.cheban@gmail.com>
This commit is contained in:
Evgeniy Cheban 2025-06-14 04:41:40 +03:00 committed by Joe Grandja
parent f52f097a4d
commit e4dcffae8a
10 changed files with 1007 additions and 13 deletions

View File

@ -254,6 +254,8 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder {
private ReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient;
private ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler;
private Duration clockSkew;
private Clock clock;
@ -274,6 +276,21 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder {
return this;
}
/**
* Sets a {@link ReactiveOAuth2AuthorizationSuccessHandler} to use for handling
* successful refresh token response, defaults to
* {@link RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler}.
* @param authorizationSuccessHandler the
* {@link ReactiveOAuth2AuthorizationSuccessHandler} to use
* @return the {@link RefreshTokenGrantBuilder}
* @since 7.1
*/
public RefreshTokenGrantBuilder authorizationSuccessHandler(
ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler) {
this.authorizationSuccessHandler = authorizationSuccessHandler;
return this;
}
/**
* Sets the maximum acceptable clock skew, which is used when checking the access
* token expiry. An access token is considered expired if
@ -310,6 +327,9 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder {
if (this.accessTokenResponseClient != null) {
authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient);
}
if (this.authorizationSuccessHandler != null) {
authorizedClientProvider.setAuthorizationSuccessHandler(this.authorizationSuccessHandler);
}
if (this.clockSkew != null) {
authorizedClientProvider.setClockSkew(this.clockSkew);
}

View File

@ -0,0 +1,304 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import java.time.Duration;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import reactor.core.publisher.Mono;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.oidc.authentication.ReactiveOidcIdTokenDecoderFactory;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcReactiveOAuth2UserService;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;
/**
* A {@link ReactiveOAuth2AuthorizationSuccessHandler} that refreshes an {@link OidcUser}
* in the {@link SecurityContext} if the refreshed {@link OidcIdToken} is valid according
* to <a href=
* "https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse">OpenID
* Connect Core 1.0 - Section 12.2 Successful Refresh Response</a>
*
* @author Evgeniy Cheban
* @since 7.1
*/
public final class RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler
implements ReactiveOAuth2AuthorizationSuccessHandler {
private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token";
private static final String INVALID_NONCE_ERROR_CODE = "invalid_nonce";
private static final String REFRESH_TOKEN_RESPONSE_ERROR_URI = "https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse";
// @formatter:off
private static final Mono<ServerWebExchange> currentServerWebExchangeMono = Mono.deferContextual(Mono::just)
.filter((c) -> c.hasKey(ServerWebExchange.class))
.map((c) -> c.get(ServerWebExchange.class));
// @formatter:on
private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
private ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory = new ReactiveOidcIdTokenDecoderFactory();
private ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = new OidcReactiveOAuth2UserService();
private GrantedAuthoritiesMapper authoritiesMapper = (authorities) -> authorities;
private Duration clockSkew = Duration.ofSeconds(60);
@Override
public Mono<Void> onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient, Authentication principal,
Map<String, Object> attributes) {
// The response must contain the openid scope.
if (!authorizedClient.getAccessToken().getScopes().contains(OidcScopes.OPENID)) {
return Mono.empty();
}
// The response must contain an id_token.
String idToken = extractIdToken(attributes);
if (!StringUtils.hasText(idToken)) {
return Mono.empty();
}
if (!(principal instanceof OAuth2AuthenticationToken authenticationToken)
|| authenticationToken.getClass() != OAuth2AuthenticationToken.class) {
// If the application customizes the authentication result, then a custom
// handler should be provided.
return Mono.empty();
}
// The current principal must be an OidcUser.
if (!(authenticationToken.getPrincipal() instanceof OidcUser existingOidcUser)) {
return Mono.empty();
}
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
// The registrationId must match the one used to log in.
if (!authenticationToken.getAuthorizedClientRegistrationId().equals(clientRegistration.getRegistrationId())) {
return Mono.empty();
}
// Create, validate OidcIdToken and refresh OidcUser in the SecurityContext.
return Mono.justOrEmpty((ServerWebExchange) attributes.get(ServerWebExchange.class.getName()))
.switchIfEmpty(currentServerWebExchangeMono)
.flatMap((exchange) -> {
ReactiveJwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
return jwtDecoder.decode(idToken).onErrorMap(JwtException.class, (ex) -> {
OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(),
null);
return new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex);
})
.map((jwt) -> new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(),
jwt.getClaims()))
.doOnNext((oidcIdToken) -> validateIdToken(existingOidcUser, oidcIdToken))
.flatMap((oidcIdToken) -> {
OidcUserRequest userRequest = new OidcUserRequest(clientRegistration,
authorizedClient.getAccessToken(), oidcIdToken);
return this.userService.loadUser(userRequest);
})
.flatMap((oidcUser) -> refreshSecurityContext(exchange, clientRegistration, authenticationToken,
oidcUser));
});
}
/**
* Sets a {@link ServerSecurityContextRepository} to use for refreshing a
* {@link SecurityContext}, defaults to
* {@link WebSessionServerSecurityContextRepository}.
* @param serverSecurityContextRepository the {@link ServerSecurityContextRepository}
* to use
*/
public void setServerSecurityContextRepository(ServerSecurityContextRepository serverSecurityContextRepository) {
Assert.notNull(serverSecurityContextRepository, "serverSecurityContextRepository cannot be null");
this.serverSecurityContextRepository = serverSecurityContextRepository;
}
/**
* Sets a {@link ReactiveJwtDecoderFactory} to use for decoding refreshed oidc
* id-token, defaults to {@link ReactiveOidcIdTokenDecoderFactory}.
* @param jwtDecoderFactory the {@link ReactiveJwtDecoderFactory} to use
*/
public void setJwtDecoderFactory(ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory) {
Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null");
this.jwtDecoderFactory = jwtDecoderFactory;
}
/**
* Sets a {@link GrantedAuthoritiesMapper} to use for mapping
* {@link GrantedAuthority}s, defaults to no-op implementation.
* @param authoritiesMapper the {@link GrantedAuthoritiesMapper} to use
*/
public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null");
this.authoritiesMapper = authoritiesMapper;
}
/**
* Sets a {@link ReactiveOAuth2UserService} to use for loading an {@link OidcUser}
* from refreshed oidc id-token, defaults to {@link OidcReactiveOAuth2UserService}.
* @param userService the {@link ReactiveOAuth2UserService} to use
*/
public void setUserService(ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService) {
Assert.notNull(userService, "userService cannot be null");
this.userService = userService;
}
/**
* Sets the maximum acceptable clock skew, which is used when checking the
* {@link OidcIdToken#getIssuedAt()} to match the existing
* {@link OidcUser#getIdToken()}'s issuedAt time, defaults to 60 seconds.
* @param clockSkew the maximum acceptable clock skew to use
*/
public void setClockSkew(Duration clockSkew) {
Assert.notNull(clockSkew, "clockSkew cannot be null");
Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0");
this.clockSkew = clockSkew;
}
private String extractIdToken(Map<String, Object> attributes) {
if (attributes.get(OidcParameterNames.ID_TOKEN) instanceof String idToken) {
return idToken;
}
return null;
}
private void validateIdToken(OidcUser existingOidcUser, OidcIdToken idToken) {
// OpenID Connect Core 1.0 - Section 12.2 Successful Refresh Response
// If an ID Token is returned as a result of a token refresh request, the
// following requirements apply:
// its iss Claim Value MUST be the same as in the ID Token issued when the
// original authentication occurred,
validateIssuer(existingOidcUser, idToken);
// its sub Claim Value MUST be the same as in the ID Token issued when the
// original authentication occurred,
validateSubject(existingOidcUser, idToken);
// its iat Claim MUST represent the time that the new ID Token is issued,
validateIssuedAt(existingOidcUser, idToken);
// its aud Claim Value MUST be the same as in the ID Token issued when the
// original authentication occurred,
validateAudience(existingOidcUser, idToken);
// if the ID Token contains an auth_time Claim, its value MUST represent the time
// of the original authentication - not the time that the new ID token is issued,
validateAuthenticatedAt(existingOidcUser, idToken);
// it SHOULD NOT have a nonce Claim, even when the ID Token issued at the time of
// the original authentication contained nonce; however, if it is present, its
// value MUST be the same as in the ID Token issued at the time of the original
// authentication,
validateNonce(existingOidcUser, idToken);
}
private void validateIssuer(OidcUser existingOidcUser, OidcIdToken idToken) {
if (!idToken.getIssuer().toString().equals(existingOidcUser.getIdToken().getIssuer().toString())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid issuer",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}
private void validateSubject(OidcUser existingOidcUser, OidcIdToken idToken) {
if (!idToken.getSubject().equals(existingOidcUser.getIdToken().getSubject())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid subject",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}
private void validateIssuedAt(OidcUser existingOidcUser, OidcIdToken idToken) {
if (!idToken.getIssuedAt().isAfter(existingOidcUser.getIdToken().getIssuedAt().minus(this.clockSkew))) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid issued at time",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}
private void validateAudience(OidcUser existingOidcUser, OidcIdToken idToken) {
if (!isValidAudience(existingOidcUser, idToken)) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid audience",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}
private boolean isValidAudience(OidcUser existingOidcUser, OidcIdToken idToken) {
List<String> idTokenAudiences = idToken.getAudience();
Set<String> oidcUserAudiences = new HashSet<>(existingOidcUser.getIdToken().getAudience());
if (idTokenAudiences.size() != oidcUserAudiences.size()) {
return false;
}
for (String audience : idTokenAudiences) {
if (!oidcUserAudiences.contains(audience)) {
return false;
}
}
return true;
}
private void validateAuthenticatedAt(OidcUser existingOidcUser, OidcIdToken idToken) {
if (idToken.getAuthenticatedAt() == null) {
return;
}
if (!idToken.getAuthenticatedAt().equals(existingOidcUser.getIdToken().getAuthenticatedAt())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid authenticated at time",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}
private void validateNonce(OidcUser existingOidcUser, OidcIdToken idToken) {
if (!StringUtils.hasText(idToken.getNonce())) {
return;
}
if (!idToken.getNonce().equals(existingOidcUser.getIdToken().getNonce())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE, "Invalid nonce",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}
private Mono<Void> refreshSecurityContext(ServerWebExchange exchange, ClientRegistration clientRegistration,
OAuth2AuthenticationToken authenticationToken, OidcUser oidcUser) {
Collection<? extends GrantedAuthority> mappedAuthorities = this.authoritiesMapper
.mapAuthorities(oidcUser.getAuthorities());
OAuth2AuthenticationToken authenticationResult = new OAuth2AuthenticationToken(oidcUser, mappedAuthorities,
clientRegistration.getRegistrationId());
authenticationResult.setDetails(authenticationToken.getDetails());
SecurityContextImpl securityContext = new SecurityContextImpl(authenticationResult);
return this.serverSecurityContextRepository.save(exchange, securityContext);
}
}

View File

@ -21,7 +21,9 @@ import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import reactor.core.publisher.Mono;
@ -40,6 +42,7 @@ import org.springframework.util.Assert;
* {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant.
*
* @author Joe Grandja
* @author Evgeniy Cheban
* @since 5.2
* @see ReactiveOAuth2AuthorizedClientProvider
* @see WebClientReactiveRefreshTokenTokenResponseClient
@ -49,6 +52,8 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider
private ReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient = new WebClientReactiveRefreshTokenTokenResponseClient();
private ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler();
private Duration clockSkew = Duration.ofSeconds(60);
private Clock clock = Clock.systemUTC();
@ -96,12 +101,17 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider
.flatMap(this.accessTokenResponseClient::getTokenResponse)
.onErrorMap(OAuth2AuthorizationException.class,
(e) -> new ClientAuthorizationException(e.getError(), clientRegistration.getRegistrationId(), e))
.map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(),
tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()));
}
private boolean hasTokenExpired(OAuth2Token token) {
return this.clock.instant().isAfter(token.getExpiresAt().minus(this.clockSkew));
.flatMap((tokenResponse) -> {
OAuth2AuthorizedClient refreshedClient = new OAuth2AuthorizedClient(clientRegistration,
context.getPrincipal().getName(), tokenResponse.getAccessToken(),
tokenResponse.getRefreshToken());
Map<String, Object> attributes = new HashMap<>(context.getAttributes());
attributes.putAll(tokenResponse.getAdditionalParameters());
return this.authorizationSuccessHandler
.onAuthorizationSuccess(refreshedClient, context.getPrincipal(),
Collections.unmodifiableMap(attributes))
.thenReturn(refreshedClient);
});
}
/**
@ -116,6 +126,19 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider
this.accessTokenResponseClient = accessTokenResponseClient;
}
/**
* Sets a {@link ReactiveOAuth2AuthorizationSuccessHandler} to use for handling
* successful refresh token response, defaults to
* {@link RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler}.
* @param authorizationSuccessHandler the
* {@link ReactiveOAuth2AuthorizationSuccessHandler} to use
* @since 7.1
*/
public void setAuthorizationSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler) {
Assert.notNull(authorizationSuccessHandler, "authorizationSuccessHandler cannot be null");
this.authorizationSuccessHandler = authorizationSuccessHandler;
}
/**
* Sets the maximum acceptable clock skew, which is used when checking the
* {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is
@ -143,4 +166,8 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider
this.clock = clock;
}
private boolean hasTokenExpired(OAuth2Token token) {
return this.clock.instant().isAfter(token.getExpiresAt().minus(this.clockSkew));
}
}

View File

@ -318,14 +318,14 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
return Mono.justOrEmpty(serverWebExchange)
.switchIfEmpty(currentServerWebExchangeMono)
.flatMap((exchange) -> {
Map<String, Object> contextAttributes = Collections.emptyMap();
Map<String, Object> contextAttributes = new HashMap<>();
contextAttributes.put(ServerWebExchange.class.getName(), serverWebExchange);
String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
if (StringUtils.hasText(scope)) {
contextAttributes = new HashMap<>();
contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
StringUtils.delimitedListToStringArray(scope, " "));
}
return Mono.just(contextAttributes);
return Mono.just(Collections.unmodifiableMap(contextAttributes));
})
.defaultIfEmpty(Collections.emptyMap());
// @formatter:on

View File

@ -51,6 +51,8 @@ import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.client.ClientRequest;
@ -96,6 +98,7 @@ import org.springframework.web.server.ServerWebExchange;
* @author Rob Winch
* @author Joe Grandja
* @author Phil Clay
* @author Evgeniy Cheban
* @since 5.1
*/
public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
@ -139,6 +142,8 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
private ClientResponseHandler clientResponseHandler;
private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
/**
* Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the
* provided parameters.
@ -330,8 +335,11 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
}
private Mono<ClientResponse> exchangeAndHandleResponse(ClientRequest request, ExchangeFunction next) {
return next.exchange(request)
.transform((responseMono) -> this.clientResponseHandler.handleResponse(request, responseMono));
// Re-request an Authentication from serverSecurityContextRepository since it
// might have been changed during provider invocation.
return effectiveAuthentication(request).flatMap((authentication) -> next.exchange(request)
.transform((responseMono) -> this.clientResponseHandler.handleResponse(request, responseMono))
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication)));
}
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
@ -362,6 +370,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
// @formatter:on
}
private Mono<Authentication> effectiveAuthentication(ClientRequest request) {
// @formatter:off
return effectiveServerWebExchange(request)
.filter(Optional::isPresent)
.map(Optional::get)
.flatMap(this.serverSecurityContextRepository::load)
.mapNotNull(SecurityContext::getAuthentication)
.switchIfEmpty(this.currentAuthenticationMono);
// @formatter:on
}
/**
* Returns a {@link Mono} the emits the {@code clientRegistrationId} that is active
* for the given request.
@ -445,6 +464,19 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler);
}
/**
* Sets a {@link ServerSecurityContextRepository} to use for re-obtaining a
* {@link SecurityContext} if it has been refreshed during provider invocation,
* defaults to {@link WebSessionServerSecurityContextRepository}.
* @param serverSecurityContextRepository the {@link ServerSecurityContextRepository}
* to use
* @since 7.1
*/
public void setServerSecurityContextRepository(ServerSecurityContextRepository serverSecurityContextRepository) {
Assert.notNull(serverSecurityContextRepository, "serverSecurityContextRepository cannot be null");
this.serverSecurityContextRepository = serverSecurityContextRepository;
}
@FunctionalInterface
private interface ClientResponseHandler {

View File

@ -0,0 +1,398 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.oidc.user.TestOidcUsers;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.web.server.ServerWebExchange;
import static org.assertj.core.api.Assertions.assertThatException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler}.
*
* @author Evgeniy Cheban
*/
class RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandlerTests {
@Test
void onAuthorizationSuccessWhenIdTokenValidThenSecurityContextRefreshed() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
DefaultOidcUser principal = TestOidcUsers.create();
OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal,
principal.getAuthorities(), clientRegistration.getRegistrationId());
OAuth2AccessToken accessToken = createAccessToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(),
accessToken, null);
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build());
Map<String, Object> attributes = Map.of(ServerWebExchange.class.getName(), exchange,
OidcParameterNames.ID_TOKEN, "id-token-1234");
Map<String, Object> claims = new HashMap<>();
claims.put("iss", principal.getIssuer());
claims.put("sub", principal.getSubject());
claims.put("aud", principal.getAudience());
claims.put("nonce", principal.getNonce());
Jwt jwt = mock(Jwt.class);
given(jwt.getTokenValue()).willReturn("id-token-1234");
given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt());
given(jwt.getClaims()).willReturn(claims);
ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt));
ReactiveJwtDecoderFactory<ClientRegistration> reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class);
given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder);
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
given(userService.loadUser(any())).willReturn(Mono.just(principal));
WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler();
handler.setJwtDecoderFactory(reactiveJwtDecoderFactory);
handler.setUserService(userService);
handler.setServerSecurityContextRepository(serverSecurityContextRepository);
StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes))
.verifyComplete();
StepVerifier.create(serverSecurityContextRepository.load(exchange).map(SecurityContext::getAuthentication))
.expectNext(authenticationToken)
.verifyComplete();
}
@Test
void onAuthorizationSuccessWhenIdTokenIssuerNotSameThenException() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
DefaultOidcUser principal = TestOidcUsers.create();
OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal,
principal.getAuthorities(), clientRegistration.getRegistrationId());
OAuth2AccessToken accessToken = createAccessToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(),
accessToken, null);
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build());
Map<String, Object> attributes = Map.of(ServerWebExchange.class.getName(), exchange,
OidcParameterNames.ID_TOKEN, "id-token-1234");
Map<String, Object> claims = new HashMap<>();
claims.put("iss", "https://issuer.com");
claims.put("sub", principal.getSubject());
claims.put("aud", principal.getAudience());
claims.put("nonce", principal.getNonce());
Jwt jwt = mock(Jwt.class);
given(jwt.getTokenValue()).willReturn("id-token-1234");
given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt());
given(jwt.getClaims()).willReturn(claims);
ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt));
ReactiveJwtDecoderFactory<ClientRegistration> reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class);
given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder);
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
given(userService.loadUser(any())).willReturn(Mono.just(principal));
WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler();
handler.setJwtDecoderFactory(reactiveJwtDecoderFactory);
handler.setUserService(userService);
handler.setServerSecurityContextRepository(serverSecurityContextRepository);
StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes))
.verifyErrorMessage("[invalid_id_token] Invalid issuer");
}
@Test
void onAuthorizationSuccessWhenIdTokenSubNotSameThenException() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
DefaultOidcUser principal = TestOidcUsers.create();
OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal,
principal.getAuthorities(), clientRegistration.getRegistrationId());
OAuth2AccessToken accessToken = createAccessToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(),
accessToken, null);
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build());
Map<String, Object> attributes = Map.of(ServerWebExchange.class.getName(), exchange,
OidcParameterNames.ID_TOKEN, "id-token-1234");
Map<String, Object> claims = new HashMap<>();
claims.put("iss", principal.getIssuer());
claims.put("sub", "invalid_sub");
claims.put("aud", principal.getAudience());
claims.put("nonce", principal.getNonce());
Jwt jwt = mock(Jwt.class);
given(jwt.getTokenValue()).willReturn("id-token-1234");
given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt());
given(jwt.getClaims()).willReturn(claims);
ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt));
ReactiveJwtDecoderFactory<ClientRegistration> reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class);
given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder);
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
given(userService.loadUser(any())).willReturn(Mono.just(principal));
WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler();
handler.setJwtDecoderFactory(reactiveJwtDecoderFactory);
handler.setUserService(userService);
handler.setServerSecurityContextRepository(serverSecurityContextRepository);
StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes))
.verifyErrorMessage("[invalid_id_token] Invalid subject");
}
@Test
void onAuthorizationSuccessWhenIdTokenIatNotAfterThenException() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
DefaultOidcUser principal = TestOidcUsers.create();
OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal,
principal.getAuthorities(), clientRegistration.getRegistrationId());
OAuth2AccessToken accessToken = createAccessToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(),
accessToken, null);
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build());
Map<String, Object> attributes = Map.of(ServerWebExchange.class.getName(), exchange,
OidcParameterNames.ID_TOKEN, "id-token-1234");
Map<String, Object> claims = new HashMap<>();
claims.put("iss", principal.getIssuer());
claims.put("sub", principal.getSubject());
claims.put("aud", principal.getAudience());
claims.put("nonce", principal.getNonce());
Jwt jwt = mock(Jwt.class);
given(jwt.getTokenValue()).willReturn("id-token-1234");
given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt().minus(Duration.ofDays(1)));
given(jwt.getClaims()).willReturn(claims);
ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt));
ReactiveJwtDecoderFactory<ClientRegistration> reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class);
given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder);
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
given(userService.loadUser(any())).willReturn(Mono.just(principal));
WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler();
handler.setJwtDecoderFactory(reactiveJwtDecoderFactory);
handler.setUserService(userService);
handler.setServerSecurityContextRepository(serverSecurityContextRepository);
StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes))
.verifyErrorMessage("[invalid_id_token] Invalid issued at time");
}
@Test
void onAuthorizationSuccessWhenIdTokenAudEmptyThenException() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
DefaultOidcUser principal = TestOidcUsers.create();
OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal,
principal.getAuthorities(), clientRegistration.getRegistrationId());
OAuth2AccessToken accessToken = createAccessToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(),
accessToken, null);
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build());
Map<String, Object> attributes = Map.of(ServerWebExchange.class.getName(), exchange,
OidcParameterNames.ID_TOKEN, "id-token-1234");
Map<String, Object> claims = new HashMap<>();
claims.put("iss", principal.getIssuer());
claims.put("sub", principal.getSubject());
claims.put("aud", Collections.emptyList());
claims.put("nonce", principal.getNonce());
Jwt jwt = mock(Jwt.class);
given(jwt.getTokenValue()).willReturn("id-token-1234");
given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt());
given(jwt.getClaims()).willReturn(claims);
ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt));
ReactiveJwtDecoderFactory<ClientRegistration> reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class);
given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder);
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
given(userService.loadUser(any())).willReturn(Mono.just(principal));
WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler();
handler.setJwtDecoderFactory(reactiveJwtDecoderFactory);
handler.setUserService(userService);
handler.setServerSecurityContextRepository(serverSecurityContextRepository);
StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes))
.verifyErrorMessage("[invalid_id_token] Invalid audience");
}
@Test
void onAuthorizationSuccessWhenIdTokenAudNotContainThenException() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
DefaultOidcUser principal = TestOidcUsers.create();
OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal,
principal.getAuthorities(), clientRegistration.getRegistrationId());
OAuth2AccessToken accessToken = createAccessToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(),
accessToken, null);
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build());
Map<String, Object> attributes = Map.of(ServerWebExchange.class.getName(), exchange,
OidcParameterNames.ID_TOKEN, "id-token-1234");
Map<String, Object> claims = new HashMap<>();
claims.put("iss", principal.getIssuer());
claims.put("sub", principal.getSubject());
claims.put("aud", List.of("invalid_client-id"));
claims.put("nonce", principal.getNonce());
Jwt jwt = mock(Jwt.class);
given(jwt.getTokenValue()).willReturn("id-token-1234");
given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt());
given(jwt.getClaims()).willReturn(claims);
ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt));
ReactiveJwtDecoderFactory<ClientRegistration> reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class);
given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder);
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
given(userService.loadUser(any())).willReturn(Mono.just(principal));
WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler();
handler.setJwtDecoderFactory(reactiveJwtDecoderFactory);
handler.setUserService(userService);
handler.setServerSecurityContextRepository(serverSecurityContextRepository);
StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes))
.verifyErrorMessage("[invalid_id_token] Invalid audience");
}
@Test
void onAuthorizationSuccessWhenIdTokenAuthTimeNotSameThenException() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
DefaultOidcUser principal = TestOidcUsers.create();
OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal,
principal.getAuthorities(), clientRegistration.getRegistrationId());
OAuth2AccessToken accessToken = createAccessToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(),
accessToken, null);
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build());
Map<String, Object> attributes = Map.of(ServerWebExchange.class.getName(), exchange,
OidcParameterNames.ID_TOKEN, "id-token-1234");
Map<String, Object> claims = new HashMap<>();
claims.put("iss", principal.getIssuer());
claims.put("sub", principal.getSubject());
claims.put("aud", principal.getAudience());
claims.put("auth_time", principal.getIssuedAt());
claims.put("nonce", principal.getNonce());
Jwt jwt = mock(Jwt.class);
given(jwt.getTokenValue()).willReturn("id-token-1234");
given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt());
given(jwt.getClaims()).willReturn(claims);
ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt));
ReactiveJwtDecoderFactory<ClientRegistration> reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class);
given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder);
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
given(userService.loadUser(any())).willReturn(Mono.just(principal));
WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler();
handler.setJwtDecoderFactory(reactiveJwtDecoderFactory);
handler.setUserService(userService);
handler.setServerSecurityContextRepository(serverSecurityContextRepository);
StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes))
.verifyErrorMessage("[invalid_id_token] Invalid authenticated at time");
}
@Test
void onAuthorizationSuccessWhenIdTokenNonceNotSameThenException() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
DefaultOidcUser principal = TestOidcUsers.create();
OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal,
principal.getAuthorities(), clientRegistration.getRegistrationId());
OAuth2AccessToken accessToken = createAccessToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(),
accessToken, null);
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build());
Map<String, Object> attributes = Map.of(ServerWebExchange.class.getName(), exchange,
OidcParameterNames.ID_TOKEN, "id-token-1234");
Map<String, Object> claims = new HashMap<>();
claims.put("iss", principal.getIssuer());
claims.put("sub", principal.getSubject());
claims.put("aud", principal.getAudience());
claims.put("nonce", "invalid_nonce");
Jwt jwt = mock(Jwt.class);
given(jwt.getTokenValue()).willReturn("id-token-1234");
given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt());
given(jwt.getClaims()).willReturn(claims);
ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt));
ReactiveJwtDecoderFactory<ClientRegistration> reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class);
given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder);
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
given(userService.loadUser(any())).willReturn(Mono.just(principal));
WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler();
handler.setJwtDecoderFactory(reactiveJwtDecoderFactory);
handler.setUserService(userService);
handler.setServerSecurityContextRepository(serverSecurityContextRepository);
StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes))
.verifyErrorMessage("[invalid_nonce] Invalid nonce");
}
@Test
void setServerSecurityContextRepositoryWhenNullThenException() {
assertThatException()
.isThrownBy(() -> new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler()
.setServerSecurityContextRepository(null))
.withMessage("serverSecurityContextRepository cannot be null");
}
@Test
void setJwtDecoderFactoryWhenNullThenException() {
assertThatException()
.isThrownBy(() -> new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler().setJwtDecoderFactory(null))
.withMessage("jwtDecoderFactory cannot be null");
}
@Test
void setAuthoritiesMapperWhenNullThenException() {
assertThatException()
.isThrownBy(() -> new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler().setAuthoritiesMapper(null))
.withMessage("authoritiesMapper cannot be null");
}
@Test
void setUserServiceWhenNullThenException() {
assertThatException()
.isThrownBy(() -> new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler().setUserService(null))
.withMessage("userService cannot be null");
}
@Test
void setClockSkewWhenNullThenException() {
assertThatException()
.isThrownBy(() -> new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler().setClockSkew(null))
.withMessage("clockSkew cannot be null");
}
private static OAuth2AccessToken createAccessToken() {
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60));
return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt,
Set.of(OidcScopes.OPENID));
}
}

View File

@ -49,6 +49,7 @@ import static org.mockito.Mockito.verify;
* Tests for {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider}.
*
* @author Joe Grandja
* @author Evgeniy Cheban
*/
public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests {
@ -84,6 +85,15 @@ public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests {
.withMessage("accessTokenResponseClient cannot be null");
}
@Test
public void setAuthorizationSuccessHandlerWhenHandlerIsNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setAuthorizationSuccessHandler(null))
.withMessage("authorizationSuccessHandler cannot be null");
// @formatter:on
}
@Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
// @formatter:off

View File

@ -18,9 +18,13 @@ package org.springframework.security.oauth2.client.web.reactive.function.client;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import reactor.core.publisher.Mono;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFunction;
@ -29,14 +33,21 @@ import static org.mockito.Mockito.mock;
/**
* @author Rob Winch
* @author Evgeniy Cheban
* @since 5.1
*/
public class MockExchangeFunction implements ExchangeFunction {
private final AtomicReference<Authentication> authenticationCaptor = new AtomicReference<>();
private List<ClientRequest> requests = new ArrayList<>();
private ClientResponse response = mock(ClientResponse.class);
public Authentication getCapturedAuthentication() {
return this.authenticationCaptor.get();
}
public ClientRequest getRequest() {
return this.requests.get(this.requests.size() - 1);
}
@ -53,8 +64,14 @@ public class MockExchangeFunction implements ExchangeFunction {
public Mono<ClientResponse> exchange(ClientRequest request) {
return Mono.defer(() -> {
this.requests.add(request);
return Mono.just(this.response);
return captureAuthentication().then(Mono.just(this.response));
});
}
private Mono<Authentication> captureAuthentication() {
return ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.doOnNext(this.authenticationCaptor::set);
}
}

View File

@ -19,8 +19,15 @@ package org.springframework.security.oauth2.client.web.reactive.function.client;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.jwk.gen.RSAKeyGenerator;
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.jupiter.api.AfterEach;
@ -38,23 +45,42 @@ import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AbstractOAuth2Token;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.JwsHeader;
import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.jwt.JwtEncoderParameters;
import org.springframework.security.oauth2.jwt.NimbusJwtEncoder;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import org.springframework.web.reactive.function.client.ExchangeFunction;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.client.WebClientResponseException;
import org.springframework.web.server.ServerWebExchange;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.InstanceOfAssertFactories.type;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
@ -67,9 +93,12 @@ import static org.mockito.Mockito.verify;
/**
* @author Phil Clay
* @author Evgeniy Cheban
*/
public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests {
private final AuthenticationCapturingExchangeFilterFunction authenticationCapturingFilter = new AuthenticationCapturingExchangeFilterFunction();
private ReactiveClientRegistrationRepository clientRegistrationRepository;
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
@ -118,6 +147,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests {
// @formatter:off
this.webClient = WebClient.builder()
.filter(this.authorizedClientFilter)
.filter(this.authenticationCapturingFilter)
.build();
// @formatter:on
this.authentication = new TestingAuthenticationToken("principal", "password");
@ -218,6 +248,116 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests {
assertThat(refreshedAuthorizedClient.getAccessToken().getTokenValue()).isEqualTo("refreshed-access-token");
}
@Test
public void requestWhenAuthorizedButExpiredThenRefreshSecurityContext() throws Exception {
// Current OIDC user in the SecurityContext.
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration()
.issuerUri(this.serverUrl)
.tokenUri(this.serverUrl)
.jwkSetUri(this.serverUrl + "oauth2/jwk")
.userInfoUri(this.serverUrl + "user")
.userNameAttributeName("username")
.build();
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant expiresAt = issuedAt.plus(Duration.ofHours(1));
OidcIdToken existingIdToken = OidcIdToken.withTokenValue("id-token-1234")
.issuer(this.serverUrl)
.subject("subject-1234")
.audience(List.of(clientRegistration.getClientId()))
.issuedAt(issuedAt)
.expiresAt(expiresAt)
.build();
OidcUser existingOidcUser = new DefaultOidcUser(Collections.emptyList(), existingIdToken);
this.authentication = new OAuth2AuthenticationToken(existingOidcUser, Collections.emptyList(),
clientRegistration.getRegistrationId());
// Generate new OIDC ID token with refreshed user information.
JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).type("JWT").build();
JwtClaimsSet jwtClaimsSet = JwtClaimsSet.builder()
.subject("subject-1234")
.audience(List.of(clientRegistration.getClientId()))
.issuer(this.serverUrl)
.issuedAt(Instant.now())
.expiresAt(Instant.now().plusSeconds(3600))
.build();
RSAKey key = new RSAKeyGenerator(2048).generate();
JWKSet jwkSet = new JWKSet(key);
JwtEncoder jwtEncoder = new NimbusJwtEncoder(new ImmutableJWKSet<>(jwkSet));
String idToken = jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet)).getTokenValue();
// @formatter:off
String accessTokenResponse = """
{
"access_token": "refreshed-access-token",
"id_token": "%s",
"token_type": "bearer",
"expires_in": "3600"
}
""".formatted(idToken);
String userInfoResponse = """
{
"sub": "subject-1234",
"username": "refreshed-username"
}
""";
String clientResponse = """
{
"attribute1": "value1",
"attribute2": "value2"
}
""";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenResponse));
this.server.enqueue(jsonResponse(jwkSet.toString()));
this.server.enqueue(jsonResponse(userInfoResponse));
this.server.enqueue(jsonResponse(clientResponse));
given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId())))
.willReturn(Mono.just(clientRegistration));
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
"expired-access-token", issuedAt, expiresAt,
new HashSet<>(Arrays.asList("read", "write", OidcScopes.OPENID)));
OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration,
this.authentication.getName(), accessToken, refreshToken);
doReturn(Mono.just(authorizedClient)).when(this.authorizedClientRepository)
.loadAuthorizedClient(eq(clientRegistration.getRegistrationId()), eq(this.authentication),
eq(this.exchange));
this.webClient.get()
.uri(this.serverUrl)
.attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction
.clientRegistrationId(clientRegistration.getRegistrationId()))
.retrieve()
.bodyToMono(String.class)
.contextWrite(Context.of(ServerWebExchange.class, this.exchange))
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(this.authentication))
.block();
assertThat(this.server.getRequestCount()).isEqualTo(4);
ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor
.forClass(OAuth2AuthorizedClient.class);
verify(this.authorizedClientRepository).saveAuthorizedClient(authorizedClientCaptor.capture(),
eq(this.authentication), eq(this.exchange));
OAuth2AuthorizedClient refreshedAuthorizedClient = authorizedClientCaptor.getValue();
assertThat(refreshedAuthorizedClient.getClientRegistration()).isSameAs(clientRegistration);
assertThat(refreshedAuthorizedClient.getAccessToken().getTokenValue()).isEqualTo("refreshed-access-token");
WebSessionServerSecurityContextRepository securityContextRepository = new WebSessionServerSecurityContextRepository();
// Capture and verify that the refreshed Authentication object was propagated to
// the next ExchangeFilterFunction's context.
Authentication capturedAuthentication = this.authenticationCapturingFilter.getCapturedAuthentication();
assertThat(capturedAuthentication).isNotNull();
Authentication refreshedAuthentication = securityContextRepository.load(this.exchange)
.mapNotNull(SecurityContext::getAuthentication)
.block();
assertThat(refreshedAuthentication).isNotNull();
assertThat(refreshedAuthentication).isSameAs(capturedAuthentication);
assertThat(refreshedAuthentication).asInstanceOf(type(OAuth2AuthenticationToken.class))
.extracting(OAuth2AuthenticationToken::getPrincipal)
.asInstanceOf(type(OidcUser.class))
.satisfies((oidcUser) -> {
// Verify that the OidcUser's attributes match the id_token's claims.
assertThat(oidcUser.getIdToken()).extracting(AbstractOAuth2Token::getTokenValue).isEqualTo(idToken);
assertThat(oidcUser.getSubject()).isEqualTo("subject-1234");
assertThat(oidcUser.getName()).isEqualTo("refreshed-username");
});
}
@Test
public void requestMultipleWhenNoneAuthorizedThenAuthorizeAndSendRequest() {
// @formatter:off
@ -349,4 +489,22 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests {
// @formatter:on
}
private static final class AuthenticationCapturingExchangeFilterFunction implements ExchangeFilterFunction {
private final AtomicReference<Authentication> authenticationCaptor = new AtomicReference<>();
@Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
return ReactiveSecurityContextHolder.getContext().flatMap((ctx) -> {
this.authenticationCaptor.set(ctx.getAuthentication());
return next.exchange(request);
});
}
private Authentication getCapturedAuthentication() {
return this.authenticationCaptor.get();
}
}
}

View File

@ -59,6 +59,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.oauth2.client.AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.ClientCredentialsReactiveOAuth2AuthorizedClientProvider;
@ -66,6 +67,7 @@ import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2Authoriz
import org.springframework.security.oauth2.client.JwtBearerReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationFailureHandler;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationSuccessHandler;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
@ -89,8 +91,10 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.core.user.TestOAuth2Users;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.TestJwts;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
@ -113,6 +117,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
/**
* @author Rob Winch
* @author Evgeniy Cheban
* @since 5.1
*/
@ExtendWith(MockitoExtension.class)
@ -133,6 +138,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Mock
private ReactiveOAuth2AccessTokenResponseClient<JwtBearerGrantRequest> jwtBearerTokenResponseClient;
@Mock
private ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler;
@Mock
private ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler;
@ -170,7 +178,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
.builder()
.authorizationCode()
.refreshToken(
(configurer) -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient))
(configurer) -> configurer
.authorizationSuccessHandler(this.authorizationSuccessHandler)
.accessTokenResponseClient(this.refreshTokenTokenResponseClient))
.clientCredentials(
(configurer) -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient))
.provider(jwtBearerAuthorizedClientProvider)
@ -201,6 +211,13 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
.isThrownBy(() -> new ServerOAuth2AuthorizedClientExchangeFilterFunction(null));
}
@Test
public void setServerSecurityContextRepositoryWhenHandlerIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager)
.setServerSecurityContextRepository(null));
}
@Test
public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() {
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build();
@ -312,6 +329,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
.refreshToken("refresh-1")
.build();
given(this.refreshTokenTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(response));
given(this.authorizationSuccessHandler.onAuthorizationSuccess(any(), any(), any())).willReturn(Mono.empty());
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(),
@ -326,14 +344,23 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
// @formatter:on
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
// @formatter:off
DefaultOAuth2User refreshedUser = TestOAuth2Users.create();
OAuth2AuthenticationToken refreshedAuthentication = new OAuth2AuthenticationToken(refreshedUser, refreshedUser.getAuthorities(), this.registration.getRegistrationId());
SecurityContextImpl securityContext = new SecurityContextImpl(refreshedAuthentication);
ServerSecurityContextRepository securityContextRepository = mock(ServerSecurityContextRepository.class);
given(securityContextRepository.load(this.serverWebExchange)).willReturn(Mono.just(securityContext));
this.function.setServerSecurityContextRepository(securityContextRepository);
this.function.filter(request, this.exchange)
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication))
.contextWrite(serverWebExchange())
.block();
Authentication currentAuthentication = this.exchange.getCapturedAuthentication();
assertThat(currentAuthentication).isSameAs(refreshedAuthentication);
// @formatter:on
verify(this.refreshTokenTokenResponseClient).getTokenResponse(any());
verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(),
eq(authentication), any());
verify(securityContextRepository).load(this.serverWebExchange);
OAuth2AuthorizedClient newAuthorizedClient = this.authorizedClientCaptor.getValue();
assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken());
assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(response.getRefreshToken());
@ -355,6 +382,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
.refreshToken("refresh-1")
.build();
given(this.refreshTokenTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(response));
given(this.authorizationSuccessHandler.onAuthorizationSuccess(any(), any(), any())).willReturn(Mono.empty());
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(),