diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java index 181753bed4..77718a15c7 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.util.Assert; -import java.util.Base64; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -29,6 +28,7 @@ import java.util.concurrent.ConcurrentHashMap; * {@link OAuth2AuthorizedClient Authorized Client(s)} in-memory. * * @author Joe Grandja + * @author Vedran Pavic * @since 5.0 * @see OAuth2AuthorizedClientService * @see OAuth2AuthorizedClient @@ -36,8 +36,8 @@ import java.util.concurrent.ConcurrentHashMap; * @see Authentication */ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2AuthorizedClientService { - private final Map authorizedClients = new ConcurrentHashMap<>(); private final ClientRegistrationRepository clientRegistrationRepository; + private Map authorizedClients = new ConcurrentHashMap<>(); /** * Constructs an {@code InMemoryOAuth2AuthorizedClientService} using the provided parameters. @@ -49,7 +49,17 @@ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2Author this.clientRegistrationRepository = clientRegistrationRepository; } + /** + * Sets the map of authorized clients to use. + * @param authorizedClients the map of authorized clients + */ + public void setAuthorizedClients(Map authorizedClients) { + Assert.notNull(authorizedClients, "authorizedClients cannot be null"); + this.authorizedClients = authorizedClients; + } + @Override + @SuppressWarnings("unchecked") public T loadAuthorizedClient(String clientRegistrationId, String principalName) { Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.hasText(principalName, "principalName cannot be empty"); @@ -57,15 +67,15 @@ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2Author if (registration == null) { return null; } - return (T) this.authorizedClients.get(this.getIdentifier(registration, principalName)); + return (T) this.authorizedClients.get(OAuth2AuthorizedClientId.create(registration, principalName)); } @Override public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { Assert.notNull(authorizedClient, "authorizedClient cannot be null"); Assert.notNull(principal, "principal cannot be null"); - this.authorizedClients.put(this.getIdentifier( - authorizedClient.getClientRegistration(), principal.getName()), authorizedClient); + this.authorizedClients.put(OAuth2AuthorizedClientId.create(authorizedClient.getClientRegistration(), + principal.getName()), authorizedClient); } @Override @@ -74,12 +84,8 @@ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2Author Assert.hasText(principalName, "principalName cannot be empty"); ClientRegistration registration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); if (registration != null) { - this.authorizedClients.remove(this.getIdentifier(registration, principalName)); + this.authorizedClients.remove(OAuth2AuthorizedClientId.create(registration, principalName)); } } - private String getIdentifier(ClientRegistration registration, String principalName) { - String identifier = "[" + registration.getRegistrationId() + "][" + principalName + "]"; - return Base64.getEncoder().encodeToString(identifier.getBytes()); - } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java index 21b08393a3..64db840ccd 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java @@ -15,7 +15,6 @@ */ package org.springframework.security.oauth2.client; -import java.util.Base64; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -31,6 +30,7 @@ import reactor.core.publisher.Mono; * {@link OAuth2AuthorizedClient Authorized Client(s)} in-memory. * * @author Rob Winch + * @author Vedran Pavic * @since 5.1 * @see OAuth2AuthorizedClientService * @see OAuth2AuthorizedClient @@ -38,7 +38,7 @@ import reactor.core.publisher.Mono; * @see Authentication */ public final class InMemoryReactiveOAuth2AuthorizedClientService implements ReactiveOAuth2AuthorizedClientService { - private final Map authorizedClients = new ConcurrentHashMap<>(); + private final Map authorizedClients = new ConcurrentHashMap<>();; private final ReactiveClientRegistrationRepository clientRegistrationRepository; /** @@ -52,10 +52,12 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac } @Override + @SuppressWarnings("unchecked") public Mono loadAuthorizedClient(String clientRegistrationId, String principalName) { Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.hasText(principalName, "principalName cannot be empty"); - return (Mono) getIdentifier(clientRegistrationId, principalName) + return (Mono) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) + .map(clientRegistration -> OAuth2AuthorizedClientId.create(clientRegistration, principalName)) .flatMap(identifier -> Mono.justOrEmpty(this.authorizedClients.get(identifier))); } @@ -64,7 +66,8 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac Assert.notNull(authorizedClient, "authorizedClient cannot be null"); Assert.notNull(principal, "principal cannot be null"); return Mono.fromRunnable(() -> { - String identifier = this.getIdentifier(authorizedClient.getClientRegistration(), principal.getName()); + OAuth2AuthorizedClientId identifier = OAuth2AuthorizedClientId.create( + authorizedClient.getClientRegistration(), principal.getName()); this.authorizedClients.put(identifier, authorizedClient); }); } @@ -73,18 +76,10 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac public Mono removeAuthorizedClient(String clientRegistrationId, String principalName) { Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.hasText(principalName, "principalName cannot be empty"); - return this.getIdentifier(clientRegistrationId, principalName) - .doOnNext(identifier -> this.authorizedClients.remove(identifier)) + return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) + .map(clientRegistration -> OAuth2AuthorizedClientId.create(clientRegistration, principalName)) + .doOnNext(this.authorizedClients::remove) .then(Mono.empty()); } - private Mono getIdentifier(String clientRegistrationId, String principalName) { - return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) - .map(registration -> getIdentifier(registration, principalName)); - } - - private String getIdentifier(ClientRegistration registration, String principalName) { - String identifier = "[" + registration.getRegistrationId() + "][" + principalName + "]"; - return Base64.getEncoder().encodeToString(identifier.getBytes()); - } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientId.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientId.java new file mode 100644 index 0000000000..2501b1f22f --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientId.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import java.io.Serializable; +import java.util.Objects; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.util.Assert; + +/** + * The identifier for {@link OAuth2AuthorizedClient}. + * + * @author Vedran Pavic + * @since 5.2 + * @see OAuth2AuthorizedClient + * @see OAuth2AuthorizedClientService + */ +public final class OAuth2AuthorizedClientId implements Serializable { + + private final String clientRegistrationId; + + private final String principalName; + + private OAuth2AuthorizedClientId(String clientRegistrationId, String principalName) { + Assert.notNull(clientRegistrationId, "clientRegistrationId cannot be null"); + Assert.notNull(principalName, "principalName cannot be null"); + this.clientRegistrationId = clientRegistrationId; + this.principalName = principalName; + } + + /** + * Factory method for creating new {@link OAuth2AuthorizedClientId} using + * {@link ClientRegistration} and principal name. + * @param clientRegistration the client registration + * @param principalName the principal name + * @return the new authorized client id + */ + public static OAuth2AuthorizedClientId create(ClientRegistration clientRegistration, + String principalName) { + return new OAuth2AuthorizedClientId(clientRegistration.getRegistrationId(), + principalName); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + OAuth2AuthorizedClientId that = (OAuth2AuthorizedClientId) obj; + return Objects.equals(this.clientRegistrationId, that.clientRegistrationId) + && Objects.equals(this.principalName, that.principalName); + } + + @Override + public int hashCode() { + return Objects.hash(this.clientRegistrationId, this.principalName); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java index 31508826a0..e6482a64a8 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,11 @@ */ package org.springframework.security.oauth2.client; +import java.util.Collections; +import java.util.Map; + import org.junit.Test; + import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; @@ -24,6 +28,9 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistr import org.springframework.security.oauth2.core.OAuth2AccessToken; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -31,6 +38,7 @@ import static org.mockito.Mockito.when; * Tests for {@link InMemoryOAuth2AuthorizedClientService}. * * @author Joe Grandja + * @author Vedran Pavic */ public class InMemoryOAuth2AuthorizedClientServiceTests { private String principalName1 = "principal-1"; @@ -57,6 +65,30 @@ public class InMemoryOAuth2AuthorizedClientServiceTests { new InMemoryOAuth2AuthorizedClientService(null); } + @Test + public void constructorWhenAuthorizedClientsIsNullThenIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.authorizedClientService.setAuthorizedClients(null)) + .withMessage("authorizedClients cannot be null"); + } + + @Test + public void constructorWhenAuthorizedClientsIsEmptyMapThenRepositoryUsingSuppliedAuthorizedClients() { + String registrationId = this.registration3.getRegistrationId(); + + Map authorizedClients = Collections.singletonMap( + OAuth2AuthorizedClientId.create(this.registration3, this.principalName1), + mock(OAuth2AuthorizedClient.class)); + ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); + given(clientRegistrationRepository.findByRegistrationId(eq(registrationId))).willReturn(this.registration3); + + InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService( + this.clientRegistrationRepository); + authorizedClientService.setAuthorizedClients(authorizedClients); + assertThat((OAuth2AuthorizedClient) authorizedClientService.loadAuthorizedClient( + registrationId, this.principalName1)).isNotNull(); + } + @Test(expected = IllegalArgumentException.class) public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { this.authorizedClientService.loadAuthorizedClient(null, this.principalName1); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientIdTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientIdTests.java new file mode 100644 index 0000000000..b979f42061 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientIdTests.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import org.junit.Test; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link OAuth2AuthorizedClientId}. + * + * @author Vedran Pavic + */ +public class OAuth2AuthorizedClientIdTests { + + @Test + public void equalsWhenSameRegistrationIdAndPrincipalThenShouldReturnTrue() { + OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"), + "test-principal"); + OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"), + "test-principal"); + assertThat(id1.equals(id2)).isTrue(); + } + + @Test + public void equalsWhenDifferentRegistrationIdAndSamePrincipalThenShouldReturnFalse() { + OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client1"), + "test-principal"); + OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client2"), + "test-principal"); + assertThat(id1.equals(id2)).isFalse(); + } + + @Test + public void equalsWhenSameRegistrationIdAndDifferentPrincipalThenShouldReturnFalse() { + OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"), + "test-principal1"); + OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"), + "test-principal2"); + assertThat(id1.equals(id2)).isFalse(); + } + + @Test + public void hashCodeWhenSameRegistrationIdAndPrincipalThenShouldReturnSame() { + OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"), + "test-principal"); + OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"), + "test-principal"); + assertThat(id1.hashCode()).isEqualTo(id2.hashCode()); + } + + @Test + public void hashCodeWhenDifferentRegistrationIdAndSamePrincipalThenShouldNotReturnSame() { + OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client1"), + "test-principal"); + OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client2"), + "test-principal"); + assertThat(id1.hashCode()).isNotEqualTo(id2.hashCode()); + } + + @Test + public void hashCodeWhenSameRegistrationIdAndDifferentPrincipalThenShouldNotReturnSame() { + OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"), + "test-principal1"); + OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"), + "test-principal2"); + assertThat(id1.hashCode()).isNotEqualTo(id2.hashCode()); + } + + private static ClientRegistration testClientRegistration(String registrationId) { + return ClientRegistration.withRegistrationId(registrationId).clientId("id").clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUriTemplate("{baseUrl}/{action}/oauth2/code/{registrationId}") + .authorizationUri("http://example.com/authorize").tokenUri("http://example.com/token").build(); + } + +}