diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java index 3041ce764f..8ec6fdb17e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -80,7 +80,13 @@ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2Author if (registration == null) { return null; } - return (T) this.authorizedClients.get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName)); + OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients + .get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName)); + if (cachedAuthorizedClient == null) { + return null; + } + return (T) new OAuth2AuthorizedClient(registration, cachedAuthorizedClient.getPrincipalName(), + cachedAuthorizedClient.getAccessToken(), cachedAuthorizedClient.getRefreshToken()); } @Override diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java index 3cf977d477..a096d04d0f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,8 +62,15 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.hasText(principalName, "principalName cannot be empty"); return (Mono) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) - .map((clientRegistration) -> new OAuth2AuthorizedClientId(clientRegistrationId, principalName)) - .flatMap((identifier) -> Mono.justOrEmpty(this.authorizedClients.get(identifier))); + .mapNotNull((clientRegistration) -> { + OAuth2AuthorizedClientId id = new OAuth2AuthorizedClientId(clientRegistrationId, principalName); + OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients.get(id); + if (cachedAuthorizedClient == null) { + return null; + } + return new OAuth2AuthorizedClient(clientRegistration, cachedAuthorizedClient.getPrincipalName(), + cachedAuthorizedClient.getAccessToken(), cachedAuthorizedClient.getRefreshToken()); + }); } @Override diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java index efa546b5d0..8df5cca36d 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,22 +18,25 @@ package org.springframework.security.oauth2.client; import java.util.Collections; import java.util.Map; +import java.util.function.Consumer; import org.junit.jupiter.api.Test; +import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatObject; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; -import static org.mockito.Mockito.mock; +import static org.mockito.BDDMockito.mock; /** * Tests for {@link InMemoryOAuth2AuthorizedClientService}. @@ -79,9 +82,11 @@ public class InMemoryOAuth2AuthorizedClientServiceTests { @Test public void constructorWhenAuthorizedClientsProvidedThenUseProvidedAuthorizedClients() { String registrationId = this.registration3.getRegistrationId(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration3, this.principalName1, + mock(OAuth2AccessToken.class)); Map authorizedClients = Collections.singletonMap( new OAuth2AuthorizedClientId(this.registration3.getRegistrationId(), this.principalName1), - mock(OAuth2AuthorizedClient.class)); + authorizedClient); ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); given(clientRegistrationRepository.findByRegistrationId(eq(registrationId))).willReturn(this.registration3); InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService( @@ -124,7 +129,35 @@ public class InMemoryOAuth2AuthorizedClientServiceTests { this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication); OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService .loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1); - assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient); + assertThat(loadedAuthorizedClient).satisfies(isEqualTo(authorizedClient)); + } + + @Test + public void loadAuthorizedClientWhenClientRegistrationIsUpdatedThenReturnAuthorizedClientWithUpdatedClientRegistration() { + ClientRegistration updatedRegistration = ClientRegistration.withClientRegistration(this.registration1) + .clientSecret("updated secret") + .build(); + + ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); + given(clientRegistrationRepository.findByRegistrationId(this.registration1.getRegistrationId())) + .willReturn(this.registration1, updatedRegistration); + + InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService( + clientRegistrationRepository); + + OAuth2AuthorizedClient cachedAuthorizedClient = new OAuth2AuthorizedClient(this.registration1, + this.principalName1, mock(OAuth2AccessToken.class), mock(OAuth2RefreshToken.class)); + authorizedClientService.saveAuthorizedClient(cachedAuthorizedClient, + new TestingAuthenticationToken(this.principalName1, null)); + + OAuth2AuthorizedClient authorizedClientWithUpdatedRegistration = new OAuth2AuthorizedClient(updatedRegistration, + this.principalName1, mock(OAuth2AccessToken.class), mock(OAuth2RefreshToken.class)); + OAuth2AuthorizedClient firstLoadedClient = authorizedClientService + .loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1); + OAuth2AuthorizedClient secondLoadedClient = authorizedClientService + .loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1); + assertThat(firstLoadedClient).satisfies(isEqualTo(cachedAuthorizedClient)); + assertThat(secondLoadedClient).satisfies(isEqualTo(authorizedClientWithUpdatedRegistration)); } @Test @@ -148,7 +181,7 @@ public class InMemoryOAuth2AuthorizedClientServiceTests { this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication); OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService .loadAuthorizedClient(this.registration3.getRegistrationId(), this.principalName2); - assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient); + assertThat(loadedAuthorizedClient).satisfies(isEqualTo(authorizedClient)); } @Test @@ -180,4 +213,38 @@ public class InMemoryOAuth2AuthorizedClientServiceTests { assertThat(loadedAuthorizedClient).isNull(); } + private static Consumer isEqualTo(OAuth2AuthorizedClient expected) { + return (actual) -> { + assertThat(actual).isNotNull(); + assertThat(actual.getClientRegistration().getRegistrationId()) + .isEqualTo(expected.getClientRegistration().getRegistrationId()); + assertThat(actual.getClientRegistration().getClientName()) + .isEqualTo(expected.getClientRegistration().getClientName()); + assertThat(actual.getClientRegistration().getRedirectUri()) + .isEqualTo(expected.getClientRegistration().getRedirectUri()); + assertThat(actual.getClientRegistration().getAuthorizationGrantType()) + .isEqualTo(expected.getClientRegistration().getAuthorizationGrantType()); + assertThat(actual.getClientRegistration().getClientAuthenticationMethod()) + .isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod()); + assertThat(actual.getClientRegistration().getClientId()) + .isEqualTo(expected.getClientRegistration().getClientId()); + assertThat(actual.getClientRegistration().getClientSecret()) + .isEqualTo(expected.getClientRegistration().getClientSecret()); + assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName()); + assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType()); + assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue()); + assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt()); + assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt()); + assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes()); + if (expected.getRefreshToken() != null) { + assertThat(actual.getRefreshToken()).isNotNull(); + assertThat(actual.getRefreshToken().getTokenValue()) + .isEqualTo(expected.getRefreshToken().getTokenValue()); + assertThat(actual.getRefreshToken().getIssuedAt()).isEqualTo(expected.getRefreshToken().getIssuedAt()); + assertThat(actual.getRefreshToken().getExpiresAt()) + .isEqualTo(expected.getRefreshToken().getExpiresAt()); + } + }; + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientServiceTests.java index 71a359b5ab..cc4e76700a 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientServiceTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,12 +18,14 @@ package org.springframework.security.oauth2.client; import java.time.Duration; import java.time.Instant; +import java.util.function.Consumer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -34,7 +36,9 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.BDDMockito.given; @@ -56,8 +60,9 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests { private Authentication principal = new TestingAuthenticationToken(this.principalName, "notused"); - OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", Instant.now(), - Instant.now().plus(Duration.ofDays(1))); + private OAuth2AccessToken accessToken; + + private OAuth2RefreshToken refreshToken; // @formatter:off private ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(this.clientRegistrationId) @@ -79,6 +84,11 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests { public void setup() { this.authorizedClientService = new InMemoryReactiveOAuth2AuthorizedClientService( this.clientRegistrationRepository); + + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(Duration.ofDays(1)); + this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", issuedAt, expiresAt); + this.refreshToken = new OAuth2RefreshToken("refresh", issuedAt, expiresAt); } @Test @@ -153,11 +163,37 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests { .saveAuthorizedClient(authorizedClient, this.principal) .then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); StepVerifier.create(saveAndLoad) - .expectNext(authorizedClient) + .assertNext(isEqualTo(authorizedClient)) .verifyComplete(); // @formatter:on } + @Test + @SuppressWarnings("unchecked") + public void loadAuthorizedClientWhenClientRegistrationIsUpdatedThenReturnsAuthorizedClientWithUpdatedClientRegistration() { + ClientRegistration updatedRegistration = ClientRegistration.withClientRegistration(this.clientRegistration) + .clientSecret("updated secret") + .build(); + + given(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId)) + .willReturn(Mono.just(this.clientRegistration), Mono.just(updatedRegistration)); + + OAuth2AuthorizedClient cachedAuthorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principalName, this.accessToken, this.refreshToken); + OAuth2AuthorizedClient authorizedClientWithChangedRegistration = new OAuth2AuthorizedClient(updatedRegistration, + this.principalName, this.accessToken, this.refreshToken); + + Flux saveAndLoadTwice = this.authorizedClientService + .saveAuthorizedClient(cachedAuthorizedClient, this.principal) + .then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)) + .concatWith( + this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); + StepVerifier.create(saveAndLoadTwice) + .assertNext(isEqualTo(cachedAuthorizedClient)) + .assertNext(isEqualTo(authorizedClientWithChangedRegistration)) + .verifyComplete(); + } + @Test public void saveAuthorizedClientWhenAuthorizedClientNullThenIllegalArgumentException() { OAuth2AuthorizedClient authorizedClient = null; @@ -246,4 +282,38 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests { // @formatter:on } + private static Consumer isEqualTo(OAuth2AuthorizedClient expected) { + return (actual) -> { + assertThat(actual).isNotNull(); + assertThat(actual.getClientRegistration().getRegistrationId()) + .isEqualTo(expected.getClientRegistration().getRegistrationId()); + assertThat(actual.getClientRegistration().getClientName()) + .isEqualTo(expected.getClientRegistration().getClientName()); + assertThat(actual.getClientRegistration().getRedirectUri()) + .isEqualTo(expected.getClientRegistration().getRedirectUri()); + assertThat(actual.getClientRegistration().getAuthorizationGrantType()) + .isEqualTo(expected.getClientRegistration().getAuthorizationGrantType()); + assertThat(actual.getClientRegistration().getClientAuthenticationMethod()) + .isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod()); + assertThat(actual.getClientRegistration().getClientId()) + .isEqualTo(expected.getClientRegistration().getClientId()); + assertThat(actual.getClientRegistration().getClientSecret()) + .isEqualTo(expected.getClientRegistration().getClientSecret()); + assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName()); + assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType()); + assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue()); + assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt()); + assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt()); + assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes()); + if (expected.getRefreshToken() != null) { + assertThat(actual.getRefreshToken()).isNotNull(); + assertThat(actual.getRefreshToken().getTokenValue()) + .isEqualTo(expected.getRefreshToken().getTokenValue()); + assertThat(actual.getRefreshToken().getIssuedAt()).isEqualTo(expected.getRefreshToken().getIssuedAt()); + assertThat(actual.getRefreshToken().getExpiresAt()) + .isEqualTo(expected.getRefreshToken().getExpiresAt()); + } + }; + } + }