Allow InMemoryOAuth2AuthorizedClientService to be constructed with a Map
Fixes gh-5994
This commit is contained in:
parent
d66d895e60
commit
9432670f1d
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2017 the original author or authors.
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -20,7 +20,6 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
|
|||
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.Base64;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
|
@ -29,6 +28,7 @@ import java.util.concurrent.ConcurrentHashMap;
|
|||
* {@link OAuth2AuthorizedClient Authorized Client(s)} in-memory.
|
||||
*
|
||||
* @author Joe Grandja
|
||||
* @author Vedran Pavic
|
||||
* @since 5.0
|
||||
* @see OAuth2AuthorizedClientService
|
||||
* @see OAuth2AuthorizedClient
|
||||
|
@ -36,8 +36,8 @@ import java.util.concurrent.ConcurrentHashMap;
|
|||
* @see Authentication
|
||||
*/
|
||||
public final class InMemoryOAuth2AuthorizedClientService implements OAuth2AuthorizedClientService {
|
||||
private final Map<String, OAuth2AuthorizedClient> authorizedClients = new ConcurrentHashMap<>();
|
||||
private final ClientRegistrationRepository clientRegistrationRepository;
|
||||
private Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = new ConcurrentHashMap<>();
|
||||
|
||||
/**
|
||||
* Constructs an {@code InMemoryOAuth2AuthorizedClientService} using the provided parameters.
|
||||
|
@ -49,7 +49,17 @@ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2Author
|
|||
this.clientRegistrationRepository = clientRegistrationRepository;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the map of authorized clients to use.
|
||||
* @param authorizedClients the map of authorized clients
|
||||
*/
|
||||
public void setAuthorizedClients(Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients) {
|
||||
Assert.notNull(authorizedClients, "authorizedClients cannot be null");
|
||||
this.authorizedClients = authorizedClients;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId, String principalName) {
|
||||
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
|
||||
Assert.hasText(principalName, "principalName cannot be empty");
|
||||
|
@ -57,15 +67,15 @@ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2Author
|
|||
if (registration == null) {
|
||||
return null;
|
||||
}
|
||||
return (T) this.authorizedClients.get(this.getIdentifier(registration, principalName));
|
||||
return (T) this.authorizedClients.get(OAuth2AuthorizedClientId.create(registration, principalName));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
|
||||
Assert.notNull(authorizedClient, "authorizedClient cannot be null");
|
||||
Assert.notNull(principal, "principal cannot be null");
|
||||
this.authorizedClients.put(this.getIdentifier(
|
||||
authorizedClient.getClientRegistration(), principal.getName()), authorizedClient);
|
||||
this.authorizedClients.put(OAuth2AuthorizedClientId.create(authorizedClient.getClientRegistration(),
|
||||
principal.getName()), authorizedClient);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -74,12 +84,8 @@ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2Author
|
|||
Assert.hasText(principalName, "principalName cannot be empty");
|
||||
ClientRegistration registration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
|
||||
if (registration != null) {
|
||||
this.authorizedClients.remove(this.getIdentifier(registration, principalName));
|
||||
this.authorizedClients.remove(OAuth2AuthorizedClientId.create(registration, principalName));
|
||||
}
|
||||
}
|
||||
|
||||
private String getIdentifier(ClientRegistration registration, String principalName) {
|
||||
String identifier = "[" + registration.getRegistrationId() + "][" + principalName + "]";
|
||||
return Base64.getEncoder().encodeToString(identifier.getBytes());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
package org.springframework.security.oauth2.client;
|
||||
|
||||
import java.util.Base64;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
|
@ -31,6 +30,7 @@ import reactor.core.publisher.Mono;
|
|||
* {@link OAuth2AuthorizedClient Authorized Client(s)} in-memory.
|
||||
*
|
||||
* @author Rob Winch
|
||||
* @author Vedran Pavic
|
||||
* @since 5.1
|
||||
* @see OAuth2AuthorizedClientService
|
||||
* @see OAuth2AuthorizedClient
|
||||
|
@ -38,7 +38,7 @@ import reactor.core.publisher.Mono;
|
|||
* @see Authentication
|
||||
*/
|
||||
public final class InMemoryReactiveOAuth2AuthorizedClientService implements ReactiveOAuth2AuthorizedClientService {
|
||||
private final Map<String, OAuth2AuthorizedClient> authorizedClients = new ConcurrentHashMap<>();
|
||||
private final Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = new ConcurrentHashMap<>();;
|
||||
private final ReactiveClientRegistrationRepository clientRegistrationRepository;
|
||||
|
||||
/**
|
||||
|
@ -52,10 +52,12 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac
|
|||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T extends OAuth2AuthorizedClient> Mono<T> loadAuthorizedClient(String clientRegistrationId, String principalName) {
|
||||
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
|
||||
Assert.hasText(principalName, "principalName cannot be empty");
|
||||
return (Mono<T>) getIdentifier(clientRegistrationId, principalName)
|
||||
return (Mono<T>) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
|
||||
.map(clientRegistration -> OAuth2AuthorizedClientId.create(clientRegistration, principalName))
|
||||
.flatMap(identifier -> Mono.justOrEmpty(this.authorizedClients.get(identifier)));
|
||||
}
|
||||
|
||||
|
@ -64,7 +66,8 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac
|
|||
Assert.notNull(authorizedClient, "authorizedClient cannot be null");
|
||||
Assert.notNull(principal, "principal cannot be null");
|
||||
return Mono.fromRunnable(() -> {
|
||||
String identifier = this.getIdentifier(authorizedClient.getClientRegistration(), principal.getName());
|
||||
OAuth2AuthorizedClientId identifier = OAuth2AuthorizedClientId.create(
|
||||
authorizedClient.getClientRegistration(), principal.getName());
|
||||
this.authorizedClients.put(identifier, authorizedClient);
|
||||
});
|
||||
}
|
||||
|
@ -73,18 +76,10 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac
|
|||
public Mono<Void> removeAuthorizedClient(String clientRegistrationId, String principalName) {
|
||||
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
|
||||
Assert.hasText(principalName, "principalName cannot be empty");
|
||||
return this.getIdentifier(clientRegistrationId, principalName)
|
||||
.doOnNext(identifier -> this.authorizedClients.remove(identifier))
|
||||
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
|
||||
.map(clientRegistration -> OAuth2AuthorizedClientId.create(clientRegistration, principalName))
|
||||
.doOnNext(this.authorizedClients::remove)
|
||||
.then(Mono.empty());
|
||||
}
|
||||
|
||||
private Mono<String> getIdentifier(String clientRegistrationId, String principalName) {
|
||||
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
|
||||
.map(registration -> getIdentifier(registration, principalName));
|
||||
}
|
||||
|
||||
private String getIdentifier(ClientRegistration registration, String principalName) {
|
||||
String identifier = "[" + registration.getRegistrationId() + "][" + principalName + "]";
|
||||
return Base64.getEncoder().encodeToString(identifier.getBytes());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.security.oauth2.client;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Objects;
|
||||
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
* The identifier for {@link OAuth2AuthorizedClient}.
|
||||
*
|
||||
* @author Vedran Pavic
|
||||
* @since 5.2
|
||||
* @see OAuth2AuthorizedClient
|
||||
* @see OAuth2AuthorizedClientService
|
||||
*/
|
||||
public final class OAuth2AuthorizedClientId implements Serializable {
|
||||
|
||||
private final String clientRegistrationId;
|
||||
|
||||
private final String principalName;
|
||||
|
||||
private OAuth2AuthorizedClientId(String clientRegistrationId, String principalName) {
|
||||
Assert.notNull(clientRegistrationId, "clientRegistrationId cannot be null");
|
||||
Assert.notNull(principalName, "principalName cannot be null");
|
||||
this.clientRegistrationId = clientRegistrationId;
|
||||
this.principalName = principalName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Factory method for creating new {@link OAuth2AuthorizedClientId} using
|
||||
* {@link ClientRegistration} and principal name.
|
||||
* @param clientRegistration the client registration
|
||||
* @param principalName the principal name
|
||||
* @return the new authorized client id
|
||||
*/
|
||||
public static OAuth2AuthorizedClientId create(ClientRegistration clientRegistration,
|
||||
String principalName) {
|
||||
return new OAuth2AuthorizedClientId(clientRegistration.getRegistrationId(),
|
||||
principalName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
return true;
|
||||
}
|
||||
if (obj == null || getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
OAuth2AuthorizedClientId that = (OAuth2AuthorizedClientId) obj;
|
||||
return Objects.equals(this.clientRegistrationId, that.clientRegistrationId)
|
||||
&& Objects.equals(this.principalName, that.principalName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(this.clientRegistrationId, this.principalName);
|
||||
}
|
||||
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2017 the original author or authors.
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,7 +15,11 @@
|
|||
*/
|
||||
package org.springframework.security.oauth2.client;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
|
||||
|
@ -24,6 +28,9 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistr
|
|||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.BDDMockito.given;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
|
@ -31,6 +38,7 @@ import static org.mockito.Mockito.when;
|
|||
* Tests for {@link InMemoryOAuth2AuthorizedClientService}.
|
||||
*
|
||||
* @author Joe Grandja
|
||||
* @author Vedran Pavic
|
||||
*/
|
||||
public class InMemoryOAuth2AuthorizedClientServiceTests {
|
||||
private String principalName1 = "principal-1";
|
||||
|
@ -57,6 +65,30 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
|
|||
new InMemoryOAuth2AuthorizedClientService(null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void constructorWhenAuthorizedClientsIsNullThenIllegalArgumentException() {
|
||||
assertThatExceptionOfType(IllegalArgumentException.class)
|
||||
.isThrownBy(() -> this.authorizedClientService.setAuthorizedClients(null))
|
||||
.withMessage("authorizedClients cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void constructorWhenAuthorizedClientsIsEmptyMapThenRepositoryUsingSuppliedAuthorizedClients() {
|
||||
String registrationId = this.registration3.getRegistrationId();
|
||||
|
||||
Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = Collections.singletonMap(
|
||||
OAuth2AuthorizedClientId.create(this.registration3, this.principalName1),
|
||||
mock(OAuth2AuthorizedClient.class));
|
||||
ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
|
||||
given(clientRegistrationRepository.findByRegistrationId(eq(registrationId))).willReturn(this.registration3);
|
||||
|
||||
InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService(
|
||||
this.clientRegistrationRepository);
|
||||
authorizedClientService.setAuthorizedClients(authorizedClients);
|
||||
assertThat((OAuth2AuthorizedClient) authorizedClientService.loadAuthorizedClient(
|
||||
registrationId, this.principalName1)).isNotNull();
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
|
||||
this.authorizedClientService.loadAuthorizedClient(null, this.principalName1);
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
/*
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.security.oauth2.client;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* Tests for {@link OAuth2AuthorizedClientId}.
|
||||
*
|
||||
* @author Vedran Pavic
|
||||
*/
|
||||
public class OAuth2AuthorizedClientIdTests {
|
||||
|
||||
@Test
|
||||
public void equalsWhenSameRegistrationIdAndPrincipalThenShouldReturnTrue() {
|
||||
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
|
||||
"test-principal");
|
||||
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
|
||||
"test-principal");
|
||||
assertThat(id1.equals(id2)).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void equalsWhenDifferentRegistrationIdAndSamePrincipalThenShouldReturnFalse() {
|
||||
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client1"),
|
||||
"test-principal");
|
||||
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client2"),
|
||||
"test-principal");
|
||||
assertThat(id1.equals(id2)).isFalse();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void equalsWhenSameRegistrationIdAndDifferentPrincipalThenShouldReturnFalse() {
|
||||
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
|
||||
"test-principal1");
|
||||
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
|
||||
"test-principal2");
|
||||
assertThat(id1.equals(id2)).isFalse();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void hashCodeWhenSameRegistrationIdAndPrincipalThenShouldReturnSame() {
|
||||
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
|
||||
"test-principal");
|
||||
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
|
||||
"test-principal");
|
||||
assertThat(id1.hashCode()).isEqualTo(id2.hashCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void hashCodeWhenDifferentRegistrationIdAndSamePrincipalThenShouldNotReturnSame() {
|
||||
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client1"),
|
||||
"test-principal");
|
||||
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client2"),
|
||||
"test-principal");
|
||||
assertThat(id1.hashCode()).isNotEqualTo(id2.hashCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void hashCodeWhenSameRegistrationIdAndDifferentPrincipalThenShouldNotReturnSame() {
|
||||
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
|
||||
"test-principal1");
|
||||
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
|
||||
"test-principal2");
|
||||
assertThat(id1.hashCode()).isNotEqualTo(id2.hashCode());
|
||||
}
|
||||
|
||||
private static ClientRegistration testClientRegistration(String registrationId) {
|
||||
return ClientRegistration.withRegistrationId(registrationId).clientId("id").clientSecret("secret")
|
||||
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
|
||||
.redirectUriTemplate("{baseUrl}/{action}/oauth2/code/{registrationId}")
|
||||
.authorizationUri("http://example.com/authorize").tokenUri("http://example.com/token").build();
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue