diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilder.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilder.java index 3d3e76c634..3e6d88be93 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilder.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilder.java @@ -254,6 +254,8 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder { private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; + private ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler; + private Duration clockSkew; private Clock clock; @@ -274,6 +276,21 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder { return this; } + /** + * Sets a {@link ReactiveOAuth2AuthorizationSuccessHandler} to use for handling + * successful refresh token response, defaults to + * {@link RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler}. + * @param authorizationSuccessHandler the + * {@link ReactiveOAuth2AuthorizationSuccessHandler} to use + * @return the {@link RefreshTokenGrantBuilder} + * @since 7.1 + */ + public RefreshTokenGrantBuilder authorizationSuccessHandler( + ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler) { + this.authorizationSuccessHandler = authorizationSuccessHandler; + return this; + } + /** * Sets the maximum acceptable clock skew, which is used when checking the access * token expiry. An access token is considered expired if @@ -310,6 +327,9 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder { if (this.accessTokenResponseClient != null) { authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); } + if (this.authorizationSuccessHandler != null) { + authorizedClientProvider.setAuthorizationSuccessHandler(this.authorizationSuccessHandler); + } if (this.clockSkew != null) { authorizedClientProvider.setClockSkew(this.clockSkew); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler.java new file mode 100644 index 0000000000..064aafa1b9 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler.java @@ -0,0 +1,304 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import java.time.Duration; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import reactor.core.publisher.Mono; + +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextImpl; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.oidc.authentication.ReactiveOidcIdTokenDecoderFactory; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcReactiveOAuth2UserService; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.jwt.JwtException; +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory; +import org.springframework.security.web.server.context.ServerSecurityContextRepository; +import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; + +/** + * A {@link ReactiveOAuth2AuthorizationSuccessHandler} that refreshes an {@link OidcUser} + * in the {@link SecurityContext} if the refreshed {@link OidcIdToken} is valid according + * to OpenID + * Connect Core 1.0 - Section 12.2 Successful Refresh Response + * + * @author Evgeniy Cheban + * @since 7.1 + */ +public final class RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler + implements ReactiveOAuth2AuthorizationSuccessHandler { + + private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token"; + + private static final String INVALID_NONCE_ERROR_CODE = "invalid_nonce"; + + private static final String REFRESH_TOKEN_RESPONSE_ERROR_URI = "https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse"; + + // @formatter:off + private static final Mono currentServerWebExchangeMono = Mono.deferContextual(Mono::just) + .filter((c) -> c.hasKey(ServerWebExchange.class)) + .map((c) -> c.get(ServerWebExchange.class)); + // @formatter:on + + private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + + private ReactiveJwtDecoderFactory jwtDecoderFactory = new ReactiveOidcIdTokenDecoderFactory(); + + private ReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService(); + + private GrantedAuthoritiesMapper authoritiesMapper = (authorities) -> authorities; + + private Duration clockSkew = Duration.ofSeconds(60); + + @Override + public Mono onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient, Authentication principal, + Map attributes) { + // The response must contain the openid scope. + if (!authorizedClient.getAccessToken().getScopes().contains(OidcScopes.OPENID)) { + return Mono.empty(); + } + // The response must contain an id_token. + String idToken = extractIdToken(attributes); + if (!StringUtils.hasText(idToken)) { + return Mono.empty(); + } + if (!(principal instanceof OAuth2AuthenticationToken authenticationToken) + || authenticationToken.getClass() != OAuth2AuthenticationToken.class) { + // If the application customizes the authentication result, then a custom + // handler should be provided. + return Mono.empty(); + } + // The current principal must be an OidcUser. + if (!(authenticationToken.getPrincipal() instanceof OidcUser existingOidcUser)) { + return Mono.empty(); + } + ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); + // The registrationId must match the one used to log in. + if (!authenticationToken.getAuthorizedClientRegistrationId().equals(clientRegistration.getRegistrationId())) { + return Mono.empty(); + } + // Create, validate OidcIdToken and refresh OidcUser in the SecurityContext. + return Mono.justOrEmpty((ServerWebExchange) attributes.get(ServerWebExchange.class.getName())) + .switchIfEmpty(currentServerWebExchangeMono) + .flatMap((exchange) -> { + ReactiveJwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration); + return jwtDecoder.decode(idToken).onErrorMap(JwtException.class, (ex) -> { + OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(), + null); + return new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex); + }) + .map((jwt) -> new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), + jwt.getClaims())) + .doOnNext((oidcIdToken) -> validateIdToken(existingOidcUser, oidcIdToken)) + .flatMap((oidcIdToken) -> { + OidcUserRequest userRequest = new OidcUserRequest(clientRegistration, + authorizedClient.getAccessToken(), oidcIdToken); + return this.userService.loadUser(userRequest); + }) + .flatMap((oidcUser) -> refreshSecurityContext(exchange, clientRegistration, authenticationToken, + oidcUser)); + }); + } + + /** + * Sets a {@link ServerSecurityContextRepository} to use for refreshing a + * {@link SecurityContext}, defaults to + * {@link WebSessionServerSecurityContextRepository}. + * @param serverSecurityContextRepository the {@link ServerSecurityContextRepository} + * to use + */ + public void setServerSecurityContextRepository(ServerSecurityContextRepository serverSecurityContextRepository) { + Assert.notNull(serverSecurityContextRepository, "serverSecurityContextRepository cannot be null"); + this.serverSecurityContextRepository = serverSecurityContextRepository; + } + + /** + * Sets a {@link ReactiveJwtDecoderFactory} to use for decoding refreshed oidc + * id-token, defaults to {@link ReactiveOidcIdTokenDecoderFactory}. + * @param jwtDecoderFactory the {@link ReactiveJwtDecoderFactory} to use + */ + public void setJwtDecoderFactory(ReactiveJwtDecoderFactory jwtDecoderFactory) { + Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null"); + this.jwtDecoderFactory = jwtDecoderFactory; + } + + /** + * Sets a {@link GrantedAuthoritiesMapper} to use for mapping + * {@link GrantedAuthority}s, defaults to no-op implementation. + * @param authoritiesMapper the {@link GrantedAuthoritiesMapper} to use + */ + public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) { + Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null"); + this.authoritiesMapper = authoritiesMapper; + } + + /** + * Sets a {@link ReactiveOAuth2UserService} to use for loading an {@link OidcUser} + * from refreshed oidc id-token, defaults to {@link OidcReactiveOAuth2UserService}. + * @param userService the {@link ReactiveOAuth2UserService} to use + */ + public void setUserService(ReactiveOAuth2UserService userService) { + Assert.notNull(userService, "userService cannot be null"); + this.userService = userService; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the + * {@link OidcIdToken#getIssuedAt()} to match the existing + * {@link OidcUser#getIdToken()}'s issuedAt time, defaults to 60 seconds. + * @param clockSkew the maximum acceptable clock skew to use + */ + public void setClockSkew(Duration clockSkew) { + Assert.notNull(clockSkew, "clockSkew cannot be null"); + Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0"); + this.clockSkew = clockSkew; + } + + private String extractIdToken(Map attributes) { + if (attributes.get(OidcParameterNames.ID_TOKEN) instanceof String idToken) { + return idToken; + } + return null; + } + + private void validateIdToken(OidcUser existingOidcUser, OidcIdToken idToken) { + // OpenID Connect Core 1.0 - Section 12.2 Successful Refresh Response + // If an ID Token is returned as a result of a token refresh request, the + // following requirements apply: + // its iss Claim Value MUST be the same as in the ID Token issued when the + // original authentication occurred, + validateIssuer(existingOidcUser, idToken); + // its sub Claim Value MUST be the same as in the ID Token issued when the + // original authentication occurred, + validateSubject(existingOidcUser, idToken); + // its iat Claim MUST represent the time that the new ID Token is issued, + validateIssuedAt(existingOidcUser, idToken); + // its aud Claim Value MUST be the same as in the ID Token issued when the + // original authentication occurred, + validateAudience(existingOidcUser, idToken); + // if the ID Token contains an auth_time Claim, its value MUST represent the time + // of the original authentication - not the time that the new ID token is issued, + validateAuthenticatedAt(existingOidcUser, idToken); + // it SHOULD NOT have a nonce Claim, even when the ID Token issued at the time of + // the original authentication contained nonce; however, if it is present, its + // value MUST be the same as in the ID Token issued at the time of the original + // authentication, + validateNonce(existingOidcUser, idToken); + } + + private void validateIssuer(OidcUser existingOidcUser, OidcIdToken idToken) { + if (!idToken.getIssuer().toString().equals(existingOidcUser.getIdToken().getIssuer().toString())) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid issuer", + REFRESH_TOKEN_RESPONSE_ERROR_URI); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + + private void validateSubject(OidcUser existingOidcUser, OidcIdToken idToken) { + if (!idToken.getSubject().equals(existingOidcUser.getIdToken().getSubject())) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid subject", + REFRESH_TOKEN_RESPONSE_ERROR_URI); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + + private void validateIssuedAt(OidcUser existingOidcUser, OidcIdToken idToken) { + if (!idToken.getIssuedAt().isAfter(existingOidcUser.getIdToken().getIssuedAt().minus(this.clockSkew))) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid issued at time", + REFRESH_TOKEN_RESPONSE_ERROR_URI); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + + private void validateAudience(OidcUser existingOidcUser, OidcIdToken idToken) { + if (!isValidAudience(existingOidcUser, idToken)) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid audience", + REFRESH_TOKEN_RESPONSE_ERROR_URI); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + + private boolean isValidAudience(OidcUser existingOidcUser, OidcIdToken idToken) { + List idTokenAudiences = idToken.getAudience(); + Set oidcUserAudiences = new HashSet<>(existingOidcUser.getIdToken().getAudience()); + if (idTokenAudiences.size() != oidcUserAudiences.size()) { + return false; + } + for (String audience : idTokenAudiences) { + if (!oidcUserAudiences.contains(audience)) { + return false; + } + } + return true; + } + + private void validateAuthenticatedAt(OidcUser existingOidcUser, OidcIdToken idToken) { + if (idToken.getAuthenticatedAt() == null) { + return; + } + if (!idToken.getAuthenticatedAt().equals(existingOidcUser.getIdToken().getAuthenticatedAt())) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid authenticated at time", + REFRESH_TOKEN_RESPONSE_ERROR_URI); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + + private void validateNonce(OidcUser existingOidcUser, OidcIdToken idToken) { + if (!StringUtils.hasText(idToken.getNonce())) { + return; + } + if (!idToken.getNonce().equals(existingOidcUser.getIdToken().getNonce())) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE, "Invalid nonce", + REFRESH_TOKEN_RESPONSE_ERROR_URI); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + + private Mono refreshSecurityContext(ServerWebExchange exchange, ClientRegistration clientRegistration, + OAuth2AuthenticationToken authenticationToken, OidcUser oidcUser) { + Collection mappedAuthorities = this.authoritiesMapper + .mapAuthorities(oidcUser.getAuthorities()); + OAuth2AuthenticationToken authenticationResult = new OAuth2AuthenticationToken(oidcUser, mappedAuthorities, + clientRegistration.getRegistrationId()); + authenticationResult.setDetails(authenticationToken.getDetails()); + SecurityContextImpl securityContext = new SecurityContextImpl(authenticationResult); + return this.serverSecurityContextRepository.save(exchange, securityContext); + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java index 03e0bef1c4..4f9f86f85e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java @@ -21,7 +21,9 @@ import java.time.Duration; import java.time.Instant; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; import java.util.Set; import reactor.core.publisher.Mono; @@ -40,6 +42,7 @@ import org.springframework.util.Assert; * {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant. * * @author Joe Grandja + * @author Evgeniy Cheban * @since 5.2 * @see ReactiveOAuth2AuthorizedClientProvider * @see WebClientReactiveRefreshTokenTokenResponseClient @@ -49,6 +52,8 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = new WebClientReactiveRefreshTokenTokenResponseClient(); + private ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler(); + private Duration clockSkew = Duration.ofSeconds(60); private Clock clock = Clock.systemUTC(); @@ -96,12 +101,17 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider .flatMap(this.accessTokenResponseClient::getTokenResponse) .onErrorMap(OAuth2AuthorizationException.class, (e) -> new ClientAuthorizationException(e.getError(), clientRegistration.getRegistrationId(), e)) - .map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), - tokenResponse.getAccessToken(), tokenResponse.getRefreshToken())); - } - - private boolean hasTokenExpired(OAuth2Token token) { - return this.clock.instant().isAfter(token.getExpiresAt().minus(this.clockSkew)); + .flatMap((tokenResponse) -> { + OAuth2AuthorizedClient refreshedClient = new OAuth2AuthorizedClient(clientRegistration, + context.getPrincipal().getName(), tokenResponse.getAccessToken(), + tokenResponse.getRefreshToken()); + Map attributes = new HashMap<>(context.getAttributes()); + attributes.putAll(tokenResponse.getAdditionalParameters()); + return this.authorizationSuccessHandler + .onAuthorizationSuccess(refreshedClient, context.getPrincipal(), + Collections.unmodifiableMap(attributes)) + .thenReturn(refreshedClient); + }); } /** @@ -116,6 +126,19 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider this.accessTokenResponseClient = accessTokenResponseClient; } + /** + * Sets a {@link ReactiveOAuth2AuthorizationSuccessHandler} to use for handling + * successful refresh token response, defaults to + * {@link RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler}. + * @param authorizationSuccessHandler the + * {@link ReactiveOAuth2AuthorizationSuccessHandler} to use + * @since 7.1 + */ + public void setAuthorizationSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler) { + Assert.notNull(authorizationSuccessHandler, "authorizationSuccessHandler cannot be null"); + this.authorizationSuccessHandler = authorizationSuccessHandler; + } + /** * Sets the maximum acceptable clock skew, which is used when checking the * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is @@ -143,4 +166,8 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider this.clock = clock; } + private boolean hasTokenExpired(OAuth2Token token) { + return this.clock.instant().isAfter(token.getExpiresAt().minus(this.clockSkew)); + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java index 9e4b2db717..e76e765d8a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java @@ -318,14 +318,14 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React return Mono.justOrEmpty(serverWebExchange) .switchIfEmpty(currentServerWebExchangeMono) .flatMap((exchange) -> { - Map contextAttributes = Collections.emptyMap(); + Map contextAttributes = new HashMap<>(); + contextAttributes.put(ServerWebExchange.class.getName(), serverWebExchange); String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE); if (StringUtils.hasText(scope)) { - contextAttributes = new HashMap<>(); contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, StringUtils.delimitedListToStringArray(scope, " ")); } - return Mono.just(contextAttributes); + return Mono.just(Collections.unmodifiableMap(contextAttributes)); }) .defaultIfEmpty(Collections.emptyMap()); // @formatter:on diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java index f4aa10d4b1..6a60fb977a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java @@ -51,6 +51,8 @@ import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.web.server.context.ServerSecurityContextRepository; +import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.reactive.function.client.ClientRequest; @@ -96,6 +98,7 @@ import org.springframework.web.server.ServerWebExchange; * @author Rob Winch * @author Joe Grandja * @author Phil Clay + * @author Evgeniy Cheban * @since 5.1 */ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction { @@ -139,6 +142,8 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements private ClientResponseHandler clientResponseHandler; + private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + /** * Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the * provided parameters. @@ -330,8 +335,11 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements } private Mono exchangeAndHandleResponse(ClientRequest request, ExchangeFunction next) { - return next.exchange(request) - .transform((responseMono) -> this.clientResponseHandler.handleResponse(request, responseMono)); + // Re-request an Authentication from serverSecurityContextRepository since it + // might have been changed during provider invocation. + return effectiveAuthentication(request).flatMap((authentication) -> next.exchange(request) + .transform((responseMono) -> this.clientResponseHandler.handleResponse(request, responseMono)) + .contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication))); } private Mono authorizedClient(ClientRequest request) { @@ -362,6 +370,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements // @formatter:on } + private Mono effectiveAuthentication(ClientRequest request) { + // @formatter:off + return effectiveServerWebExchange(request) + .filter(Optional::isPresent) + .map(Optional::get) + .flatMap(this.serverSecurityContextRepository::load) + .mapNotNull(SecurityContext::getAuthentication) + .switchIfEmpty(this.currentAuthenticationMono); + // @formatter:on + } + /** * Returns a {@link Mono} the emits the {@code clientRegistrationId} that is active * for the given request. @@ -445,6 +464,19 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler); } + /** + * Sets a {@link ServerSecurityContextRepository} to use for re-obtaining a + * {@link SecurityContext} if it has been refreshed during provider invocation, + * defaults to {@link WebSessionServerSecurityContextRepository}. + * @param serverSecurityContextRepository the {@link ServerSecurityContextRepository} + * to use + * @since 7.1 + */ + public void setServerSecurityContextRepository(ServerSecurityContextRepository serverSecurityContextRepository) { + Assert.notNull(serverSecurityContextRepository, "serverSecurityContextRepository cannot be null"); + this.serverSecurityContextRepository = serverSecurityContextRepository; + } + @FunctionalInterface private interface ClientResponseHandler { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandlerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandlerTests.java new file mode 100644 index 0000000000..5472886f4b --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandlerTests.java @@ -0,0 +1,398 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.core.oidc.user.TestOidcUsers; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory; +import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; +import org.springframework.web.server.ServerWebExchange; + +import static org.assertj.core.api.Assertions.assertThatException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler}. + * + * @author Evgeniy Cheban + */ +class RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandlerTests { + + @Test + void onAuthorizationSuccessWhenIdTokenValidThenSecurityContextRefreshed() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + DefaultOidcUser principal = TestOidcUsers.create(); + OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal, + principal.getAuthorities(), clientRegistration.getRegistrationId()); + OAuth2AccessToken accessToken = createAccessToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + accessToken, null); + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); + Map attributes = Map.of(ServerWebExchange.class.getName(), exchange, + OidcParameterNames.ID_TOKEN, "id-token-1234"); + Map claims = new HashMap<>(); + claims.put("iss", principal.getIssuer()); + claims.put("sub", principal.getSubject()); + claims.put("aud", principal.getAudience()); + claims.put("nonce", principal.getNonce()); + Jwt jwt = mock(Jwt.class); + given(jwt.getTokenValue()).willReturn("id-token-1234"); + given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt()); + given(jwt.getClaims()).willReturn(claims); + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); + given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt)); + ReactiveJwtDecoderFactory reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class); + given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder); + ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); + given(userService.loadUser(any())).willReturn(Mono.just(principal)); + WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler(); + handler.setJwtDecoderFactory(reactiveJwtDecoderFactory); + handler.setUserService(userService); + handler.setServerSecurityContextRepository(serverSecurityContextRepository); + StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes)) + .verifyComplete(); + StepVerifier.create(serverSecurityContextRepository.load(exchange).map(SecurityContext::getAuthentication)) + .expectNext(authenticationToken) + .verifyComplete(); + } + + @Test + void onAuthorizationSuccessWhenIdTokenIssuerNotSameThenException() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + DefaultOidcUser principal = TestOidcUsers.create(); + OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal, + principal.getAuthorities(), clientRegistration.getRegistrationId()); + OAuth2AccessToken accessToken = createAccessToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + accessToken, null); + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); + Map attributes = Map.of(ServerWebExchange.class.getName(), exchange, + OidcParameterNames.ID_TOKEN, "id-token-1234"); + Map claims = new HashMap<>(); + claims.put("iss", "https://issuer.com"); + claims.put("sub", principal.getSubject()); + claims.put("aud", principal.getAudience()); + claims.put("nonce", principal.getNonce()); + Jwt jwt = mock(Jwt.class); + given(jwt.getTokenValue()).willReturn("id-token-1234"); + given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt()); + given(jwt.getClaims()).willReturn(claims); + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); + given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt)); + ReactiveJwtDecoderFactory reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class); + given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder); + ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); + given(userService.loadUser(any())).willReturn(Mono.just(principal)); + WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler(); + handler.setJwtDecoderFactory(reactiveJwtDecoderFactory); + handler.setUserService(userService); + handler.setServerSecurityContextRepository(serverSecurityContextRepository); + StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes)) + .verifyErrorMessage("[invalid_id_token] Invalid issuer"); + } + + @Test + void onAuthorizationSuccessWhenIdTokenSubNotSameThenException() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + DefaultOidcUser principal = TestOidcUsers.create(); + OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal, + principal.getAuthorities(), clientRegistration.getRegistrationId()); + OAuth2AccessToken accessToken = createAccessToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + accessToken, null); + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); + Map attributes = Map.of(ServerWebExchange.class.getName(), exchange, + OidcParameterNames.ID_TOKEN, "id-token-1234"); + Map claims = new HashMap<>(); + claims.put("iss", principal.getIssuer()); + claims.put("sub", "invalid_sub"); + claims.put("aud", principal.getAudience()); + claims.put("nonce", principal.getNonce()); + Jwt jwt = mock(Jwt.class); + given(jwt.getTokenValue()).willReturn("id-token-1234"); + given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt()); + given(jwt.getClaims()).willReturn(claims); + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); + given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt)); + ReactiveJwtDecoderFactory reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class); + given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder); + ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); + given(userService.loadUser(any())).willReturn(Mono.just(principal)); + WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler(); + handler.setJwtDecoderFactory(reactiveJwtDecoderFactory); + handler.setUserService(userService); + handler.setServerSecurityContextRepository(serverSecurityContextRepository); + StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes)) + .verifyErrorMessage("[invalid_id_token] Invalid subject"); + } + + @Test + void onAuthorizationSuccessWhenIdTokenIatNotAfterThenException() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + DefaultOidcUser principal = TestOidcUsers.create(); + OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal, + principal.getAuthorities(), clientRegistration.getRegistrationId()); + OAuth2AccessToken accessToken = createAccessToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + accessToken, null); + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); + Map attributes = Map.of(ServerWebExchange.class.getName(), exchange, + OidcParameterNames.ID_TOKEN, "id-token-1234"); + Map claims = new HashMap<>(); + claims.put("iss", principal.getIssuer()); + claims.put("sub", principal.getSubject()); + claims.put("aud", principal.getAudience()); + claims.put("nonce", principal.getNonce()); + Jwt jwt = mock(Jwt.class); + given(jwt.getTokenValue()).willReturn("id-token-1234"); + given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt().minus(Duration.ofDays(1))); + given(jwt.getClaims()).willReturn(claims); + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); + given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt)); + ReactiveJwtDecoderFactory reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class); + given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder); + ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); + given(userService.loadUser(any())).willReturn(Mono.just(principal)); + WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler(); + handler.setJwtDecoderFactory(reactiveJwtDecoderFactory); + handler.setUserService(userService); + handler.setServerSecurityContextRepository(serverSecurityContextRepository); + StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes)) + .verifyErrorMessage("[invalid_id_token] Invalid issued at time"); + } + + @Test + void onAuthorizationSuccessWhenIdTokenAudEmptyThenException() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + DefaultOidcUser principal = TestOidcUsers.create(); + OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal, + principal.getAuthorities(), clientRegistration.getRegistrationId()); + OAuth2AccessToken accessToken = createAccessToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + accessToken, null); + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); + Map attributes = Map.of(ServerWebExchange.class.getName(), exchange, + OidcParameterNames.ID_TOKEN, "id-token-1234"); + Map claims = new HashMap<>(); + claims.put("iss", principal.getIssuer()); + claims.put("sub", principal.getSubject()); + claims.put("aud", Collections.emptyList()); + claims.put("nonce", principal.getNonce()); + Jwt jwt = mock(Jwt.class); + given(jwt.getTokenValue()).willReturn("id-token-1234"); + given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt()); + given(jwt.getClaims()).willReturn(claims); + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); + given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt)); + ReactiveJwtDecoderFactory reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class); + given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder); + ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); + given(userService.loadUser(any())).willReturn(Mono.just(principal)); + WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler(); + handler.setJwtDecoderFactory(reactiveJwtDecoderFactory); + handler.setUserService(userService); + handler.setServerSecurityContextRepository(serverSecurityContextRepository); + StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes)) + .verifyErrorMessage("[invalid_id_token] Invalid audience"); + } + + @Test + void onAuthorizationSuccessWhenIdTokenAudNotContainThenException() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + DefaultOidcUser principal = TestOidcUsers.create(); + OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal, + principal.getAuthorities(), clientRegistration.getRegistrationId()); + OAuth2AccessToken accessToken = createAccessToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + accessToken, null); + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); + Map attributes = Map.of(ServerWebExchange.class.getName(), exchange, + OidcParameterNames.ID_TOKEN, "id-token-1234"); + Map claims = new HashMap<>(); + claims.put("iss", principal.getIssuer()); + claims.put("sub", principal.getSubject()); + claims.put("aud", List.of("invalid_client-id")); + claims.put("nonce", principal.getNonce()); + Jwt jwt = mock(Jwt.class); + given(jwt.getTokenValue()).willReturn("id-token-1234"); + given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt()); + given(jwt.getClaims()).willReturn(claims); + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); + given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt)); + ReactiveJwtDecoderFactory reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class); + given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder); + ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); + given(userService.loadUser(any())).willReturn(Mono.just(principal)); + WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler(); + handler.setJwtDecoderFactory(reactiveJwtDecoderFactory); + handler.setUserService(userService); + handler.setServerSecurityContextRepository(serverSecurityContextRepository); + StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes)) + .verifyErrorMessage("[invalid_id_token] Invalid audience"); + } + + @Test + void onAuthorizationSuccessWhenIdTokenAuthTimeNotSameThenException() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + DefaultOidcUser principal = TestOidcUsers.create(); + OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal, + principal.getAuthorities(), clientRegistration.getRegistrationId()); + OAuth2AccessToken accessToken = createAccessToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + accessToken, null); + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); + Map attributes = Map.of(ServerWebExchange.class.getName(), exchange, + OidcParameterNames.ID_TOKEN, "id-token-1234"); + Map claims = new HashMap<>(); + claims.put("iss", principal.getIssuer()); + claims.put("sub", principal.getSubject()); + claims.put("aud", principal.getAudience()); + claims.put("auth_time", principal.getIssuedAt()); + claims.put("nonce", principal.getNonce()); + Jwt jwt = mock(Jwt.class); + given(jwt.getTokenValue()).willReturn("id-token-1234"); + given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt()); + given(jwt.getClaims()).willReturn(claims); + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); + given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt)); + ReactiveJwtDecoderFactory reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class); + given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder); + ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); + given(userService.loadUser(any())).willReturn(Mono.just(principal)); + WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler(); + handler.setJwtDecoderFactory(reactiveJwtDecoderFactory); + handler.setUserService(userService); + handler.setServerSecurityContextRepository(serverSecurityContextRepository); + StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes)) + .verifyErrorMessage("[invalid_id_token] Invalid authenticated at time"); + } + + @Test + void onAuthorizationSuccessWhenIdTokenNonceNotSameThenException() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + DefaultOidcUser principal = TestOidcUsers.create(); + OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal, + principal.getAuthorities(), clientRegistration.getRegistrationId()); + OAuth2AccessToken accessToken = createAccessToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + accessToken, null); + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); + Map attributes = Map.of(ServerWebExchange.class.getName(), exchange, + OidcParameterNames.ID_TOKEN, "id-token-1234"); + Map claims = new HashMap<>(); + claims.put("iss", principal.getIssuer()); + claims.put("sub", principal.getSubject()); + claims.put("aud", principal.getAudience()); + claims.put("nonce", "invalid_nonce"); + Jwt jwt = mock(Jwt.class); + given(jwt.getTokenValue()).willReturn("id-token-1234"); + given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt()); + given(jwt.getClaims()).willReturn(claims); + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); + given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt)); + ReactiveJwtDecoderFactory reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class); + given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder); + ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); + given(userService.loadUser(any())).willReturn(Mono.just(principal)); + WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler(); + handler.setJwtDecoderFactory(reactiveJwtDecoderFactory); + handler.setUserService(userService); + handler.setServerSecurityContextRepository(serverSecurityContextRepository); + StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes)) + .verifyErrorMessage("[invalid_nonce] Invalid nonce"); + } + + @Test + void setServerSecurityContextRepositoryWhenNullThenException() { + assertThatException() + .isThrownBy(() -> new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler() + .setServerSecurityContextRepository(null)) + .withMessage("serverSecurityContextRepository cannot be null"); + } + + @Test + void setJwtDecoderFactoryWhenNullThenException() { + assertThatException() + .isThrownBy(() -> new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler().setJwtDecoderFactory(null)) + .withMessage("jwtDecoderFactory cannot be null"); + } + + @Test + void setAuthoritiesMapperWhenNullThenException() { + assertThatException() + .isThrownBy(() -> new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler().setAuthoritiesMapper(null)) + .withMessage("authoritiesMapper cannot be null"); + } + + @Test + void setUserServiceWhenNullThenException() { + assertThatException() + .isThrownBy(() -> new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler().setUserService(null)) + .withMessage("userService cannot be null"); + } + + @Test + void setClockSkewWhenNullThenException() { + assertThatException() + .isThrownBy(() -> new RefreshOidcUserReactiveOAuth2AuthorizationSuccessHandler().setClockSkew(null)) + .withMessage("clockSkew cannot be null"); + } + + private static OAuth2AccessToken createAccessToken() { + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); + return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt, + Set.of(OidcScopes.OPENID)); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProviderTests.java index f30a4ac5d2..9550d88670 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProviderTests.java @@ -49,6 +49,7 @@ import static org.mockito.Mockito.verify; * Tests for {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider}. * * @author Joe Grandja + * @author Evgeniy Cheban */ public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests { @@ -84,6 +85,15 @@ public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests { .withMessage("accessTokenResponseClient cannot be null"); } + @Test + public void setAuthorizationSuccessHandlerWhenHandlerIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setAuthorizationSuccessHandler(null)) + .withMessage("authorizationSuccessHandler cannot be null"); + // @formatter:on + } + @Test public void setClockSkewWhenNullThenThrowIllegalArgumentException() { // @formatter:off diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/MockExchangeFunction.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/MockExchangeFunction.java index 6811c4e344..c562572065 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/MockExchangeFunction.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/MockExchangeFunction.java @@ -18,9 +18,13 @@ package org.springframework.security.oauth2.client.web.reactive.function.client; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; import reactor.core.publisher.Mono; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ExchangeFunction; @@ -29,14 +33,21 @@ import static org.mockito.Mockito.mock; /** * @author Rob Winch + * @author Evgeniy Cheban * @since 5.1 */ public class MockExchangeFunction implements ExchangeFunction { + private final AtomicReference authenticationCaptor = new AtomicReference<>(); + private List requests = new ArrayList<>(); private ClientResponse response = mock(ClientResponse.class); + public Authentication getCapturedAuthentication() { + return this.authenticationCaptor.get(); + } + public ClientRequest getRequest() { return this.requests.get(this.requests.size() - 1); } @@ -53,8 +64,14 @@ public class MockExchangeFunction implements ExchangeFunction { public Mono exchange(ClientRequest request) { return Mono.defer(() -> { this.requests.add(request); - return Mono.just(this.response); + return captureAuthentication().then(Mono.just(this.response)); }); } + private Mono captureAuthentication() { + return ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication) + .doOnNext(this.authenticationCaptor::set); + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionITests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionITests.java index dfe7613912..de86312619 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionITests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionITests.java @@ -19,8 +19,15 @@ package org.springframework.security.oauth2.client.web.reactive.function.client; import java.time.Duration; import java.time.Instant; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jose.jwk.gen.RSAKeyGenerator; +import com.nimbusds.jose.jwk.source.ImmutableJWKSet; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.junit.jupiter.api.AfterEach; @@ -38,23 +45,42 @@ import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JwsHeader; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.jwt.JwtEncoderParameters; +import org.springframework.security.oauth2.jwt.NimbusJwtEncoder; +import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.ExchangeFilterFunction; +import org.springframework.web.reactive.function.client.ExchangeFunction; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClientResponseException; import org.springframework.web.server.ServerWebExchange; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.InstanceOfAssertFactories.type; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; @@ -67,9 +93,12 @@ import static org.mockito.Mockito.verify; /** * @author Phil Clay + * @author Evgeniy Cheban */ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests { + private final AuthenticationCapturingExchangeFilterFunction authenticationCapturingFilter = new AuthenticationCapturingExchangeFilterFunction(); + private ReactiveClientRegistrationRepository clientRegistrationRepository; private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; @@ -118,6 +147,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests { // @formatter:off this.webClient = WebClient.builder() .filter(this.authorizedClientFilter) + .filter(this.authenticationCapturingFilter) .build(); // @formatter:on this.authentication = new TestingAuthenticationToken("principal", "password"); @@ -218,6 +248,116 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests { assertThat(refreshedAuthorizedClient.getAccessToken().getTokenValue()).isEqualTo("refreshed-access-token"); } + @Test + public void requestWhenAuthorizedButExpiredThenRefreshSecurityContext() throws Exception { + // Current OIDC user in the SecurityContext. + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .issuerUri(this.serverUrl) + .tokenUri(this.serverUrl) + .jwkSetUri(this.serverUrl + "oauth2/jwk") + .userInfoUri(this.serverUrl + "user") + .userNameAttributeName("username") + .build(); + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant expiresAt = issuedAt.plus(Duration.ofHours(1)); + OidcIdToken existingIdToken = OidcIdToken.withTokenValue("id-token-1234") + .issuer(this.serverUrl) + .subject("subject-1234") + .audience(List.of(clientRegistration.getClientId())) + .issuedAt(issuedAt) + .expiresAt(expiresAt) + .build(); + OidcUser existingOidcUser = new DefaultOidcUser(Collections.emptyList(), existingIdToken); + this.authentication = new OAuth2AuthenticationToken(existingOidcUser, Collections.emptyList(), + clientRegistration.getRegistrationId()); + // Generate new OIDC ID token with refreshed user information. + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).type("JWT").build(); + JwtClaimsSet jwtClaimsSet = JwtClaimsSet.builder() + .subject("subject-1234") + .audience(List.of(clientRegistration.getClientId())) + .issuer(this.serverUrl) + .issuedAt(Instant.now()) + .expiresAt(Instant.now().plusSeconds(3600)) + .build(); + RSAKey key = new RSAKeyGenerator(2048).generate(); + JWKSet jwkSet = new JWKSet(key); + JwtEncoder jwtEncoder = new NimbusJwtEncoder(new ImmutableJWKSet<>(jwkSet)); + String idToken = jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet)).getTokenValue(); + // @formatter:off + String accessTokenResponse = """ + { + "access_token": "refreshed-access-token", + "id_token": "%s", + "token_type": "bearer", + "expires_in": "3600" + } + """.formatted(idToken); + String userInfoResponse = """ + { + "sub": "subject-1234", + "username": "refreshed-username" + } + """; + String clientResponse = """ + { + "attribute1": "value1", + "attribute2": "value2" + } + """; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenResponse)); + this.server.enqueue(jsonResponse(jwkSet.toString())); + this.server.enqueue(jsonResponse(userInfoResponse)); + this.server.enqueue(jsonResponse(clientResponse)); + given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(clientRegistration)); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "expired-access-token", issuedAt, expiresAt, + new HashSet<>(Arrays.asList("read", "write", OidcScopes.OPENID))); + OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, + this.authentication.getName(), accessToken, refreshToken); + doReturn(Mono.just(authorizedClient)).when(this.authorizedClientRepository) + .loadAuthorizedClient(eq(clientRegistration.getRegistrationId()), eq(this.authentication), + eq(this.exchange)); + this.webClient.get() + .uri(this.serverUrl) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .clientRegistrationId(clientRegistration.getRegistrationId())) + .retrieve() + .bodyToMono(String.class) + .contextWrite(Context.of(ServerWebExchange.class, this.exchange)) + .contextWrite(ReactiveSecurityContextHolder.withAuthentication(this.authentication)) + .block(); + assertThat(this.server.getRequestCount()).isEqualTo(4); + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor + .forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository).saveAuthorizedClient(authorizedClientCaptor.capture(), + eq(this.authentication), eq(this.exchange)); + OAuth2AuthorizedClient refreshedAuthorizedClient = authorizedClientCaptor.getValue(); + assertThat(refreshedAuthorizedClient.getClientRegistration()).isSameAs(clientRegistration); + assertThat(refreshedAuthorizedClient.getAccessToken().getTokenValue()).isEqualTo("refreshed-access-token"); + WebSessionServerSecurityContextRepository securityContextRepository = new WebSessionServerSecurityContextRepository(); + // Capture and verify that the refreshed Authentication object was propagated to + // the next ExchangeFilterFunction's context. + Authentication capturedAuthentication = this.authenticationCapturingFilter.getCapturedAuthentication(); + assertThat(capturedAuthentication).isNotNull(); + Authentication refreshedAuthentication = securityContextRepository.load(this.exchange) + .mapNotNull(SecurityContext::getAuthentication) + .block(); + assertThat(refreshedAuthentication).isNotNull(); + assertThat(refreshedAuthentication).isSameAs(capturedAuthentication); + assertThat(refreshedAuthentication).asInstanceOf(type(OAuth2AuthenticationToken.class)) + .extracting(OAuth2AuthenticationToken::getPrincipal) + .asInstanceOf(type(OidcUser.class)) + .satisfies((oidcUser) -> { + // Verify that the OidcUser's attributes match the id_token's claims. + assertThat(oidcUser.getIdToken()).extracting(AbstractOAuth2Token::getTokenValue).isEqualTo(idToken); + assertThat(oidcUser.getSubject()).isEqualTo("subject-1234"); + assertThat(oidcUser.getName()).isEqualTo("refreshed-username"); + }); + } + @Test public void requestMultipleWhenNoneAuthorizedThenAuthorizeAndSendRequest() { // @formatter:off @@ -349,4 +489,22 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests { // @formatter:on } + private static final class AuthenticationCapturingExchangeFilterFunction implements ExchangeFilterFunction { + + private final AtomicReference authenticationCaptor = new AtomicReference<>(); + + @Override + public Mono filter(ClientRequest request, ExchangeFunction next) { + return ReactiveSecurityContextHolder.getContext().flatMap((ctx) -> { + this.authenticationCaptor.set(ctx.getAuthentication()); + return next.exchange(request); + }); + } + + private Authentication getCapturedAuthentication() { + return this.authenticationCaptor.get(); + } + + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java index b78a8258f6..e2046e4af1 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -59,6 +59,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.oauth2.client.AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.ClientCredentialsReactiveOAuth2AuthorizedClientProvider; @@ -66,6 +67,7 @@ import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2Authoriz import org.springframework.security.oauth2.client.JwtBearerReactiveOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationFailureHandler; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationSuccessHandler; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; @@ -89,8 +91,10 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.user.DefaultOAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.security.oauth2.core.user.TestOAuth2Users; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.TestJwts; +import org.springframework.security.web.server.context.ServerSecurityContextRepository; import org.springframework.web.reactive.function.BodyInserter; import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.ClientResponse; @@ -113,6 +117,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; /** * @author Rob Winch + * @author Evgeniy Cheban * @since 5.1 */ @ExtendWith(MockitoExtension.class) @@ -133,6 +138,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Mock private ReactiveOAuth2AccessTokenResponseClient jwtBearerTokenResponseClient; + @Mock + private ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler; + @Mock private ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler; @@ -170,7 +178,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { .builder() .authorizationCode() .refreshToken( - (configurer) -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) + (configurer) -> configurer + .authorizationSuccessHandler(this.authorizationSuccessHandler) + .accessTokenResponseClient(this.refreshTokenTokenResponseClient)) .clientCredentials( (configurer) -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) .provider(jwtBearerAuthorizedClientProvider) @@ -201,6 +211,13 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { .isThrownBy(() -> new ServerOAuth2AuthorizedClientExchangeFilterFunction(null)); } + @Test + public void setServerSecurityContextRepositoryWhenHandlerIsNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager) + .setServerSecurityContextRepository(null)); + } + @Test public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() { ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build(); @@ -312,6 +329,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { .refreshToken("refresh-1") .build(); given(this.refreshTokenTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(response)); + given(this.authorizationSuccessHandler.onAuthorizationSuccess(any(), any(), any())).willReturn(Mono.empty()); Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(), @@ -326,14 +344,23 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { // @formatter:on TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); // @formatter:off + DefaultOAuth2User refreshedUser = TestOAuth2Users.create(); + OAuth2AuthenticationToken refreshedAuthentication = new OAuth2AuthenticationToken(refreshedUser, refreshedUser.getAuthorities(), this.registration.getRegistrationId()); + SecurityContextImpl securityContext = new SecurityContextImpl(refreshedAuthentication); + ServerSecurityContextRepository securityContextRepository = mock(ServerSecurityContextRepository.class); + given(securityContextRepository.load(this.serverWebExchange)).willReturn(Mono.just(securityContext)); + this.function.setServerSecurityContextRepository(securityContextRepository); this.function.filter(request, this.exchange) .contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication)) .contextWrite(serverWebExchange()) .block(); + Authentication currentAuthentication = this.exchange.getCapturedAuthentication(); + assertThat(currentAuthentication).isSameAs(refreshedAuthentication); // @formatter:on verify(this.refreshTokenTokenResponseClient).getTokenResponse(any()); verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(authentication), any()); + verify(securityContextRepository).load(this.serverWebExchange); OAuth2AuthorizedClient newAuthorizedClient = this.authorizedClientCaptor.getValue(); assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken()); assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(response.getRefreshToken()); @@ -355,6 +382,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { .refreshToken("refresh-1") .build(); given(this.refreshTokenTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(response)); + given(this.authorizationSuccessHandler.onAuthorizationSuccess(any(), any(), any())).willReturn(Mono.empty()); Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(),