Add JdbcAssertingPartyMetadataRepository#save

Issue gh-16012

Co-Authored-By: chao.wang <chao.wang@zatech.com>
This commit is contained in:
Josh Cummings 2025-06-05 14:44:19 -06:00
parent e2e42a5580
commit 2bd05128ec
2 changed files with 117 additions and 81 deletions

View File

@ -19,16 +19,20 @@ package org.springframework.security.saml2.provider.service.registration;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.core.serializer.DefaultDeserializer;
import org.springframework.core.serializer.DefaultSerializer;
import org.springframework.core.serializer.Deserializer;
import org.springframework.core.serializer.Serializer;
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.PreparedStatementSetter;
@ -37,6 +41,7 @@ import org.springframework.jdbc.core.SqlParameterValue;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.AssertingPartyDetails;
import org.springframework.util.Assert;
import org.springframework.util.function.ThrowingFunction;
/**
* A JDBC implementation of {@link AssertingPartyMetadataRepository}.
@ -51,6 +56,8 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
private RowMapper<AssertingPartyMetadata> assertingPartyMetadataRowMapper = new AssertingPartyMetadataRowMapper(
ResultSet::getBytes);
private final AssertingPartyMetadataParametersMapper assertingPartyMetadataParametersMapper = new AssertingPartyMetadataParametersMapper();
// @formatter:off
static final String COLUMN_NAMES = "entity_id, "
+ "singlesignon_url, "
@ -77,6 +84,25 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
+ " FROM " + TABLE_NAME;
// @formatter:on
// @formatter:off
private static final String SAVE_CREDENTIAL_RECORD_SQL = "INSERT INTO " + TABLE_NAME
+ " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
// @formatter:on
// @formatter:off
private static final String UPDATE_CREDENTIAL_RECORD_SQL = "UPDATE " + TABLE_NAME
+ " SET singlesignon_url = ?, "
+ "singlesignon_binding = ?, "
+ "singlesignon_sign_request = ?, "
+ "signing_algorithms = ?, "
+ "verification_credentials = ?, "
+ "encryption_credentials = ?, "
+ "singlelogout_url = ?, "
+ "singlelogout_response_url = ?, "
+ "singlelogout_binding = ?"
+ " WHERE " + ENTITY_ID_FILTER;
// @formatter:on
/**
* Constructs a {@code JdbcRelyingPartyRegistrationRepository} using the provided
* parameters.
@ -116,6 +142,30 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
return result.iterator();
}
/**
* Persist this {@link AssertingPartyMetadata}
* @param metadata the metadata to persist
*/
public void save(AssertingPartyMetadata metadata) {
Assert.notNull(metadata, "metadata cannot be null");
int rows = updateCredentialRecord(metadata);
if (rows == 0) {
insertCredentialRecord(metadata);
}
}
private void insertCredentialRecord(AssertingPartyMetadata metadata) {
List<SqlParameterValue> parameters = this.assertingPartyMetadataParametersMapper.apply(metadata);
this.jdbcOperations.update(SAVE_CREDENTIAL_RECORD_SQL, parameters.toArray());
}
private int updateCredentialRecord(AssertingPartyMetadata metadata) {
List<SqlParameterValue> parameters = this.assertingPartyMetadataParametersMapper.apply(metadata);
SqlParameterValue credentialId = parameters.remove(0);
parameters.add(credentialId);
return this.jdbcOperations.update(UPDATE_CREDENTIAL_RECORD_SQL, parameters.toArray());
}
/**
* The default {@link RowMapper} that maps the current row in
* {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}.
@ -181,6 +231,34 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
}
private static class AssertingPartyMetadataParametersMapper
implements Function<AssertingPartyMetadata, List<SqlParameterValue>> {
private final Serializer<Object> serializer = new DefaultSerializer();
@Override
public List<SqlParameterValue> apply(AssertingPartyMetadata record) {
List<SqlParameterValue> parameters = new ArrayList<>();
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getEntityId()));
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceLocation()));
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceBinding().getUrn()));
parameters.add(new SqlParameterValue(Types.BOOLEAN, record.getWantAuthnRequestsSigned()));
ThrowingFunction<List<String>, byte[]> algorithms = this.serializer::serializeToByteArray;
parameters.add(new SqlParameterValue(Types.BLOB, algorithms.apply(record.getSigningAlgorithms())));
ThrowingFunction<Collection<Saml2X509Credential>, byte[]> credentials = this.serializer::serializeToByteArray;
parameters
.add(new SqlParameterValue(Types.BLOB, credentials.apply(record.getVerificationX509Credentials())));
parameters.add(new SqlParameterValue(Types.BLOB, credentials.apply(record.getEncryptionX509Credentials())));
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceLocation()));
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceResponseLocation()));
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceBinding().getUrn()));
return parameters;
}
}
private interface GetBytes {
byte[] getBytes(ResultSet rs, String columnName) throws SQLException;

View File

@ -16,27 +16,17 @@
package org.springframework.security.saml2.provider.service.registration;
import java.io.IOException;
import java.io.InputStream;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.serializer.DefaultSerializer;
import org.springframework.core.serializer.Serializer;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import org.springframework.security.saml2.core.Saml2X509Credential;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@ -48,41 +38,21 @@ class JdbcAssertingPartyMetadataRepositoryTests {
private static final String SCHEMA_SQL_RESOURCE = "org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql";
private static final String SAVE_SQL = "INSERT INTO saml2_asserting_party_metadata ("
+ JdbcAssertingPartyMetadataRepository.COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
private static final String ENTITY_ID = "https://localhost/simplesaml/saml2/idp/metadata.php";
private static final String SINGLE_SIGNON_URL = "https://localhost/SSO";
private static final String SINGLE_SIGNON_BINDING = Saml2MessageBinding.REDIRECT.getUrn();
private static final boolean SINGLE_SIGNON_SIGN_REQUEST = false;
private static final String SINGLE_LOGOUT_URL = "https://localhost/SLO";
private static final String SINGLE_LOGOUT_RESPONSE_URL = "https://localhost/SLO/response";
private static final String SINGLE_LOGOUT_BINDING = Saml2MessageBinding.REDIRECT.getUrn();
private static final List<String> SIGNING_ALGORITHMS = List.of("http://www.w3.org/2001/04/xmldsig-more#rsa-sha512");
private X509Certificate certificate;
private EmbeddedDatabase db;
private JdbcAssertingPartyMetadataRepository repository;
private JdbcOperations jdbcOperations;
private final Serializer<Object> serializer = new DefaultSerializer();
private final AssertingPartyMetadata metadata = TestRelyingPartyRegistrations.full()
.build()
.getAssertingPartyMetadata();
@BeforeEach
void setUp() {
this.db = createDb();
this.jdbcOperations = new JdbcTemplate(this.db);
this.repository = new JdbcAssertingPartyMetadataRepository(this.jdbcOperations);
this.certificate = loadCertificate("rsa.crt");
}
@AfterEach
@ -109,26 +79,12 @@ class JdbcAssertingPartyMetadataRepositoryTests {
}
@Test
void findByEntityId() throws IOException {
this.jdbcOperations.update(SAVE_SQL, ENTITY_ID, SINGLE_SIGNON_URL, SINGLE_SIGNON_BINDING,
SINGLE_SIGNON_SIGN_REQUEST, this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
this.serializer.serializeToByteArray(asCredentials(this.certificate)),
this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
void findByEntityId() {
this.repository.save(this.metadata);
AssertingPartyMetadata found = this.repository.findByEntityId(ENTITY_ID);
AssertingPartyMetadata found = this.repository.findByEntityId(this.metadata.getEntityId());
assertThat(found).isNotNull();
assertThat(found.getEntityId()).isEqualTo(ENTITY_ID);
assertThat(found.getSingleSignOnServiceLocation()).isEqualTo(SINGLE_SIGNON_URL);
assertThat(found.getSingleSignOnServiceBinding().getUrn()).isEqualTo(SINGLE_SIGNON_BINDING);
assertThat(found.getWantAuthnRequestsSigned()).isEqualTo(SINGLE_SIGNON_SIGN_REQUEST);
assertThat(found.getSingleLogoutServiceLocation()).isEqualTo(SINGLE_LOGOUT_URL);
assertThat(found.getSingleLogoutServiceResponseLocation()).isEqualTo(SINGLE_LOGOUT_RESPONSE_URL);
assertThat(found.getSingleLogoutServiceBinding().getUrn()).isEqualTo(SINGLE_LOGOUT_BINDING);
assertThat(found.getSigningAlgorithms()).contains(SIGNING_ALGORITHMS.get(0));
assertThat(found.getVerificationX509Credentials()).hasSize(1);
assertThat(found.getEncryptionX509Credentials()).hasSize(1);
assertAssertingPartyEquals(found, this.metadata);
}
@Test
@ -138,28 +94,30 @@ class JdbcAssertingPartyMetadataRepositoryTests {
}
@Test
void iterator() throws IOException {
this.jdbcOperations.update(SAVE_SQL, ENTITY_ID, SINGLE_SIGNON_URL, SINGLE_SIGNON_BINDING,
SINGLE_SIGNON_SIGN_REQUEST, this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
this.serializer.serializeToByteArray(asCredentials(this.certificate)),
this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
this.jdbcOperations.update(SAVE_SQL, "https://localhost/simplesaml2/saml2/idp/metadata.php", SINGLE_SIGNON_URL,
SINGLE_SIGNON_BINDING, SINGLE_SIGNON_SIGN_REQUEST,
this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
this.serializer.serializeToByteArray(asCredentials(this.certificate)),
this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
void iterator() {
AssertingPartyMetadata second = RelyingPartyRegistration.withAssertingPartyMetadata(this.metadata)
.assertingPartyMetadata((a) -> a.entityId("https://example.org/idp"))
.build()
.getAssertingPartyMetadata();
this.repository.save(this.metadata);
this.repository.save(second);
Iterator<AssertingPartyMetadata> iterator = this.repository.iterator();
AssertingPartyMetadata first = iterator.next();
assertThat(first).isNotNull();
AssertingPartyMetadata second = iterator.next();
assertThat(second).isNotNull();
assertAssertingPartyEquals(iterator.next(), this.metadata);
assertAssertingPartyEquals(iterator.next(), second);
assertThat(iterator.hasNext()).isFalse();
}
@Test
void saveWhenExistingThenUpdates() {
this.repository.save(this.metadata);
boolean existing = this.metadata.getWantAuthnRequestsSigned();
this.repository.save(this.metadata.mutate().wantAuthnRequestsSigned(!existing).build());
boolean updated = this.repository.findByEntityId(this.metadata.getEntityId()).getWantAuthnRequestsSigned();
assertThat(existing).isNotEqualTo(updated);
}
private static EmbeddedDatabase createDb() {
return createDb(SCHEMA_SQL_RESOURCE);
}
@ -175,19 +133,19 @@ class JdbcAssertingPartyMetadataRepositoryTests {
// @formatter:on
}
private X509Certificate loadCertificate(String path) {
try (InputStream is = new ClassPathResource(path).getInputStream()) {
CertificateFactory factory = CertificateFactory.getInstance("X.509");
return (X509Certificate) factory.generateCertificate(is);
}
catch (Exception ex) {
throw new RuntimeException("Error loading certificate from " + path, ex);
}
}
private Collection<Saml2X509Credential> asCredentials(X509Certificate certificate) {
return List.of(new Saml2X509Credential(certificate, Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION,
Saml2X509Credential.Saml2X509CredentialType.VERIFICATION));
private void assertAssertingPartyEquals(AssertingPartyMetadata found, AssertingPartyMetadata expected) {
assertThat(found).isNotNull();
assertThat(found.getEntityId()).isEqualTo(expected.getEntityId());
assertThat(found.getSingleSignOnServiceLocation()).isEqualTo(expected.getSingleSignOnServiceLocation());
assertThat(found.getSingleSignOnServiceBinding()).isEqualTo(expected.getSingleSignOnServiceBinding());
assertThat(found.getWantAuthnRequestsSigned()).isEqualTo(expected.getWantAuthnRequestsSigned());
assertThat(found.getSingleLogoutServiceLocation()).isEqualTo(expected.getSingleLogoutServiceLocation());
assertThat(found.getSingleLogoutServiceResponseLocation())
.isEqualTo(expected.getSingleLogoutServiceResponseLocation());
assertThat(found.getSingleLogoutServiceBinding()).isEqualTo(expected.getSingleLogoutServiceBinding());
assertThat(found.getSigningAlgorithms()).containsAll(expected.getSigningAlgorithms());
assertThat(found.getVerificationX509Credentials()).containsAll(expected.getVerificationX509Credentials());
assertThat(found.getEncryptionX509Credentials()).containsAll(expected.getEncryptionX509Credentials());
}
}