diff --git a/oauth2/oauth2-client/spring-security-oauth2-client.gradle b/oauth2/oauth2-client/spring-security-oauth2-client.gradle index d43a0849e1..6e17822390 100644 --- a/oauth2/oauth2-client/spring-security-oauth2-client.gradle +++ b/oauth2/oauth2-client/spring-security-oauth2-client.gradle @@ -12,6 +12,7 @@ dependencies { optional 'org.springframework:spring-webflux' optional 'com.fasterxml.jackson.core:jackson-databind' optional 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310' + optional 'org.springframework:spring-jdbc' testCompile project(path: ':spring-security-oauth2-core', configuration: 'tests') testCompile project(path: ':spring-security-oauth2-jose', configuration: 'tests') @@ -22,5 +23,7 @@ dependencies { testCompile 'io.projectreactor.tools:blockhound' testCompile 'org.skyscreamer:jsonassert' + testRuntime 'org.hsqldb:hsqldb' + provided 'javax.servlet:javax.servlet-api' } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java new file mode 100644 index 0000000000..a4ac907123 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java @@ -0,0 +1,312 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.dao.DataRetrievalFailureException; +import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.PreparedStatementSetter; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.SqlParameterValue; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +import java.nio.charset.StandardCharsets; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.sql.Types; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.function.Function; + +/** + * A JDBC implementation of an {@link OAuth2AuthorizedClientService} + * that uses a {@link JdbcOperations} for {@link OAuth2AuthorizedClient} persistence. + * + *

+ * NOTE: This {@code OAuth2AuthorizedClientService} depends on the table definition + * described in "classpath:org/springframework/security/oauth2/client/oauth2-client-schema.sql" + * and therefore MUST be defined in the database schema. + * + * @author Joe Grandja + * @since 5.3 + * @see OAuth2AuthorizedClientService + * @see OAuth2AuthorizedClient + * @see JdbcOperations + * @see RowMapper + */ +public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClientService { + private static final String COLUMN_NAMES = + "client_registration_id, " + + "principal_name, " + + "access_token_type, " + + "access_token_value, " + + "access_token_issued_at, " + + "access_token_expires_at, " + + "access_token_scopes, " + + "refresh_token_value, " + + "refresh_token_issued_at"; + private static final String TABLE_NAME = "oauth2_authorized_client"; + private static final String PK_FILTER = "client_registration_id = ? AND principal_name = ?"; + private static final String LOAD_AUTHORIZED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + + " FROM " + TABLE_NAME + " WHERE " + PK_FILTER; + private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME + + " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME + + " WHERE " + PK_FILTER; + protected final JdbcOperations jdbcOperations; + protected RowMapper authorizedClientRowMapper; + protected Function> authorizedClientParametersMapper; + + /** + * Constructs a {@code JdbcOAuth2AuthorizedClientService} using the provided parameters. + * + * @param jdbcOperations the JDBC operations + * @param clientRegistrationRepository the repository of client registrations + */ + public JdbcOAuth2AuthorizedClientService( + JdbcOperations jdbcOperations, ClientRegistrationRepository clientRegistrationRepository) { + + Assert.notNull(jdbcOperations, "jdbcOperations cannot be null"); + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + this.jdbcOperations = jdbcOperations; + this.authorizedClientRowMapper = new OAuth2AuthorizedClientRowMapper(clientRegistrationRepository); + this.authorizedClientParametersMapper = new OAuth2AuthorizedClientParametersMapper(); + } + + @Override + @SuppressWarnings("unchecked") + public T loadAuthorizedClient(String clientRegistrationId, String principalName) { + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + Assert.hasText(principalName, "principalName cannot be empty"); + + SqlParameterValue[] parameters = new SqlParameterValue[] { + new SqlParameterValue(Types.VARCHAR, clientRegistrationId), + new SqlParameterValue(Types.VARCHAR, principalName) + }; + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); + + List result = this.jdbcOperations.query( + LOAD_AUTHORIZED_CLIENT_SQL, pss, this.authorizedClientRowMapper); + + return !result.isEmpty() ? (T) result.get(0) : null; + } + + @Override + public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + Assert.notNull(principal, "principal cannot be null"); + + List parameters = this.authorizedClientParametersMapper.apply( + new OAuth2AuthorizedClientHolder(authorizedClient, principal)); + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); + + this.jdbcOperations.update(SAVE_AUTHORIZED_CLIENT_SQL, pss); + } + + @Override + public void removeAuthorizedClient(String clientRegistrationId, String principalName) { + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + Assert.hasText(principalName, "principalName cannot be empty"); + + SqlParameterValue[] parameters = new SqlParameterValue[] { + new SqlParameterValue(Types.VARCHAR, clientRegistrationId), + new SqlParameterValue(Types.VARCHAR, principalName) + }; + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); + + this.jdbcOperations.update(REMOVE_AUTHORIZED_CLIENT_SQL, pss); + } + + /** + * Sets the {@link RowMapper} used for mapping the current row in {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient}. + * The default is {@link OAuth2AuthorizedClientRowMapper}. + * + * @param authorizedClientRowMapper the {@link RowMapper} used for mapping the current row in {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient} + */ + public final void setAuthorizedClientRowMapper(RowMapper authorizedClientRowMapper) { + Assert.notNull(authorizedClientRowMapper, "authorizedClientRowMapper cannot be null"); + this.authorizedClientRowMapper = authorizedClientRowMapper; + } + + /** + * Sets the {@code Function} used for mapping {@link OAuth2AuthorizedClientHolder} to a {@code List} of {@link SqlParameterValue}. + * The default is {@link OAuth2AuthorizedClientParametersMapper}. + * + * @param authorizedClientParametersMapper the {@code Function} used for mapping {@link OAuth2AuthorizedClientHolder} to a {@code List} of {@link SqlParameterValue} + */ + public final void setAuthorizedClientParametersMapper(Function> authorizedClientParametersMapper) { + Assert.notNull(authorizedClientParametersMapper, "authorizedClientParametersMapper cannot be null"); + this.authorizedClientParametersMapper = authorizedClientParametersMapper; + } + + /** + * The default {@link RowMapper} that maps the current row + * in {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient}. + */ + public static class OAuth2AuthorizedClientRowMapper implements RowMapper { + protected final ClientRegistrationRepository clientRegistrationRepository; + + public OAuth2AuthorizedClientRowMapper(ClientRegistrationRepository clientRegistrationRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + } + + @Override + public OAuth2AuthorizedClient mapRow(ResultSet rs, int rowNum) throws SQLException { + String clientRegistrationId = rs.getString("client_registration_id"); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId( + clientRegistrationId); + if (clientRegistration == null) { + throw new DataRetrievalFailureException("The ClientRegistration with id '" + + clientRegistrationId + "' exists in the data source, " + + "however, it was not found in the ClientRegistrationRepository."); + } + + OAuth2AccessToken.TokenType tokenType = null; + if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase( + rs.getString("access_token_type"))) { + tokenType = OAuth2AccessToken.TokenType.BEARER; + } + String tokenValue = new String(rs.getBytes("access_token_value"), StandardCharsets.UTF_8); + Instant issuedAt = rs.getTimestamp("access_token_issued_at").toInstant(); + Instant expiresAt = rs.getTimestamp("access_token_expires_at").toInstant(); + Set scopes = Collections.emptySet(); + String accessTokenScopes = rs.getString("access_token_scopes"); + if (accessTokenScopes != null) { + scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes); + } + OAuth2AccessToken accessToken = new OAuth2AccessToken( + tokenType, tokenValue, issuedAt, expiresAt, scopes); + + OAuth2RefreshToken refreshToken = null; + byte[] refreshTokenValue = rs.getBytes("refresh_token_value"); + if (refreshTokenValue != null) { + tokenValue = new String(refreshTokenValue, StandardCharsets.UTF_8); + issuedAt = null; + Timestamp refreshTokenIssuedAt = rs.getTimestamp("refresh_token_issued_at"); + if (refreshTokenIssuedAt != null) { + issuedAt = refreshTokenIssuedAt.toInstant(); + } + refreshToken = new OAuth2RefreshToken(tokenValue, issuedAt); + } + + String principalName = rs.getString("principal_name"); + + return new OAuth2AuthorizedClient( + clientRegistration, principalName, accessToken, refreshToken); + } + } + + /** + * The default {@code Function} that maps {@link OAuth2AuthorizedClientHolder} + * to a {@code List} of {@link SqlParameterValue}. + */ + public static class OAuth2AuthorizedClientParametersMapper implements Function> { + + @Override + public List apply(OAuth2AuthorizedClientHolder authorizedClientHolder) { + OAuth2AuthorizedClient authorizedClient = authorizedClientHolder.getAuthorizedClient(); + Authentication principal = authorizedClientHolder.getPrincipal(); + ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); + OAuth2AccessToken accessToken = authorizedClient.getAccessToken(); + OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken(); + + List parameters = new ArrayList<>(); + parameters.add(new SqlParameterValue( + Types.VARCHAR, clientRegistration.getRegistrationId())); + parameters.add(new SqlParameterValue( + Types.VARCHAR, principal.getName())); + parameters.add(new SqlParameterValue( + Types.VARCHAR, accessToken.getTokenType().getValue())); + parameters.add(new SqlParameterValue( + Types.BLOB, accessToken.getTokenValue().getBytes(StandardCharsets.UTF_8))); + parameters.add(new SqlParameterValue( + Types.TIMESTAMP, Timestamp.from(accessToken.getIssuedAt()))); + parameters.add(new SqlParameterValue( + Types.TIMESTAMP, Timestamp.from(accessToken.getExpiresAt()))); + String accessTokenScopes = null; + if (!CollectionUtils.isEmpty(accessToken.getScopes())) { + accessTokenScopes = StringUtils.collectionToDelimitedString(accessToken.getScopes(), ","); + } + parameters.add(new SqlParameterValue( + Types.VARCHAR, accessTokenScopes)); + byte[] refreshTokenValue = null; + Timestamp refreshTokenIssuedAt = null; + if (refreshToken != null) { + refreshTokenValue = refreshToken.getTokenValue().getBytes(StandardCharsets.UTF_8); + if (refreshToken.getIssuedAt() != null) { + refreshTokenIssuedAt = Timestamp.from(refreshToken.getIssuedAt()); + } + } + parameters.add(new SqlParameterValue( + Types.BLOB, refreshTokenValue)); + parameters.add(new SqlParameterValue( + Types.TIMESTAMP, refreshTokenIssuedAt)); + + return parameters; + } + } + + /** + * A holder for an {@link OAuth2AuthorizedClient} and End-User {@link Authentication} (Resource Owner). + */ + public static final class OAuth2AuthorizedClientHolder { + private final OAuth2AuthorizedClient authorizedClient; + private final Authentication principal; + + /** + * Constructs an {@code OAuth2AuthorizedClientHolder} using the provided parameters. + * + * @param authorizedClient the authorized client + * @param principal the End-User {@link Authentication} (Resource Owner) + */ + public OAuth2AuthorizedClientHolder(OAuth2AuthorizedClient authorizedClient, Authentication principal) { + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + Assert.notNull(principal, "principal cannot be null"); + this.authorizedClient = authorizedClient; + this.principal = principal; + } + + /** + * Returns the {@link OAuth2AuthorizedClient}. + * + * @return the {@link OAuth2AuthorizedClient} + */ + public OAuth2AuthorizedClient getAuthorizedClient() { + return this.authorizedClient; + } + + /** + * Returns the End-User {@link Authentication} (Resource Owner). + * + * @return the End-User {@link Authentication} (Resource Owner) + */ + public Authentication getPrincipal() { + return this.principal; + } + } +} diff --git a/oauth2/oauth2-client/src/main/resources/org/springframework/security/oauth2/client/oauth2-client-schema.sql b/oauth2/oauth2-client/src/main/resources/org/springframework/security/oauth2/client/oauth2-client-schema.sql new file mode 100644 index 0000000000..b4ceebf035 --- /dev/null +++ b/oauth2/oauth2-client/src/main/resources/org/springframework/security/oauth2/client/oauth2-client-schema.sql @@ -0,0 +1,13 @@ +CREATE TABLE oauth2_authorized_client ( + client_registration_id varchar(100) NOT NULL, + principal_name varchar(200) NOT NULL, + access_token_type varchar(100) NOT NULL, + access_token_value blob NOT NULL, + access_token_issued_at timestamp NOT NULL, + access_token_expires_at timestamp NOT NULL, + access_token_scopes varchar(1000) DEFAULT NULL, + refresh_token_value blob DEFAULT NULL, + refresh_token_issued_at timestamp DEFAULT NULL, + created_at timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL, + PRIMARY KEY (client_registration_id, principal_name) +); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientServiceTests.java new file mode 100644 index 0000000000..40b2957e2a --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientServiceTests.java @@ -0,0 +1,474 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.springframework.dao.DataRetrievalFailureException; +import org.springframework.dao.DuplicateKeyException; +import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.PreparedStatementSetter; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.SqlParameterValue; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +import java.nio.charset.StandardCharsets; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.sql.Types; +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link JdbcOAuth2AuthorizedClientService}. + * + * @author Joe Grandja + */ +public class JdbcOAuth2AuthorizedClientServiceTests { + private static final String OAUTH2_CLIENT_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/client/oauth2-client-schema.sql"; + private static int principalId = 1000; + private ClientRegistration clientRegistration; + private ClientRegistrationRepository clientRegistrationRepository; + private EmbeddedDatabase db; + private JdbcOperations jdbcOperations; + private JdbcOAuth2AuthorizedClientService authorizedClientService; + + @Before + public void setUp() { + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.clientRegistration); + this.db = createDb(); + this.jdbcOperations = new JdbcTemplate(this.db); + this.authorizedClientService = new JdbcOAuth2AuthorizedClientService( + this.jdbcOperations, this.clientRegistrationRepository); + } + + @After + public void tearDown() { + this.db.shutdown(); + } + + @Test + public void constructorWhenJdbcOperationsIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new JdbcOAuth2AuthorizedClientService(null, this.clientRegistrationRepository)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("jdbcOperations cannot be null"); + } + + @Test + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new JdbcOAuth2AuthorizedClientService(this.jdbcOperations, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationRepository cannot be null"); + } + + @Test + public void setAuthorizedClientRowMapperWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientService.setAuthorizedClientRowMapper(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientRowMapper cannot be null"); + } + + @Test + public void setAuthorizedClientParametersMapperWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientService.setAuthorizedClientParametersMapper(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientParametersMapper cannot be null"); + } + + @Test + public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(null, "principalName")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationId cannot be empty"); + } + + @Test + public void loadAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("principalName cannot be empty"); + } + + @Test + public void loadAuthorizedClientWhenDoesNotExistThenReturnNull() { + OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient( + "registration-not-found", "principalName"); + assertThat(authorizedClient).isNull(); + } + + @Test + public void loadAuthorizedClientWhenExistsThenReturnAuthorizedClient() { + Authentication principal = createPrincipal(); + OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration); + + this.authorizedClientService.saveAuthorizedClient(expected, principal); + + OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient( + this.clientRegistration.getRegistrationId(), principal.getName()); + + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration()); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName()); + assertThat(authorizedClient.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType()); + assertThat(authorizedClient.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue()); + assertThat(authorizedClient.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt()); + assertThat(authorizedClient.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt()); + assertThat(authorizedClient.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes()); + assertThat(authorizedClient.getRefreshToken().getTokenValue()).isEqualTo(expected.getRefreshToken().getTokenValue()); + assertThat(authorizedClient.getRefreshToken().getIssuedAt()).isEqualTo(expected.getRefreshToken().getIssuedAt()); + } + + @Test + public void loadAuthorizedClientWhenExistsButNotFoundInClientRegistrationRepositoryThenThrowDataRetrievalFailureException() { + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(null); + Authentication principal = createPrincipal(); + OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration); + + this.authorizedClientService.saveAuthorizedClient(expected, principal); + + assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName())) + .isInstanceOf(DataRetrievalFailureException.class) + .hasMessage("The ClientRegistration with id '" + this.clientRegistration.getRegistrationId() + + "' exists in the data source, however, it was not found in the ClientRegistrationRepository."); + } + + @Test + public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { + Authentication principal = createPrincipal(); + + assertThatThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(null, principal)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClient cannot be null"); + } + + @Test + public void saveAuthorizedClientWhenPrincipalIsNullThenThrowIllegalArgumentException() { + Authentication principal = createPrincipal(); + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); + + assertThatThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("principal cannot be null"); + } + + @Test + public void saveAuthorizedClientWhenSaveThenLoadReturnsSaved() { + Authentication principal = createPrincipal(); + OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration); + + this.authorizedClientService.saveAuthorizedClient(expected, principal); + + OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient( + this.clientRegistration.getRegistrationId(), principal.getName()); + + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration()); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName()); + assertThat(authorizedClient.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType()); + assertThat(authorizedClient.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue()); + assertThat(authorizedClient.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt()); + assertThat(authorizedClient.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt()); + assertThat(authorizedClient.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes()); + assertThat(authorizedClient.getRefreshToken().getTokenValue()).isEqualTo(expected.getRefreshToken().getTokenValue()); + assertThat(authorizedClient.getRefreshToken().getIssuedAt()).isEqualTo(expected.getRefreshToken().getIssuedAt()); + + // Test save/load of NOT NULL attributes only + principal = createPrincipal(); + expected = createAuthorizedClient(principal, this.clientRegistration, true); + + this.authorizedClientService.saveAuthorizedClient(expected, principal); + + authorizedClient = this.authorizedClientService.loadAuthorizedClient( + this.clientRegistration.getRegistrationId(), principal.getName()); + + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration()); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName()); + assertThat(authorizedClient.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType()); + assertThat(authorizedClient.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue()); + assertThat(authorizedClient.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt()); + assertThat(authorizedClient.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt()); + assertThat(authorizedClient.getAccessToken().getScopes()).isEmpty(); + assertThat(authorizedClient.getRefreshToken()).isNull(); + } + + @Test + public void saveAuthorizedClientWhenSaveDuplicateThenThrowDuplicateKeyException() { + Authentication principal = createPrincipal(); + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); + + this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal); + + assertThatThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal)) + .isInstanceOf(DuplicateKeyException.class); + } + + @Test + public void saveLoadAuthorizedClientWhenCustomStrategiesSetThenCalled() throws Exception { + JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientRowMapper authorizedClientRowMapper = + spy(new JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientRowMapper(this.clientRegistrationRepository)); + this.authorizedClientService.setAuthorizedClientRowMapper(authorizedClientRowMapper); + JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientParametersMapper authorizedClientParametersMapper = + spy(new JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientParametersMapper()); + this.authorizedClientService.setAuthorizedClientParametersMapper(authorizedClientParametersMapper); + + Authentication principal = createPrincipal(); + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); + + this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal); + this.authorizedClientService.loadAuthorizedClient( + this.clientRegistration.getRegistrationId(), principal.getName()); + + verify(authorizedClientRowMapper).mapRow(any(), anyInt()); + verify(authorizedClientParametersMapper).apply(any()); + } + + @Test + public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(null, "principalName")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationId cannot be empty"); + } + + @Test + public void removeAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(this.clientRegistration.getRegistrationId(), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("principalName cannot be empty"); + } + + @Test + public void removeAuthorizedClientWhenExistsThenRemoved() { + Authentication principal = createPrincipal(); + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); + + this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal); + + authorizedClient = this.authorizedClientService.loadAuthorizedClient( + this.clientRegistration.getRegistrationId(), principal.getName()); + assertThat(authorizedClient).isNotNull(); + + this.authorizedClientService.removeAuthorizedClient( + this.clientRegistration.getRegistrationId(), principal.getName()); + + authorizedClient = this.authorizedClientService.loadAuthorizedClient( + this.clientRegistration.getRegistrationId(), principal.getName()); + assertThat(authorizedClient).isNull(); + } + + @Test + public void tableDefinitionWhenCustomThenAbleToOverride() { + CustomTableDefinitionJdbcOAuth2AuthorizedClientService customAuthorizedClientService = + new CustomTableDefinitionJdbcOAuth2AuthorizedClientService( + new JdbcTemplate(createDb("custom-oauth2-client-schema.sql")), + this.clientRegistrationRepository); + + Authentication principal = createPrincipal(); + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); + + customAuthorizedClientService.saveAuthorizedClient(authorizedClient, principal); + + authorizedClient = customAuthorizedClientService.loadAuthorizedClient( + this.clientRegistration.getRegistrationId(), principal.getName()); + assertThat(authorizedClient).isNotNull(); + + customAuthorizedClientService.removeAuthorizedClient( + this.clientRegistration.getRegistrationId(), principal.getName()); + + authorizedClient = customAuthorizedClientService.loadAuthorizedClient( + this.clientRegistration.getRegistrationId(), principal.getName()); + assertThat(authorizedClient).isNull(); + } + + private static EmbeddedDatabase createDb() { + return createDb(OAUTH2_CLIENT_SCHEMA_SQL_RESOURCE); + } + + private static EmbeddedDatabase createDb(String schema) { + return new EmbeddedDatabaseBuilder() + .generateUniqueName(true) + .setType(EmbeddedDatabaseType.HSQL) + .setScriptEncoding("UTF-8") + .addScript(schema) + .build(); + } + + private static Authentication createPrincipal() { + return new TestingAuthenticationToken("principal-" + principalId++, "password"); + } + + private static OAuth2AuthorizedClient createAuthorizedClient(Authentication principal, ClientRegistration clientRegistration) { + return createAuthorizedClient(principal, clientRegistration, false); + } + + private static OAuth2AuthorizedClient createAuthorizedClient(Authentication principal, + ClientRegistration clientRegistration, boolean requiredAttributesOnly) { + OAuth2AccessToken accessToken; + if (!requiredAttributesOnly) { + accessToken = TestOAuth2AccessTokens.scopes("read", "write"); + } else { + accessToken = TestOAuth2AccessTokens.noScopes(); + } + OAuth2RefreshToken refreshToken = null; + if (!requiredAttributesOnly) { + refreshToken = TestOAuth2RefreshTokens.refreshToken(); + } + return new OAuth2AuthorizedClient( + clientRegistration, principal.getName(), accessToken, refreshToken); + } + + private static class CustomTableDefinitionJdbcOAuth2AuthorizedClientService extends JdbcOAuth2AuthorizedClientService { + private static final String COLUMN_NAMES = + "clientRegistrationId, " + + "principalName, " + + "accessTokenType, " + + "accessTokenValue, " + + "accessTokenIssuedAt, " + + "accessTokenExpiresAt, " + + "accessTokenScopes, " + + "refreshTokenValue, " + + "refreshTokenIssuedAt"; + private static final String TABLE_NAME = "oauth2AuthorizedClient"; + private static final String PK_FILTER = "clientRegistrationId = ? AND principalName = ?"; + private static final String LOAD_AUTHORIZED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + + " FROM " + TABLE_NAME + " WHERE " + PK_FILTER; + private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME + + " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME + + " WHERE " + PK_FILTER; + + private CustomTableDefinitionJdbcOAuth2AuthorizedClientService( + JdbcOperations jdbcOperations, ClientRegistrationRepository clientRegistrationRepository) { + super(jdbcOperations, clientRegistrationRepository); + setAuthorizedClientRowMapper(new OAuth2AuthorizedClientRowMapper(clientRegistrationRepository)); + } + + @Override + @SuppressWarnings("unchecked") + public T loadAuthorizedClient(String clientRegistrationId, String principalName) { + SqlParameterValue[] parameters = new SqlParameterValue[] { + new SqlParameterValue(Types.VARCHAR, clientRegistrationId), + new SqlParameterValue(Types.VARCHAR, principalName) + }; + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); + List result = this.jdbcOperations.query( + LOAD_AUTHORIZED_CLIENT_SQL, pss, this.authorizedClientRowMapper); + return !result.isEmpty() ? (T) result.get(0) : null; + } + + @Override + public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { + List parameters = this.authorizedClientParametersMapper.apply( + new OAuth2AuthorizedClientHolder(authorizedClient, principal)); + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); + this.jdbcOperations.update(SAVE_AUTHORIZED_CLIENT_SQL, pss); + } + + @Override + public void removeAuthorizedClient(String clientRegistrationId, String principalName) { + SqlParameterValue[] parameters = new SqlParameterValue[] { + new SqlParameterValue(Types.VARCHAR, clientRegistrationId), + new SqlParameterValue(Types.VARCHAR, principalName) + }; + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); + this.jdbcOperations.update(REMOVE_AUTHORIZED_CLIENT_SQL, pss); + } + + private static class OAuth2AuthorizedClientRowMapper implements RowMapper { + private final ClientRegistrationRepository clientRegistrationRepository; + + private OAuth2AuthorizedClientRowMapper(ClientRegistrationRepository clientRegistrationRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + } + + @Override + public OAuth2AuthorizedClient mapRow(ResultSet rs, int rowNum) throws SQLException { + String clientRegistrationId = rs.getString("clientRegistrationId"); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId( + clientRegistrationId); + if (clientRegistration == null) { + throw new DataRetrievalFailureException("The ClientRegistration with id '" + + clientRegistrationId + "' exists in the data source, " + + "however, it was not found in the ClientRegistrationRepository."); + } + + OAuth2AccessToken.TokenType tokenType = null; + if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase( + rs.getString("accessTokenType"))) { + tokenType = OAuth2AccessToken.TokenType.BEARER; + } + String tokenValue = new String(rs.getBytes("accessTokenValue"), StandardCharsets.UTF_8); + Instant issuedAt = rs.getTimestamp("accessTokenIssuedAt").toInstant(); + Instant expiresAt = rs.getTimestamp("accessTokenExpiresAt").toInstant(); + Set scopes = Collections.emptySet(); + String accessTokenScopes = rs.getString("accessTokenScopes"); + if (accessTokenScopes != null) { + scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes); + } + OAuth2AccessToken accessToken = new OAuth2AccessToken( + tokenType, tokenValue, issuedAt, expiresAt, scopes); + + OAuth2RefreshToken refreshToken = null; + byte[] refreshTokenValue = rs.getBytes("refreshTokenValue"); + if (refreshTokenValue != null) { + tokenValue = new String(refreshTokenValue, StandardCharsets.UTF_8); + issuedAt = null; + Timestamp refreshTokenIssuedAt = rs.getTimestamp("refreshTokenIssuedAt"); + if (refreshTokenIssuedAt != null) { + issuedAt = refreshTokenIssuedAt.toInstant(); + } + refreshToken = new OAuth2RefreshToken(tokenValue, issuedAt); + } + + String principalName = rs.getString("principalName"); + + return new OAuth2AuthorizedClient( + clientRegistration, principalName, accessToken, refreshToken); + } + } + } +} diff --git a/oauth2/oauth2-client/src/test/resources/custom-oauth2-client-schema.sql b/oauth2/oauth2-client/src/test/resources/custom-oauth2-client-schema.sql new file mode 100644 index 0000000000..9641169fdd --- /dev/null +++ b/oauth2/oauth2-client/src/test/resources/custom-oauth2-client-schema.sql @@ -0,0 +1,13 @@ +CREATE TABLE oauth2AuthorizedClient ( + clientRegistrationId varchar(100) NOT NULL, + principalName varchar(200) NOT NULL, + accessTokenType varchar(100) NOT NULL, + accessTokenValue blob NOT NULL, + accessTokenIssuedAt timestamp NOT NULL, + accessTokenExpiresAt timestamp NOT NULL, + accessTokenScopes varchar(1000) DEFAULT NULL, + refreshTokenValue blob DEFAULT NULL, + refreshTokenIssuedAt timestamp DEFAULT NULL, + createdAt timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL, + PRIMARY KEY (clientRegistrationId, principalName) +);