Add JdbcAssertingPartyMetadataRepository

Closes gh-16012

Signed-off-by: chao.wang <chao.wang@zatech.com>
This commit is contained in:
chao.wang 2025-05-08 17:24:30 +08:00 committed by Josh Cummings
parent 9b724377ce
commit 16fd24c002
6 changed files with 420 additions and 0 deletions

View File

@ -106,6 +106,7 @@ dependencies {
provided 'jakarta.servlet:jakarta.servlet-api'
optional 'com.fasterxml.jackson.core:jackson-databind'
optional 'org.springframework:spring-jdbc'
testImplementation 'com.squareup.okhttp3:mockwebserver'
testImplementation "org.assertj:assertj-core"
@ -118,6 +119,7 @@ dependencies {
testImplementation "org.springframework:spring-test"
testRuntimeOnly 'org.junit.platform:junit-platform-launcher'
testRuntimeOnly 'org.hsqldb:hsqldb'
}
jar {

View File

@ -0,0 +1,190 @@
/*
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.registration;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Types;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.core.serializer.DefaultDeserializer;
import org.springframework.core.serializer.Deserializer;
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.PreparedStatementSetter;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.SqlParameterValue;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.AssertingPartyDetails;
import org.springframework.util.Assert;
/**
* A JDBC implementation of {@link AssertingPartyMetadataRepository}.
*
* @author Cathy Wang
* @since 7.0
*/
public final class JdbcAssertingPartyMetadataRepository implements AssertingPartyMetadataRepository {
private final JdbcOperations jdbcOperations;
private RowMapper<AssertingPartyMetadata> assertingPartyMetadataRowMapper = new AssertingPartyMetadataRowMapper(
ResultSet::getBytes);
// @formatter:off
static final String COLUMN_NAMES = "entity_id, "
+ "singlesignon_url, "
+ "singlesignon_binding, "
+ "singlesignon_sign_request, "
+ "signing_algorithms, "
+ "verification_credentials, "
+ "encryption_credentials, "
+ "singlelogout_url, "
+ "singlelogout_response_url, "
+ "singlelogout_binding";
// @formatter:on
private static final String TABLE_NAME = "saml2_asserting_party_metadata";
private static final String ENTITY_ID_FILTER = "entity_id = ?";
// @formatter:off
private static final String LOAD_BY_ID_SQL = "SELECT " + COLUMN_NAMES
+ " FROM " + TABLE_NAME
+ " WHERE " + ENTITY_ID_FILTER;
private static final String LOAD_ALL_SQL = "SELECT " + COLUMN_NAMES
+ " FROM " + TABLE_NAME;
// @formatter:on
/**
* Constructs a {@code JdbcRelyingPartyRegistrationRepository} using the provided
* parameters.
* @param jdbcOperations the JDBC operations
*/
public JdbcAssertingPartyMetadataRepository(JdbcOperations jdbcOperations) {
Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
this.jdbcOperations = jdbcOperations;
}
/**
* Sets the {@link RowMapper} used for mapping the current row in
* {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}. The default is
* {@link AssertingPartyMetadataRowMapper}.
* @param assertingPartyMetadataRowMapper the {@link RowMapper} used for mapping the
* current row in {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}
*/
public void setAssertingPartyMetadataRowMapper(RowMapper<AssertingPartyMetadata> assertingPartyMetadataRowMapper) {
Assert.notNull(assertingPartyMetadataRowMapper, "assertingPartyMetadataRowMapper cannot be null");
this.assertingPartyMetadataRowMapper = assertingPartyMetadataRowMapper;
}
@Override
public AssertingPartyMetadata findByEntityId(String entityId) {
Assert.hasText(entityId, "entityId cannot be empty");
SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, entityId) };
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
List<AssertingPartyMetadata> result = this.jdbcOperations.query(LOAD_BY_ID_SQL, pss,
this.assertingPartyMetadataRowMapper);
return !result.isEmpty() ? result.get(0) : null;
}
@Override
public Iterator<AssertingPartyMetadata> iterator() {
List<AssertingPartyMetadata> result = this.jdbcOperations.query(LOAD_ALL_SQL,
this.assertingPartyMetadataRowMapper);
return result.iterator();
}
/**
* The default {@link RowMapper} that maps the current row in
* {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}.
*/
private static final class AssertingPartyMetadataRowMapper implements RowMapper<AssertingPartyMetadata> {
private final Log logger = LogFactory.getLog(AssertingPartyMetadataRowMapper.class);
private final Deserializer<Object> deserializer = new DefaultDeserializer();
private final GetBytes getBytes;
AssertingPartyMetadataRowMapper(GetBytes getBytes) {
this.getBytes = getBytes;
}
@Override
public AssertingPartyMetadata mapRow(ResultSet rs, int rowNum) throws SQLException {
String entityId = rs.getString("entity_id");
String singleSignOnUrl = rs.getString("singlesignon_url");
Saml2MessageBinding singleSignOnBinding = Saml2MessageBinding.from(rs.getString("singlesignon_binding"));
boolean singleSignOnSignRequest = rs.getBoolean("singlesignon_sign_request");
String singleLogoutUrl = rs.getString("singlelogout_url");
String singleLogoutResponseUrl = rs.getString("singlelogout_response_url");
Saml2MessageBinding singleLogoutBinding = Saml2MessageBinding.from(rs.getString("singlelogout_binding"));
byte[] signingAlgorithmsBytes = this.getBytes.getBytes(rs, "signing_algorithms");
byte[] verificationCredentialsBytes = this.getBytes.getBytes(rs, "verification_credentials");
byte[] encryptionCredentialsBytes = this.getBytes.getBytes(rs, "encryption_credentials");
AssertingPartyMetadata.Builder<?> builder = new AssertingPartyDetails.Builder();
try {
if (signingAlgorithmsBytes != null) {
List<String> signingAlgorithms = (List<String>) this.deserializer
.deserializeFromByteArray(signingAlgorithmsBytes);
builder.signingAlgorithms((algorithms) -> algorithms.addAll(signingAlgorithms));
}
if (verificationCredentialsBytes != null) {
Collection<Saml2X509Credential> verificationCredentials = (Collection<Saml2X509Credential>) this.deserializer
.deserializeFromByteArray(verificationCredentialsBytes);
builder.verificationX509Credentials((credentials) -> credentials.addAll(verificationCredentials));
}
if (encryptionCredentialsBytes != null) {
Collection<Saml2X509Credential> encryptionCredentials = (Collection<Saml2X509Credential>) this.deserializer
.deserializeFromByteArray(encryptionCredentialsBytes);
builder.encryptionX509Credentials((credentials) -> credentials.addAll(encryptionCredentials));
}
}
catch (Exception ex) {
this.logger.debug(LogMessage.format("Parsing serialized credentials for entity %s failed", entityId),
ex);
return null;
}
builder.entityId(entityId)
.wantAuthnRequestsSigned(singleSignOnSignRequest)
.singleSignOnServiceLocation(singleSignOnUrl)
.singleSignOnServiceBinding(singleSignOnBinding)
.singleLogoutServiceLocation(singleLogoutUrl)
.singleLogoutServiceBinding(singleLogoutBinding)
.singleLogoutServiceResponseLocation(singleLogoutResponseUrl);
return builder.build();
}
}
private interface GetBytes {
byte[] getBytes(ResultSet rs, String columnName) throws SQLException;
}
}

View File

@ -0,0 +1,14 @@
CREATE TABLE saml2_asserting_party_metadata
(
entity_id VARCHAR(1000) NOT NULL,
singlesignon_url VARCHAR(1000) NOT NULL,
singlesignon_binding VARCHAR(100),
singlesignon_sign_request boolean,
signing_algorithms BYTEA,
verification_credentials BYTEA NOT NULL,
encryption_credentials BYTEA,
singlelogout_url VARCHAR(1000),
singlelogout_response_url VARCHAR(1000),
singlelogout_binding VARCHAR(100),
PRIMARY KEY (entity_id)
);

View File

@ -0,0 +1,14 @@
CREATE TABLE saml2_asserting_party_metadata
(
entity_id VARCHAR(1000) NOT NULL,
singlesignon_url VARCHAR(1000) NOT NULL,
singlesignon_binding VARCHAR(100),
singlesignon_sign_request boolean,
signing_algorithms blob,
verification_credentials blob NOT NULL,
encryption_credentials blob,
singlelogout_url VARCHAR(1000),
singlelogout_response_url VARCHAR(1000),
singlelogout_binding VARCHAR(100),
PRIMARY KEY (entity_id)
);

View File

@ -0,0 +1,177 @@
package org.springframework.security.saml2.provider.service.registration;
import java.io.IOException;
import java.io.InputStream;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.serializer.DefaultSerializer;
import org.springframework.core.serializer.Serializer;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import org.springframework.security.saml2.core.Saml2X509Credential;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
/**
* Tests for {@link JdbcAssertingPartyMetadataRepository}
*/
class JdbcAssertingPartyMetadataRepositoryTests {
private static final String SCHEMA_SQL_RESOURCE = "org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql";
private static final String SAVE_SQL = "INSERT INTO saml2_asserting_party_metadata ("
+ JdbcAssertingPartyMetadataRepository.COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
private static final String ENTITY_ID = "https://localhost/simplesaml/saml2/idp/metadata.php";
private static final String SINGLE_SIGNON_URL = "https://localhost/SSO";
private static final String SINGLE_SIGNON_BINDING = Saml2MessageBinding.REDIRECT.getUrn();
private static final boolean SINGLE_SIGNON_SIGN_REQUEST = false;
private static final String SINGLE_LOGOUT_URL = "https://localhost/SLO";
private static final String SINGLE_LOGOUT_RESPONSE_URL = "https://localhost/SLO/response";
private static final String SINGLE_LOGOUT_BINDING = Saml2MessageBinding.REDIRECT.getUrn();
private static final List<String> SIGNING_ALGORITHMS = List.of("http://www.w3.org/2001/04/xmldsig-more#rsa-sha512");
private X509Certificate certificate;
private EmbeddedDatabase db;
private JdbcAssertingPartyMetadataRepository repository;
private JdbcOperations jdbcOperations;
private final Serializer<Object> serializer = new DefaultSerializer();
@BeforeEach
public void setUp() throws Exception {
this.db = createDb();
this.jdbcOperations = new JdbcTemplate(this.db);
this.repository = new JdbcAssertingPartyMetadataRepository(this.jdbcOperations);
this.certificate = loadCertificate("rsa.crt");
}
@AfterEach
public void tearDown() {
this.db.shutdown();
}
@Test
void constructorWhenJdbcOperationsIsNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> new JdbcAssertingPartyMetadataRepository(null))
.withMessage("jdbcOperations cannot be null");
// @formatter:on
}
@Test
void findByEntityIdWhenEntityIdIsNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.repository.findByEntityId(null))
.withMessage("entityId cannot be empty");
// @formatter:on
}
@Test
void findByEntityId() throws IOException {
this.jdbcOperations.update(SAVE_SQL, ENTITY_ID, SINGLE_SIGNON_URL, SINGLE_SIGNON_BINDING,
SINGLE_SIGNON_SIGN_REQUEST, this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
this.serializer.serializeToByteArray(asCredentials(this.certificate)),
this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
AssertingPartyMetadata found = this.repository.findByEntityId(ENTITY_ID);
assertThat(found).isNotNull();
assertThat(found.getEntityId()).isEqualTo(ENTITY_ID);
assertThat(found.getSingleSignOnServiceLocation()).isEqualTo(SINGLE_SIGNON_URL);
assertThat(found.getSingleSignOnServiceBinding().getUrn()).isEqualTo(SINGLE_SIGNON_BINDING);
assertThat(found.getWantAuthnRequestsSigned()).isEqualTo(SINGLE_SIGNON_SIGN_REQUEST);
assertThat(found.getSingleLogoutServiceLocation()).isEqualTo(SINGLE_LOGOUT_URL);
assertThat(found.getSingleLogoutServiceResponseLocation()).isEqualTo(SINGLE_LOGOUT_RESPONSE_URL);
assertThat(found.getSingleLogoutServiceBinding().getUrn()).isEqualTo(SINGLE_LOGOUT_BINDING);
assertThat(found.getSigningAlgorithms()).contains(SIGNING_ALGORITHMS.get(0));
assertThat(found.getVerificationX509Credentials()).hasSize(1);
assertThat(found.getEncryptionX509Credentials()).hasSize(1);
}
@Test
void findByEntityIdWhenNotExists() {
AssertingPartyMetadata found = this.repository.findByEntityId("non-existent-entity-id");
assertThat(found).isNull();
}
@Test
void iterator() throws IOException {
this.jdbcOperations.update(SAVE_SQL, ENTITY_ID, SINGLE_SIGNON_URL, SINGLE_SIGNON_BINDING,
SINGLE_SIGNON_SIGN_REQUEST, this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
this.serializer.serializeToByteArray(asCredentials(this.certificate)),
this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
this.jdbcOperations.update(SAVE_SQL, "https://localhost/simplesaml2/saml2/idp/metadata.php", SINGLE_SIGNON_URL,
SINGLE_SIGNON_BINDING, SINGLE_SIGNON_SIGN_REQUEST,
this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
this.serializer.serializeToByteArray(asCredentials(this.certificate)),
this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
Iterator<AssertingPartyMetadata> iterator = this.repository.iterator();
AssertingPartyMetadata first = iterator.next();
assertThat(first).isNotNull();
AssertingPartyMetadata second = iterator.next();
assertThat(second).isNotNull();
assertThat(iterator.hasNext()).isFalse();
}
private static EmbeddedDatabase createDb() {
return createDb(SCHEMA_SQL_RESOURCE);
}
private static EmbeddedDatabase createDb(String schema) {
// @formatter:off
return new EmbeddedDatabaseBuilder()
.generateUniqueName(true)
.setType(EmbeddedDatabaseType.HSQL)
.setScriptEncoding("UTF-8")
.addScript(schema)
.build();
// @formatter:on
}
private X509Certificate loadCertificate(String path) {
try (InputStream is = new ClassPathResource(path).getInputStream()) {
CertificateFactory factory = CertificateFactory.getInstance("X.509");
return (X509Certificate) factory.generateCertificate(is);
}
catch (Exception ex) {
throw new RuntimeException("Error loading certificate from " + path, ex);
}
}
private Collection<Saml2X509Credential> asCredentials(X509Certificate certificate) {
return List.of(new Saml2X509Credential(certificate, Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION,
Saml2X509Credential.Saml2X509CredentialType.VERIFICATION));
}
}

View File

@ -0,0 +1,23 @@
-----BEGIN CERTIFICATE-----
MIID1zCCAr+gAwIBAgIUCzQeKBMTO0iHVW3iKmZC41haqCowDQYJKoZIhvcNAQEL
BQAwezELMAkGA1UEBhMCWFgxEjAQBgNVBAgMCVN0YXRlTmFtZTERMA8GA1UEBwwI
Q2l0eU5hbWUxFDASBgNVBAoMC0NvbXBhbnlOYW1lMRswGQYDVQQLDBJDb21wYW55
U2VjdGlvbk5hbWUxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0yMzA5MjAwODI5MDNa
Fw0zMzA5MTcwODI5MDNaMHsxCzAJBgNVBAYTAlhYMRIwEAYDVQQIDAlTdGF0ZU5h
bWUxETAPBgNVBAcMCENpdHlOYW1lMRQwEgYDVQQKDAtDb21wYW55TmFtZTEbMBkG
A1UECwwSQ29tcGFueVNlY3Rpb25OYW1lMRIwEAYDVQQDDAlsb2NhbGhvc3QwggEi
MA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDUfi4aaCotJZX6OSDjv6fxCCfc
ihSs91Z/mmN+yc1fsxVSs53SIbqUuo+Wzhv34kp8I/r03P9LWVTkFPbeDxAl75Oa
PGggxK55US0Zfy9Hj1BwWIKV3330N61emID1GDEtFKL4yJbJdreQXnIXTBL2o76V
nuV/tYozyZnb07IQ1WhUm5WDxgzM0yFudMynTczCBeZHfvharDtB8PFFhCZXW2/9
TZVVfW4oOML8EAX3hvnvYBlFl/foxXekZSwq/odOkmWCZavT2+0sburHUlOnPGUh
Qj4tHwpMRczp7VX4ptV1D2UrxsK/2B+s9FK2QSLKQ9JzAYJ6WxQjHcvET9jvAgMB
AAGjUzBRMB0GA1UdDgQWBBQjDr/1E/01pfLPD8uWF7gbaYL0TTAfBgNVHSMEGDAW
gBQjDr/1E/01pfLPD8uWF7gbaYL0TTAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3
DQEBCwUAA4IBAQAGjUuec0+0XNMCRDKZslbImdCAVsKsEWk6NpnUViDFAxL+KQuC
NW131UeHb9SCzMqRwrY4QI3nAwJQCmilL/hFM3ss4acn3WHu1yci/iKPUKeL1ec5
kCFUmqX1NpTiVaytZ/9TKEr69SMVqNfQiuW5U1bIIYTqK8xo46WpM6YNNHO3eJK6
NH0MW79Wx5ryi4i4C6afqYbVbx7tqcmy8CFeNxgZ0bFQ87SiwYXIj77b6sVYbu32
doykBQgSHLcagWASPQ73m73CWUgo+7+EqSKIQqORbgmTLPmOUh99gFIx7jmjTyHm
NBszx1ZVWuIv3mWmp626Kncyc+LLM9tvgymx
-----END CERTIFICATE-----