Merge branch '6.2.x' into 6.3.x

Closes gh-16138
This commit is contained in:
Steve Riesenberg 2024-11-20 15:54:29 -06:00
commit 4b41f8cb5b
No known key found for this signature in database
GPG Key ID: 3D0169B18AB8F0A9
4 changed files with 164 additions and 14 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2024 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -80,7 +80,13 @@ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2Author
if (registration == null) { if (registration == null) {
return null; return null;
} }
return (T) this.authorizedClients.get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName)); OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients
.get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName));
if (cachedAuthorizedClient == null) {
return null;
}
return (T) new OAuth2AuthorizedClient(registration, cachedAuthorizedClient.getPrincipalName(),
cachedAuthorizedClient.getAccessToken(), cachedAuthorizedClient.getRefreshToken());
} }
@Override @Override

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2024 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -62,8 +62,15 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
Assert.hasText(principalName, "principalName cannot be empty"); Assert.hasText(principalName, "principalName cannot be empty");
return (Mono<T>) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) return (Mono<T>) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.map((clientRegistration) -> new OAuth2AuthorizedClientId(clientRegistrationId, principalName)) .mapNotNull((clientRegistration) -> {
.flatMap((identifier) -> Mono.justOrEmpty(this.authorizedClients.get(identifier))); OAuth2AuthorizedClientId id = new OAuth2AuthorizedClientId(clientRegistrationId, principalName);
OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients.get(id);
if (cachedAuthorizedClient == null) {
return null;
}
return new OAuth2AuthorizedClient(clientRegistration, cachedAuthorizedClient.getPrincipalName(),
cachedAuthorizedClient.getAccessToken(), cachedAuthorizedClient.getRefreshToken());
});
} }
@Override @Override

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2024 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,22 +18,25 @@ package org.springframework.security.oauth2.client;
import java.util.Collections; import java.util.Collections;
import java.util.Map; import java.util.Map;
import java.util.function.Consumer;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatObject; import static org.assertj.core.api.Assertions.assertThatObject;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.BDDMockito.mock;
/** /**
* Tests for {@link InMemoryOAuth2AuthorizedClientService}. * Tests for {@link InMemoryOAuth2AuthorizedClientService}.
@ -79,9 +82,11 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
@Test @Test
public void constructorWhenAuthorizedClientsProvidedThenUseProvidedAuthorizedClients() { public void constructorWhenAuthorizedClientsProvidedThenUseProvidedAuthorizedClients() {
String registrationId = this.registration3.getRegistrationId(); String registrationId = this.registration3.getRegistrationId();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration3, this.principalName1,
mock(OAuth2AccessToken.class));
Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = Collections.singletonMap( Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = Collections.singletonMap(
new OAuth2AuthorizedClientId(this.registration3.getRegistrationId(), this.principalName1), new OAuth2AuthorizedClientId(this.registration3.getRegistrationId(), this.principalName1),
mock(OAuth2AuthorizedClient.class)); authorizedClient);
ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
given(clientRegistrationRepository.findByRegistrationId(eq(registrationId))).willReturn(this.registration3); given(clientRegistrationRepository.findByRegistrationId(eq(registrationId))).willReturn(this.registration3);
InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService( InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService(
@ -124,7 +129,35 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication); this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1); .loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient); assertThat(loadedAuthorizedClient).satisfies(isEqualTo(authorizedClient));
}
@Test
public void loadAuthorizedClientWhenClientRegistrationIsUpdatedThenReturnAuthorizedClientWithUpdatedClientRegistration() {
ClientRegistration updatedRegistration = ClientRegistration.withClientRegistration(this.registration1)
.clientSecret("updated secret")
.build();
ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
given(clientRegistrationRepository.findByRegistrationId(this.registration1.getRegistrationId()))
.willReturn(this.registration1, updatedRegistration);
InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService(
clientRegistrationRepository);
OAuth2AuthorizedClient cachedAuthorizedClient = new OAuth2AuthorizedClient(this.registration1,
this.principalName1, mock(OAuth2AccessToken.class), mock(OAuth2RefreshToken.class));
authorizedClientService.saveAuthorizedClient(cachedAuthorizedClient,
new TestingAuthenticationToken(this.principalName1, null));
OAuth2AuthorizedClient authorizedClientWithUpdatedRegistration = new OAuth2AuthorizedClient(updatedRegistration,
this.principalName1, mock(OAuth2AccessToken.class), mock(OAuth2RefreshToken.class));
OAuth2AuthorizedClient firstLoadedClient = authorizedClientService
.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
OAuth2AuthorizedClient secondLoadedClient = authorizedClientService
.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
assertThat(firstLoadedClient).satisfies(isEqualTo(cachedAuthorizedClient));
assertThat(secondLoadedClient).satisfies(isEqualTo(authorizedClientWithUpdatedRegistration));
} }
@Test @Test
@ -148,7 +181,7 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication); this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
.loadAuthorizedClient(this.registration3.getRegistrationId(), this.principalName2); .loadAuthorizedClient(this.registration3.getRegistrationId(), this.principalName2);
assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient); assertThat(loadedAuthorizedClient).satisfies(isEqualTo(authorizedClient));
} }
@Test @Test
@ -180,4 +213,38 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
assertThat(loadedAuthorizedClient).isNull(); assertThat(loadedAuthorizedClient).isNull();
} }
private static Consumer<OAuth2AuthorizedClient> isEqualTo(OAuth2AuthorizedClient expected) {
return (actual) -> {
assertThat(actual).isNotNull();
assertThat(actual.getClientRegistration().getRegistrationId())
.isEqualTo(expected.getClientRegistration().getRegistrationId());
assertThat(actual.getClientRegistration().getClientName())
.isEqualTo(expected.getClientRegistration().getClientName());
assertThat(actual.getClientRegistration().getRedirectUri())
.isEqualTo(expected.getClientRegistration().getRedirectUri());
assertThat(actual.getClientRegistration().getAuthorizationGrantType())
.isEqualTo(expected.getClientRegistration().getAuthorizationGrantType());
assertThat(actual.getClientRegistration().getClientAuthenticationMethod())
.isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod());
assertThat(actual.getClientRegistration().getClientId())
.isEqualTo(expected.getClientRegistration().getClientId());
assertThat(actual.getClientRegistration().getClientSecret())
.isEqualTo(expected.getClientRegistration().getClientSecret());
assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName());
assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
if (expected.getRefreshToken() != null) {
assertThat(actual.getRefreshToken()).isNotNull();
assertThat(actual.getRefreshToken().getTokenValue())
.isEqualTo(expected.getRefreshToken().getTokenValue());
assertThat(actual.getRefreshToken().getIssuedAt()).isEqualTo(expected.getRefreshToken().getIssuedAt());
assertThat(actual.getRefreshToken().getExpiresAt())
.isEqualTo(expected.getRefreshToken().getExpiresAt());
}
};
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2024 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,12 +18,14 @@ package org.springframework.security.oauth2.client;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.function.Consumer;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.test.StepVerifier; import reactor.test.StepVerifier;
@ -34,7 +36,9 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
@ -56,8 +60,9 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
private Authentication principal = new TestingAuthenticationToken(this.principalName, "notused"); private Authentication principal = new TestingAuthenticationToken(this.principalName, "notused");
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", Instant.now(), private OAuth2AccessToken accessToken;
Instant.now().plus(Duration.ofDays(1)));
private OAuth2RefreshToken refreshToken;
// @formatter:off // @formatter:off
private ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(this.clientRegistrationId) private ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(this.clientRegistrationId)
@ -79,6 +84,11 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
public void setup() { public void setup() {
this.authorizedClientService = new InMemoryReactiveOAuth2AuthorizedClientService( this.authorizedClientService = new InMemoryReactiveOAuth2AuthorizedClientService(
this.clientRegistrationRepository); this.clientRegistrationRepository);
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(Duration.ofDays(1));
this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", issuedAt, expiresAt);
this.refreshToken = new OAuth2RefreshToken("refresh", issuedAt, expiresAt);
} }
@Test @Test
@ -153,11 +163,37 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
.saveAuthorizedClient(authorizedClient, this.principal) .saveAuthorizedClient(authorizedClient, this.principal)
.then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); .then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
StepVerifier.create(saveAndLoad) StepVerifier.create(saveAndLoad)
.expectNext(authorizedClient) .assertNext(isEqualTo(authorizedClient))
.verifyComplete(); .verifyComplete();
// @formatter:on // @formatter:on
} }
@Test
@SuppressWarnings("unchecked")
public void loadAuthorizedClientWhenClientRegistrationIsUpdatedThenReturnsAuthorizedClientWithUpdatedClientRegistration() {
ClientRegistration updatedRegistration = ClientRegistration.withClientRegistration(this.clientRegistration)
.clientSecret("updated secret")
.build();
given(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId))
.willReturn(Mono.just(this.clientRegistration), Mono.just(updatedRegistration));
OAuth2AuthorizedClient cachedAuthorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principalName, this.accessToken, this.refreshToken);
OAuth2AuthorizedClient authorizedClientWithChangedRegistration = new OAuth2AuthorizedClient(updatedRegistration,
this.principalName, this.accessToken, this.refreshToken);
Flux<OAuth2AuthorizedClient> saveAndLoadTwice = this.authorizedClientService
.saveAuthorizedClient(cachedAuthorizedClient, this.principal)
.then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName))
.concatWith(
this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
StepVerifier.create(saveAndLoadTwice)
.assertNext(isEqualTo(cachedAuthorizedClient))
.assertNext(isEqualTo(authorizedClientWithChangedRegistration))
.verifyComplete();
}
@Test @Test
public void saveAuthorizedClientWhenAuthorizedClientNullThenIllegalArgumentException() { public void saveAuthorizedClientWhenAuthorizedClientNullThenIllegalArgumentException() {
OAuth2AuthorizedClient authorizedClient = null; OAuth2AuthorizedClient authorizedClient = null;
@ -246,4 +282,38 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
// @formatter:on // @formatter:on
} }
private static Consumer<OAuth2AuthorizedClient> isEqualTo(OAuth2AuthorizedClient expected) {
return (actual) -> {
assertThat(actual).isNotNull();
assertThat(actual.getClientRegistration().getRegistrationId())
.isEqualTo(expected.getClientRegistration().getRegistrationId());
assertThat(actual.getClientRegistration().getClientName())
.isEqualTo(expected.getClientRegistration().getClientName());
assertThat(actual.getClientRegistration().getRedirectUri())
.isEqualTo(expected.getClientRegistration().getRedirectUri());
assertThat(actual.getClientRegistration().getAuthorizationGrantType())
.isEqualTo(expected.getClientRegistration().getAuthorizationGrantType());
assertThat(actual.getClientRegistration().getClientAuthenticationMethod())
.isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod());
assertThat(actual.getClientRegistration().getClientId())
.isEqualTo(expected.getClientRegistration().getClientId());
assertThat(actual.getClientRegistration().getClientSecret())
.isEqualTo(expected.getClientRegistration().getClientSecret());
assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName());
assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
if (expected.getRefreshToken() != null) {
assertThat(actual.getRefreshToken()).isNotNull();
assertThat(actual.getRefreshToken().getTokenValue())
.isEqualTo(expected.getRefreshToken().getTokenValue());
assertThat(actual.getRefreshToken().getIssuedAt()).isEqualTo(expected.getRefreshToken().getIssuedAt());
assertThat(actual.getRefreshToken().getExpiresAt())
.isEqualTo(expected.getRefreshToken().getExpiresAt());
}
};
}
} }