diff --git a/oauth2/oauth2-client/spring-security-oauth2-client.gradle b/oauth2/oauth2-client/spring-security-oauth2-client.gradle
index d43a0849e1..6e17822390 100644
--- a/oauth2/oauth2-client/spring-security-oauth2-client.gradle
+++ b/oauth2/oauth2-client/spring-security-oauth2-client.gradle
@@ -12,6 +12,7 @@ dependencies {
optional 'org.springframework:spring-webflux'
optional 'com.fasterxml.jackson.core:jackson-databind'
optional 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310'
+ optional 'org.springframework:spring-jdbc'
testCompile project(path: ':spring-security-oauth2-core', configuration: 'tests')
testCompile project(path: ':spring-security-oauth2-jose', configuration: 'tests')
@@ -22,5 +23,7 @@ dependencies {
testCompile 'io.projectreactor.tools:blockhound'
testCompile 'org.skyscreamer:jsonassert'
+ testRuntime 'org.hsqldb:hsqldb'
+
provided 'javax.servlet:javax.servlet-api'
}
diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java
new file mode 100644
index 0000000000..a4ac907123
--- /dev/null
+++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java
@@ -0,0 +1,312 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client;
+
+import org.springframework.dao.DataRetrievalFailureException;
+import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
+import org.springframework.jdbc.core.JdbcOperations;
+import org.springframework.jdbc.core.PreparedStatementSetter;
+import org.springframework.jdbc.core.RowMapper;
+import org.springframework.jdbc.core.SqlParameterValue;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+import org.springframework.util.StringUtils;
+
+import java.nio.charset.StandardCharsets;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Timestamp;
+import java.sql.Types;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+import java.util.function.Function;
+
+/**
+ * A JDBC implementation of an {@link OAuth2AuthorizedClientService}
+ * that uses a {@link JdbcOperations} for {@link OAuth2AuthorizedClient} persistence.
+ *
+ *
+ * NOTE: This {@code OAuth2AuthorizedClientService} depends on the table definition
+ * described in "classpath:org/springframework/security/oauth2/client/oauth2-client-schema.sql"
+ * and therefore MUST be defined in the database schema.
+ *
+ * @author Joe Grandja
+ * @since 5.3
+ * @see OAuth2AuthorizedClientService
+ * @see OAuth2AuthorizedClient
+ * @see JdbcOperations
+ * @see RowMapper
+ */
+public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClientService {
+ private static final String COLUMN_NAMES =
+ "client_registration_id, " +
+ "principal_name, " +
+ "access_token_type, " +
+ "access_token_value, " +
+ "access_token_issued_at, " +
+ "access_token_expires_at, " +
+ "access_token_scopes, " +
+ "refresh_token_value, " +
+ "refresh_token_issued_at";
+ private static final String TABLE_NAME = "oauth2_authorized_client";
+ private static final String PK_FILTER = "client_registration_id = ? AND principal_name = ?";
+ private static final String LOAD_AUTHORIZED_CLIENT_SQL = "SELECT " + COLUMN_NAMES +
+ " FROM " + TABLE_NAME + " WHERE " + PK_FILTER;
+ private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME +
+ " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";
+ private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME +
+ " WHERE " + PK_FILTER;
+ protected final JdbcOperations jdbcOperations;
+ protected RowMapper authorizedClientRowMapper;
+ protected Function> authorizedClientParametersMapper;
+
+ /**
+ * Constructs a {@code JdbcOAuth2AuthorizedClientService} using the provided parameters.
+ *
+ * @param jdbcOperations the JDBC operations
+ * @param clientRegistrationRepository the repository of client registrations
+ */
+ public JdbcOAuth2AuthorizedClientService(
+ JdbcOperations jdbcOperations, ClientRegistrationRepository clientRegistrationRepository) {
+
+ Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
+ Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
+ this.jdbcOperations = jdbcOperations;
+ this.authorizedClientRowMapper = new OAuth2AuthorizedClientRowMapper(clientRegistrationRepository);
+ this.authorizedClientParametersMapper = new OAuth2AuthorizedClientParametersMapper();
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public T loadAuthorizedClient(String clientRegistrationId, String principalName) {
+ Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
+ Assert.hasText(principalName, "principalName cannot be empty");
+
+ SqlParameterValue[] parameters = new SqlParameterValue[] {
+ new SqlParameterValue(Types.VARCHAR, clientRegistrationId),
+ new SqlParameterValue(Types.VARCHAR, principalName)
+ };
+ PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
+
+ List result = this.jdbcOperations.query(
+ LOAD_AUTHORIZED_CLIENT_SQL, pss, this.authorizedClientRowMapper);
+
+ return !result.isEmpty() ? (T) result.get(0) : null;
+ }
+
+ @Override
+ public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
+ Assert.notNull(authorizedClient, "authorizedClient cannot be null");
+ Assert.notNull(principal, "principal cannot be null");
+
+ List parameters = this.authorizedClientParametersMapper.apply(
+ new OAuth2AuthorizedClientHolder(authorizedClient, principal));
+ PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
+
+ this.jdbcOperations.update(SAVE_AUTHORIZED_CLIENT_SQL, pss);
+ }
+
+ @Override
+ public void removeAuthorizedClient(String clientRegistrationId, String principalName) {
+ Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
+ Assert.hasText(principalName, "principalName cannot be empty");
+
+ SqlParameterValue[] parameters = new SqlParameterValue[] {
+ new SqlParameterValue(Types.VARCHAR, clientRegistrationId),
+ new SqlParameterValue(Types.VARCHAR, principalName)
+ };
+ PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
+
+ this.jdbcOperations.update(REMOVE_AUTHORIZED_CLIENT_SQL, pss);
+ }
+
+ /**
+ * Sets the {@link RowMapper} used for mapping the current row in {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient}.
+ * The default is {@link OAuth2AuthorizedClientRowMapper}.
+ *
+ * @param authorizedClientRowMapper the {@link RowMapper} used for mapping the current row in {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient}
+ */
+ public final void setAuthorizedClientRowMapper(RowMapper authorizedClientRowMapper) {
+ Assert.notNull(authorizedClientRowMapper, "authorizedClientRowMapper cannot be null");
+ this.authorizedClientRowMapper = authorizedClientRowMapper;
+ }
+
+ /**
+ * Sets the {@code Function} used for mapping {@link OAuth2AuthorizedClientHolder} to a {@code List} of {@link SqlParameterValue}.
+ * The default is {@link OAuth2AuthorizedClientParametersMapper}.
+ *
+ * @param authorizedClientParametersMapper the {@code Function} used for mapping {@link OAuth2AuthorizedClientHolder} to a {@code List} of {@link SqlParameterValue}
+ */
+ public final void setAuthorizedClientParametersMapper(Function> authorizedClientParametersMapper) {
+ Assert.notNull(authorizedClientParametersMapper, "authorizedClientParametersMapper cannot be null");
+ this.authorizedClientParametersMapper = authorizedClientParametersMapper;
+ }
+
+ /**
+ * The default {@link RowMapper} that maps the current row
+ * in {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient}.
+ */
+ public static class OAuth2AuthorizedClientRowMapper implements RowMapper {
+ protected final ClientRegistrationRepository clientRegistrationRepository;
+
+ public OAuth2AuthorizedClientRowMapper(ClientRegistrationRepository clientRegistrationRepository) {
+ Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
+ this.clientRegistrationRepository = clientRegistrationRepository;
+ }
+
+ @Override
+ public OAuth2AuthorizedClient mapRow(ResultSet rs, int rowNum) throws SQLException {
+ String clientRegistrationId = rs.getString("client_registration_id");
+ ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(
+ clientRegistrationId);
+ if (clientRegistration == null) {
+ throw new DataRetrievalFailureException("The ClientRegistration with id '" +
+ clientRegistrationId + "' exists in the data source, " +
+ "however, it was not found in the ClientRegistrationRepository.");
+ }
+
+ OAuth2AccessToken.TokenType tokenType = null;
+ if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(
+ rs.getString("access_token_type"))) {
+ tokenType = OAuth2AccessToken.TokenType.BEARER;
+ }
+ String tokenValue = new String(rs.getBytes("access_token_value"), StandardCharsets.UTF_8);
+ Instant issuedAt = rs.getTimestamp("access_token_issued_at").toInstant();
+ Instant expiresAt = rs.getTimestamp("access_token_expires_at").toInstant();
+ Set scopes = Collections.emptySet();
+ String accessTokenScopes = rs.getString("access_token_scopes");
+ if (accessTokenScopes != null) {
+ scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes);
+ }
+ OAuth2AccessToken accessToken = new OAuth2AccessToken(
+ tokenType, tokenValue, issuedAt, expiresAt, scopes);
+
+ OAuth2RefreshToken refreshToken = null;
+ byte[] refreshTokenValue = rs.getBytes("refresh_token_value");
+ if (refreshTokenValue != null) {
+ tokenValue = new String(refreshTokenValue, StandardCharsets.UTF_8);
+ issuedAt = null;
+ Timestamp refreshTokenIssuedAt = rs.getTimestamp("refresh_token_issued_at");
+ if (refreshTokenIssuedAt != null) {
+ issuedAt = refreshTokenIssuedAt.toInstant();
+ }
+ refreshToken = new OAuth2RefreshToken(tokenValue, issuedAt);
+ }
+
+ String principalName = rs.getString("principal_name");
+
+ return new OAuth2AuthorizedClient(
+ clientRegistration, principalName, accessToken, refreshToken);
+ }
+ }
+
+ /**
+ * The default {@code Function} that maps {@link OAuth2AuthorizedClientHolder}
+ * to a {@code List} of {@link SqlParameterValue}.
+ */
+ public static class OAuth2AuthorizedClientParametersMapper implements Function> {
+
+ @Override
+ public List apply(OAuth2AuthorizedClientHolder authorizedClientHolder) {
+ OAuth2AuthorizedClient authorizedClient = authorizedClientHolder.getAuthorizedClient();
+ Authentication principal = authorizedClientHolder.getPrincipal();
+ ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
+ OAuth2AccessToken accessToken = authorizedClient.getAccessToken();
+ OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();
+
+ List parameters = new ArrayList<>();
+ parameters.add(new SqlParameterValue(
+ Types.VARCHAR, clientRegistration.getRegistrationId()));
+ parameters.add(new SqlParameterValue(
+ Types.VARCHAR, principal.getName()));
+ parameters.add(new SqlParameterValue(
+ Types.VARCHAR, accessToken.getTokenType().getValue()));
+ parameters.add(new SqlParameterValue(
+ Types.BLOB, accessToken.getTokenValue().getBytes(StandardCharsets.UTF_8)));
+ parameters.add(new SqlParameterValue(
+ Types.TIMESTAMP, Timestamp.from(accessToken.getIssuedAt())));
+ parameters.add(new SqlParameterValue(
+ Types.TIMESTAMP, Timestamp.from(accessToken.getExpiresAt())));
+ String accessTokenScopes = null;
+ if (!CollectionUtils.isEmpty(accessToken.getScopes())) {
+ accessTokenScopes = StringUtils.collectionToDelimitedString(accessToken.getScopes(), ",");
+ }
+ parameters.add(new SqlParameterValue(
+ Types.VARCHAR, accessTokenScopes));
+ byte[] refreshTokenValue = null;
+ Timestamp refreshTokenIssuedAt = null;
+ if (refreshToken != null) {
+ refreshTokenValue = refreshToken.getTokenValue().getBytes(StandardCharsets.UTF_8);
+ if (refreshToken.getIssuedAt() != null) {
+ refreshTokenIssuedAt = Timestamp.from(refreshToken.getIssuedAt());
+ }
+ }
+ parameters.add(new SqlParameterValue(
+ Types.BLOB, refreshTokenValue));
+ parameters.add(new SqlParameterValue(
+ Types.TIMESTAMP, refreshTokenIssuedAt));
+
+ return parameters;
+ }
+ }
+
+ /**
+ * A holder for an {@link OAuth2AuthorizedClient} and End-User {@link Authentication} (Resource Owner).
+ */
+ public static final class OAuth2AuthorizedClientHolder {
+ private final OAuth2AuthorizedClient authorizedClient;
+ private final Authentication principal;
+
+ /**
+ * Constructs an {@code OAuth2AuthorizedClientHolder} using the provided parameters.
+ *
+ * @param authorizedClient the authorized client
+ * @param principal the End-User {@link Authentication} (Resource Owner)
+ */
+ public OAuth2AuthorizedClientHolder(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
+ Assert.notNull(authorizedClient, "authorizedClient cannot be null");
+ Assert.notNull(principal, "principal cannot be null");
+ this.authorizedClient = authorizedClient;
+ this.principal = principal;
+ }
+
+ /**
+ * Returns the {@link OAuth2AuthorizedClient}.
+ *
+ * @return the {@link OAuth2AuthorizedClient}
+ */
+ public OAuth2AuthorizedClient getAuthorizedClient() {
+ return this.authorizedClient;
+ }
+
+ /**
+ * Returns the End-User {@link Authentication} (Resource Owner).
+ *
+ * @return the End-User {@link Authentication} (Resource Owner)
+ */
+ public Authentication getPrincipal() {
+ return this.principal;
+ }
+ }
+}
diff --git a/oauth2/oauth2-client/src/main/resources/org/springframework/security/oauth2/client/oauth2-client-schema.sql b/oauth2/oauth2-client/src/main/resources/org/springframework/security/oauth2/client/oauth2-client-schema.sql
new file mode 100644
index 0000000000..b4ceebf035
--- /dev/null
+++ b/oauth2/oauth2-client/src/main/resources/org/springframework/security/oauth2/client/oauth2-client-schema.sql
@@ -0,0 +1,13 @@
+CREATE TABLE oauth2_authorized_client (
+ client_registration_id varchar(100) NOT NULL,
+ principal_name varchar(200) NOT NULL,
+ access_token_type varchar(100) NOT NULL,
+ access_token_value blob NOT NULL,
+ access_token_issued_at timestamp NOT NULL,
+ access_token_expires_at timestamp NOT NULL,
+ access_token_scopes varchar(1000) DEFAULT NULL,
+ refresh_token_value blob DEFAULT NULL,
+ refresh_token_issued_at timestamp DEFAULT NULL,
+ created_at timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
+ PRIMARY KEY (client_registration_id, principal_name)
+);
diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientServiceTests.java
new file mode 100644
index 0000000000..40b2957e2a
--- /dev/null
+++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientServiceTests.java
@@ -0,0 +1,474 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.dao.DataRetrievalFailureException;
+import org.springframework.dao.DuplicateKeyException;
+import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
+import org.springframework.jdbc.core.JdbcOperations;
+import org.springframework.jdbc.core.JdbcTemplate;
+import org.springframework.jdbc.core.PreparedStatementSetter;
+import org.springframework.jdbc.core.RowMapper;
+import org.springframework.jdbc.core.SqlParameterValue;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
+import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
+import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
+
+import java.nio.charset.StandardCharsets;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Timestamp;
+import java.sql.Types;
+import java.time.Instant;
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link JdbcOAuth2AuthorizedClientService}.
+ *
+ * @author Joe Grandja
+ */
+public class JdbcOAuth2AuthorizedClientServiceTests {
+ private static final String OAUTH2_CLIENT_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/client/oauth2-client-schema.sql";
+ private static int principalId = 1000;
+ private ClientRegistration clientRegistration;
+ private ClientRegistrationRepository clientRegistrationRepository;
+ private EmbeddedDatabase db;
+ private JdbcOperations jdbcOperations;
+ private JdbcOAuth2AuthorizedClientService authorizedClientService;
+
+ @Before
+ public void setUp() {
+ this.clientRegistration = TestClientRegistrations.clientRegistration().build();
+ this.clientRegistrationRepository = mock(ClientRegistrationRepository.class);
+ when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.clientRegistration);
+ this.db = createDb();
+ this.jdbcOperations = new JdbcTemplate(this.db);
+ this.authorizedClientService = new JdbcOAuth2AuthorizedClientService(
+ this.jdbcOperations, this.clientRegistrationRepository);
+ }
+
+ @After
+ public void tearDown() {
+ this.db.shutdown();
+ }
+
+ @Test
+ public void constructorWhenJdbcOperationsIsNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> new JdbcOAuth2AuthorizedClientService(null, this.clientRegistrationRepository))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("jdbcOperations cannot be null");
+ }
+
+ @Test
+ public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> new JdbcOAuth2AuthorizedClientService(this.jdbcOperations, null))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("clientRegistrationRepository cannot be null");
+ }
+
+ @Test
+ public void setAuthorizedClientRowMapperWhenNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> this.authorizedClientService.setAuthorizedClientRowMapper(null))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("authorizedClientRowMapper cannot be null");
+ }
+
+ @Test
+ public void setAuthorizedClientParametersMapperWhenNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> this.authorizedClientService.setAuthorizedClientParametersMapper(null))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("authorizedClientParametersMapper cannot be null");
+ }
+
+ @Test
+ public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(null, "principalName"))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("clientRegistrationId cannot be empty");
+ }
+
+ @Test
+ public void loadAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), null))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("principalName cannot be empty");
+ }
+
+ @Test
+ public void loadAuthorizedClientWhenDoesNotExistThenReturnNull() {
+ OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
+ "registration-not-found", "principalName");
+ assertThat(authorizedClient).isNull();
+ }
+
+ @Test
+ public void loadAuthorizedClientWhenExistsThenReturnAuthorizedClient() {
+ Authentication principal = createPrincipal();
+ OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration);
+
+ this.authorizedClientService.saveAuthorizedClient(expected, principal);
+
+ OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
+ this.clientRegistration.getRegistrationId(), principal.getName());
+
+ assertThat(authorizedClient).isNotNull();
+ assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration());
+ assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName());
+ assertThat(authorizedClient.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
+ assertThat(authorizedClient.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
+ assertThat(authorizedClient.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
+ assertThat(authorizedClient.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
+ assertThat(authorizedClient.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
+ assertThat(authorizedClient.getRefreshToken().getTokenValue()).isEqualTo(expected.getRefreshToken().getTokenValue());
+ assertThat(authorizedClient.getRefreshToken().getIssuedAt()).isEqualTo(expected.getRefreshToken().getIssuedAt());
+ }
+
+ @Test
+ public void loadAuthorizedClientWhenExistsButNotFoundInClientRegistrationRepositoryThenThrowDataRetrievalFailureException() {
+ when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(null);
+ Authentication principal = createPrincipal();
+ OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration);
+
+ this.authorizedClientService.saveAuthorizedClient(expected, principal);
+
+ assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()))
+ .isInstanceOf(DataRetrievalFailureException.class)
+ .hasMessage("The ClientRegistration with id '" + this.clientRegistration.getRegistrationId() +
+ "' exists in the data source, however, it was not found in the ClientRegistrationRepository.");
+ }
+
+ @Test
+ public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() {
+ Authentication principal = createPrincipal();
+
+ assertThatThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(null, principal))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("authorizedClient cannot be null");
+ }
+
+ @Test
+ public void saveAuthorizedClientWhenPrincipalIsNullThenThrowIllegalArgumentException() {
+ Authentication principal = createPrincipal();
+ OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration);
+
+ assertThatThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, null))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("principal cannot be null");
+ }
+
+ @Test
+ public void saveAuthorizedClientWhenSaveThenLoadReturnsSaved() {
+ Authentication principal = createPrincipal();
+ OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration);
+
+ this.authorizedClientService.saveAuthorizedClient(expected, principal);
+
+ OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
+ this.clientRegistration.getRegistrationId(), principal.getName());
+
+ assertThat(authorizedClient).isNotNull();
+ assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration());
+ assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName());
+ assertThat(authorizedClient.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
+ assertThat(authorizedClient.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
+ assertThat(authorizedClient.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
+ assertThat(authorizedClient.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
+ assertThat(authorizedClient.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
+ assertThat(authorizedClient.getRefreshToken().getTokenValue()).isEqualTo(expected.getRefreshToken().getTokenValue());
+ assertThat(authorizedClient.getRefreshToken().getIssuedAt()).isEqualTo(expected.getRefreshToken().getIssuedAt());
+
+ // Test save/load of NOT NULL attributes only
+ principal = createPrincipal();
+ expected = createAuthorizedClient(principal, this.clientRegistration, true);
+
+ this.authorizedClientService.saveAuthorizedClient(expected, principal);
+
+ authorizedClient = this.authorizedClientService.loadAuthorizedClient(
+ this.clientRegistration.getRegistrationId(), principal.getName());
+
+ assertThat(authorizedClient).isNotNull();
+ assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration());
+ assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName());
+ assertThat(authorizedClient.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
+ assertThat(authorizedClient.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
+ assertThat(authorizedClient.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
+ assertThat(authorizedClient.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
+ assertThat(authorizedClient.getAccessToken().getScopes()).isEmpty();
+ assertThat(authorizedClient.getRefreshToken()).isNull();
+ }
+
+ @Test
+ public void saveAuthorizedClientWhenSaveDuplicateThenThrowDuplicateKeyException() {
+ Authentication principal = createPrincipal();
+ OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration);
+
+ this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal);
+
+ assertThatThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal))
+ .isInstanceOf(DuplicateKeyException.class);
+ }
+
+ @Test
+ public void saveLoadAuthorizedClientWhenCustomStrategiesSetThenCalled() throws Exception {
+ JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientRowMapper authorizedClientRowMapper =
+ spy(new JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientRowMapper(this.clientRegistrationRepository));
+ this.authorizedClientService.setAuthorizedClientRowMapper(authorizedClientRowMapper);
+ JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientParametersMapper authorizedClientParametersMapper =
+ spy(new JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientParametersMapper());
+ this.authorizedClientService.setAuthorizedClientParametersMapper(authorizedClientParametersMapper);
+
+ Authentication principal = createPrincipal();
+ OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration);
+
+ this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal);
+ this.authorizedClientService.loadAuthorizedClient(
+ this.clientRegistration.getRegistrationId(), principal.getName());
+
+ verify(authorizedClientRowMapper).mapRow(any(), anyInt());
+ verify(authorizedClientParametersMapper).apply(any());
+ }
+
+ @Test
+ public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(null, "principalName"))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("clientRegistrationId cannot be empty");
+ }
+
+ @Test
+ public void removeAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(this.clientRegistration.getRegistrationId(), null))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("principalName cannot be empty");
+ }
+
+ @Test
+ public void removeAuthorizedClientWhenExistsThenRemoved() {
+ Authentication principal = createPrincipal();
+ OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration);
+
+ this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal);
+
+ authorizedClient = this.authorizedClientService.loadAuthorizedClient(
+ this.clientRegistration.getRegistrationId(), principal.getName());
+ assertThat(authorizedClient).isNotNull();
+
+ this.authorizedClientService.removeAuthorizedClient(
+ this.clientRegistration.getRegistrationId(), principal.getName());
+
+ authorizedClient = this.authorizedClientService.loadAuthorizedClient(
+ this.clientRegistration.getRegistrationId(), principal.getName());
+ assertThat(authorizedClient).isNull();
+ }
+
+ @Test
+ public void tableDefinitionWhenCustomThenAbleToOverride() {
+ CustomTableDefinitionJdbcOAuth2AuthorizedClientService customAuthorizedClientService =
+ new CustomTableDefinitionJdbcOAuth2AuthorizedClientService(
+ new JdbcTemplate(createDb("custom-oauth2-client-schema.sql")),
+ this.clientRegistrationRepository);
+
+ Authentication principal = createPrincipal();
+ OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration);
+
+ customAuthorizedClientService.saveAuthorizedClient(authorizedClient, principal);
+
+ authorizedClient = customAuthorizedClientService.loadAuthorizedClient(
+ this.clientRegistration.getRegistrationId(), principal.getName());
+ assertThat(authorizedClient).isNotNull();
+
+ customAuthorizedClientService.removeAuthorizedClient(
+ this.clientRegistration.getRegistrationId(), principal.getName());
+
+ authorizedClient = customAuthorizedClientService.loadAuthorizedClient(
+ this.clientRegistration.getRegistrationId(), principal.getName());
+ assertThat(authorizedClient).isNull();
+ }
+
+ private static EmbeddedDatabase createDb() {
+ return createDb(OAUTH2_CLIENT_SCHEMA_SQL_RESOURCE);
+ }
+
+ private static EmbeddedDatabase createDb(String schema) {
+ return new EmbeddedDatabaseBuilder()
+ .generateUniqueName(true)
+ .setType(EmbeddedDatabaseType.HSQL)
+ .setScriptEncoding("UTF-8")
+ .addScript(schema)
+ .build();
+ }
+
+ private static Authentication createPrincipal() {
+ return new TestingAuthenticationToken("principal-" + principalId++, "password");
+ }
+
+ private static OAuth2AuthorizedClient createAuthorizedClient(Authentication principal, ClientRegistration clientRegistration) {
+ return createAuthorizedClient(principal, clientRegistration, false);
+ }
+
+ private static OAuth2AuthorizedClient createAuthorizedClient(Authentication principal,
+ ClientRegistration clientRegistration, boolean requiredAttributesOnly) {
+ OAuth2AccessToken accessToken;
+ if (!requiredAttributesOnly) {
+ accessToken = TestOAuth2AccessTokens.scopes("read", "write");
+ } else {
+ accessToken = TestOAuth2AccessTokens.noScopes();
+ }
+ OAuth2RefreshToken refreshToken = null;
+ if (!requiredAttributesOnly) {
+ refreshToken = TestOAuth2RefreshTokens.refreshToken();
+ }
+ return new OAuth2AuthorizedClient(
+ clientRegistration, principal.getName(), accessToken, refreshToken);
+ }
+
+ private static class CustomTableDefinitionJdbcOAuth2AuthorizedClientService extends JdbcOAuth2AuthorizedClientService {
+ private static final String COLUMN_NAMES =
+ "clientRegistrationId, " +
+ "principalName, " +
+ "accessTokenType, " +
+ "accessTokenValue, " +
+ "accessTokenIssuedAt, " +
+ "accessTokenExpiresAt, " +
+ "accessTokenScopes, " +
+ "refreshTokenValue, " +
+ "refreshTokenIssuedAt";
+ private static final String TABLE_NAME = "oauth2AuthorizedClient";
+ private static final String PK_FILTER = "clientRegistrationId = ? AND principalName = ?";
+ private static final String LOAD_AUTHORIZED_CLIENT_SQL = "SELECT " + COLUMN_NAMES +
+ " FROM " + TABLE_NAME + " WHERE " + PK_FILTER;
+ private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME +
+ " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";
+ private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME +
+ " WHERE " + PK_FILTER;
+
+ private CustomTableDefinitionJdbcOAuth2AuthorizedClientService(
+ JdbcOperations jdbcOperations, ClientRegistrationRepository clientRegistrationRepository) {
+ super(jdbcOperations, clientRegistrationRepository);
+ setAuthorizedClientRowMapper(new OAuth2AuthorizedClientRowMapper(clientRegistrationRepository));
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public T loadAuthorizedClient(String clientRegistrationId, String principalName) {
+ SqlParameterValue[] parameters = new SqlParameterValue[] {
+ new SqlParameterValue(Types.VARCHAR, clientRegistrationId),
+ new SqlParameterValue(Types.VARCHAR, principalName)
+ };
+ PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
+ List result = this.jdbcOperations.query(
+ LOAD_AUTHORIZED_CLIENT_SQL, pss, this.authorizedClientRowMapper);
+ return !result.isEmpty() ? (T) result.get(0) : null;
+ }
+
+ @Override
+ public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
+ List parameters = this.authorizedClientParametersMapper.apply(
+ new OAuth2AuthorizedClientHolder(authorizedClient, principal));
+ PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
+ this.jdbcOperations.update(SAVE_AUTHORIZED_CLIENT_SQL, pss);
+ }
+
+ @Override
+ public void removeAuthorizedClient(String clientRegistrationId, String principalName) {
+ SqlParameterValue[] parameters = new SqlParameterValue[] {
+ new SqlParameterValue(Types.VARCHAR, clientRegistrationId),
+ new SqlParameterValue(Types.VARCHAR, principalName)
+ };
+ PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
+ this.jdbcOperations.update(REMOVE_AUTHORIZED_CLIENT_SQL, pss);
+ }
+
+ private static class OAuth2AuthorizedClientRowMapper implements RowMapper {
+ private final ClientRegistrationRepository clientRegistrationRepository;
+
+ private OAuth2AuthorizedClientRowMapper(ClientRegistrationRepository clientRegistrationRepository) {
+ Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
+ this.clientRegistrationRepository = clientRegistrationRepository;
+ }
+
+ @Override
+ public OAuth2AuthorizedClient mapRow(ResultSet rs, int rowNum) throws SQLException {
+ String clientRegistrationId = rs.getString("clientRegistrationId");
+ ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(
+ clientRegistrationId);
+ if (clientRegistration == null) {
+ throw new DataRetrievalFailureException("The ClientRegistration with id '" +
+ clientRegistrationId + "' exists in the data source, " +
+ "however, it was not found in the ClientRegistrationRepository.");
+ }
+
+ OAuth2AccessToken.TokenType tokenType = null;
+ if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(
+ rs.getString("accessTokenType"))) {
+ tokenType = OAuth2AccessToken.TokenType.BEARER;
+ }
+ String tokenValue = new String(rs.getBytes("accessTokenValue"), StandardCharsets.UTF_8);
+ Instant issuedAt = rs.getTimestamp("accessTokenIssuedAt").toInstant();
+ Instant expiresAt = rs.getTimestamp("accessTokenExpiresAt").toInstant();
+ Set scopes = Collections.emptySet();
+ String accessTokenScopes = rs.getString("accessTokenScopes");
+ if (accessTokenScopes != null) {
+ scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes);
+ }
+ OAuth2AccessToken accessToken = new OAuth2AccessToken(
+ tokenType, tokenValue, issuedAt, expiresAt, scopes);
+
+ OAuth2RefreshToken refreshToken = null;
+ byte[] refreshTokenValue = rs.getBytes("refreshTokenValue");
+ if (refreshTokenValue != null) {
+ tokenValue = new String(refreshTokenValue, StandardCharsets.UTF_8);
+ issuedAt = null;
+ Timestamp refreshTokenIssuedAt = rs.getTimestamp("refreshTokenIssuedAt");
+ if (refreshTokenIssuedAt != null) {
+ issuedAt = refreshTokenIssuedAt.toInstant();
+ }
+ refreshToken = new OAuth2RefreshToken(tokenValue, issuedAt);
+ }
+
+ String principalName = rs.getString("principalName");
+
+ return new OAuth2AuthorizedClient(
+ clientRegistration, principalName, accessToken, refreshToken);
+ }
+ }
+ }
+}
diff --git a/oauth2/oauth2-client/src/test/resources/custom-oauth2-client-schema.sql b/oauth2/oauth2-client/src/test/resources/custom-oauth2-client-schema.sql
new file mode 100644
index 0000000000..9641169fdd
--- /dev/null
+++ b/oauth2/oauth2-client/src/test/resources/custom-oauth2-client-schema.sql
@@ -0,0 +1,13 @@
+CREATE TABLE oauth2AuthorizedClient (
+ clientRegistrationId varchar(100) NOT NULL,
+ principalName varchar(200) NOT NULL,
+ accessTokenType varchar(100) NOT NULL,
+ accessTokenValue blob NOT NULL,
+ accessTokenIssuedAt timestamp NOT NULL,
+ accessTokenExpiresAt timestamp NOT NULL,
+ accessTokenScopes varchar(1000) DEFAULT NULL,
+ refreshTokenValue blob DEFAULT NULL,
+ refreshTokenIssuedAt timestamp DEFAULT NULL,
+ createdAt timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
+ PRIMARY KEY (clientRegistrationId, principalName)
+);