From 2bd05128ec614243f563a70fe0b6c09e726a29be Mon Sep 17 00:00:00 2001 From: Josh Cummings <3627351+jzheaux@users.noreply.github.com> Date: Thu, 5 Jun 2025 14:44:19 -0600 Subject: [PATCH] Add JdbcAssertingPartyMetadataRepository#save Issue gh-16012 Co-Authored-By: chao.wang --- .../JdbcAssertingPartyMetadataRepository.java | 78 ++++++++++++ ...AssertingPartyMetadataRepositoryTests.java | 120 ++++++------------ 2 files changed, 117 insertions(+), 81 deletions(-) diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepository.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepository.java index 620e6bdf2e..88f4a59a71 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepository.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepository.java @@ -19,16 +19,20 @@ package org.springframework.security.saml2.provider.service.registration; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Types; +import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; import java.util.List; +import java.util.function.Function; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.core.log.LogMessage; import org.springframework.core.serializer.DefaultDeserializer; +import org.springframework.core.serializer.DefaultSerializer; import org.springframework.core.serializer.Deserializer; +import org.springframework.core.serializer.Serializer; import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; import org.springframework.jdbc.core.JdbcOperations; import org.springframework.jdbc.core.PreparedStatementSetter; @@ -37,6 +41,7 @@ import org.springframework.jdbc.core.SqlParameterValue; import org.springframework.security.saml2.core.Saml2X509Credential; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.AssertingPartyDetails; import org.springframework.util.Assert; +import org.springframework.util.function.ThrowingFunction; /** * A JDBC implementation of {@link AssertingPartyMetadataRepository}. @@ -51,6 +56,8 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart private RowMapper assertingPartyMetadataRowMapper = new AssertingPartyMetadataRowMapper( ResultSet::getBytes); + private final AssertingPartyMetadataParametersMapper assertingPartyMetadataParametersMapper = new AssertingPartyMetadataParametersMapper(); + // @formatter:off static final String COLUMN_NAMES = "entity_id, " + "singlesignon_url, " @@ -77,6 +84,25 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart + " FROM " + TABLE_NAME; // @formatter:on + // @formatter:off + private static final String SAVE_CREDENTIAL_RECORD_SQL = "INSERT INTO " + TABLE_NAME + + " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; + // @formatter:on + + // @formatter:off + private static final String UPDATE_CREDENTIAL_RECORD_SQL = "UPDATE " + TABLE_NAME + + " SET singlesignon_url = ?, " + + "singlesignon_binding = ?, " + + "singlesignon_sign_request = ?, " + + "signing_algorithms = ?, " + + "verification_credentials = ?, " + + "encryption_credentials = ?, " + + "singlelogout_url = ?, " + + "singlelogout_response_url = ?, " + + "singlelogout_binding = ?" + + " WHERE " + ENTITY_ID_FILTER; + // @formatter:on + /** * Constructs a {@code JdbcRelyingPartyRegistrationRepository} using the provided * parameters. @@ -116,6 +142,30 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart return result.iterator(); } + /** + * Persist this {@link AssertingPartyMetadata} + * @param metadata the metadata to persist + */ + public void save(AssertingPartyMetadata metadata) { + Assert.notNull(metadata, "metadata cannot be null"); + int rows = updateCredentialRecord(metadata); + if (rows == 0) { + insertCredentialRecord(metadata); + } + } + + private void insertCredentialRecord(AssertingPartyMetadata metadata) { + List parameters = this.assertingPartyMetadataParametersMapper.apply(metadata); + this.jdbcOperations.update(SAVE_CREDENTIAL_RECORD_SQL, parameters.toArray()); + } + + private int updateCredentialRecord(AssertingPartyMetadata metadata) { + List parameters = this.assertingPartyMetadataParametersMapper.apply(metadata); + SqlParameterValue credentialId = parameters.remove(0); + parameters.add(credentialId); + return this.jdbcOperations.update(UPDATE_CREDENTIAL_RECORD_SQL, parameters.toArray()); + } + /** * The default {@link RowMapper} that maps the current row in * {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}. @@ -181,6 +231,34 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart } + private static class AssertingPartyMetadataParametersMapper + implements Function> { + + private final Serializer serializer = new DefaultSerializer(); + + @Override + public List apply(AssertingPartyMetadata record) { + List parameters = new ArrayList<>(); + + parameters.add(new SqlParameterValue(Types.VARCHAR, record.getEntityId())); + parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceLocation())); + parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceBinding().getUrn())); + parameters.add(new SqlParameterValue(Types.BOOLEAN, record.getWantAuthnRequestsSigned())); + ThrowingFunction, byte[]> algorithms = this.serializer::serializeToByteArray; + parameters.add(new SqlParameterValue(Types.BLOB, algorithms.apply(record.getSigningAlgorithms()))); + ThrowingFunction, byte[]> credentials = this.serializer::serializeToByteArray; + parameters + .add(new SqlParameterValue(Types.BLOB, credentials.apply(record.getVerificationX509Credentials()))); + parameters.add(new SqlParameterValue(Types.BLOB, credentials.apply(record.getEncryptionX509Credentials()))); + parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceLocation())); + parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceResponseLocation())); + parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceBinding().getUrn())); + + return parameters; + } + + } + private interface GetBytes { byte[] getBytes(ResultSet rs, String columnName) throws SQLException; diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepositoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepositoryTests.java index e8a76ee335..785e4a12f3 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepositoryTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepositoryTests.java @@ -16,27 +16,17 @@ package org.springframework.security.saml2.provider.service.registration; -import java.io.IOException; -import java.io.InputStream; -import java.security.cert.CertificateFactory; -import java.security.cert.X509Certificate; -import java.util.Collection; import java.util.Iterator; -import java.util.List; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.springframework.core.io.ClassPathResource; -import org.springframework.core.serializer.DefaultSerializer; -import org.springframework.core.serializer.Serializer; import org.springframework.jdbc.core.JdbcOperations; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; -import org.springframework.security.saml2.core.Saml2X509Credential; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; @@ -48,41 +38,21 @@ class JdbcAssertingPartyMetadataRepositoryTests { private static final String SCHEMA_SQL_RESOURCE = "org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql"; - private static final String SAVE_SQL = "INSERT INTO saml2_asserting_party_metadata (" - + JdbcAssertingPartyMetadataRepository.COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; - - private static final String ENTITY_ID = "https://localhost/simplesaml/saml2/idp/metadata.php"; - - private static final String SINGLE_SIGNON_URL = "https://localhost/SSO"; - - private static final String SINGLE_SIGNON_BINDING = Saml2MessageBinding.REDIRECT.getUrn(); - - private static final boolean SINGLE_SIGNON_SIGN_REQUEST = false; - - private static final String SINGLE_LOGOUT_URL = "https://localhost/SLO"; - - private static final String SINGLE_LOGOUT_RESPONSE_URL = "https://localhost/SLO/response"; - - private static final String SINGLE_LOGOUT_BINDING = Saml2MessageBinding.REDIRECT.getUrn(); - - private static final List SIGNING_ALGORITHMS = List.of("http://www.w3.org/2001/04/xmldsig-more#rsa-sha512"); - - private X509Certificate certificate; - private EmbeddedDatabase db; private JdbcAssertingPartyMetadataRepository repository; private JdbcOperations jdbcOperations; - private final Serializer serializer = new DefaultSerializer(); + private final AssertingPartyMetadata metadata = TestRelyingPartyRegistrations.full() + .build() + .getAssertingPartyMetadata(); @BeforeEach void setUp() { this.db = createDb(); this.jdbcOperations = new JdbcTemplate(this.db); this.repository = new JdbcAssertingPartyMetadataRepository(this.jdbcOperations); - this.certificate = loadCertificate("rsa.crt"); } @AfterEach @@ -109,26 +79,12 @@ class JdbcAssertingPartyMetadataRepositoryTests { } @Test - void findByEntityId() throws IOException { - this.jdbcOperations.update(SAVE_SQL, ENTITY_ID, SINGLE_SIGNON_URL, SINGLE_SIGNON_BINDING, - SINGLE_SIGNON_SIGN_REQUEST, this.serializer.serializeToByteArray(SIGNING_ALGORITHMS), - this.serializer.serializeToByteArray(asCredentials(this.certificate)), - this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL, - SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING); + void findByEntityId() { + this.repository.save(this.metadata); - AssertingPartyMetadata found = this.repository.findByEntityId(ENTITY_ID); + AssertingPartyMetadata found = this.repository.findByEntityId(this.metadata.getEntityId()); - assertThat(found).isNotNull(); - assertThat(found.getEntityId()).isEqualTo(ENTITY_ID); - assertThat(found.getSingleSignOnServiceLocation()).isEqualTo(SINGLE_SIGNON_URL); - assertThat(found.getSingleSignOnServiceBinding().getUrn()).isEqualTo(SINGLE_SIGNON_BINDING); - assertThat(found.getWantAuthnRequestsSigned()).isEqualTo(SINGLE_SIGNON_SIGN_REQUEST); - assertThat(found.getSingleLogoutServiceLocation()).isEqualTo(SINGLE_LOGOUT_URL); - assertThat(found.getSingleLogoutServiceResponseLocation()).isEqualTo(SINGLE_LOGOUT_RESPONSE_URL); - assertThat(found.getSingleLogoutServiceBinding().getUrn()).isEqualTo(SINGLE_LOGOUT_BINDING); - assertThat(found.getSigningAlgorithms()).contains(SIGNING_ALGORITHMS.get(0)); - assertThat(found.getVerificationX509Credentials()).hasSize(1); - assertThat(found.getEncryptionX509Credentials()).hasSize(1); + assertAssertingPartyEquals(found, this.metadata); } @Test @@ -138,28 +94,30 @@ class JdbcAssertingPartyMetadataRepositoryTests { } @Test - void iterator() throws IOException { - this.jdbcOperations.update(SAVE_SQL, ENTITY_ID, SINGLE_SIGNON_URL, SINGLE_SIGNON_BINDING, - SINGLE_SIGNON_SIGN_REQUEST, this.serializer.serializeToByteArray(SIGNING_ALGORITHMS), - this.serializer.serializeToByteArray(asCredentials(this.certificate)), - this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL, - SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING); - - this.jdbcOperations.update(SAVE_SQL, "https://localhost/simplesaml2/saml2/idp/metadata.php", SINGLE_SIGNON_URL, - SINGLE_SIGNON_BINDING, SINGLE_SIGNON_SIGN_REQUEST, - this.serializer.serializeToByteArray(SIGNING_ALGORITHMS), - this.serializer.serializeToByteArray(asCredentials(this.certificate)), - this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL, - SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING); + void iterator() { + AssertingPartyMetadata second = RelyingPartyRegistration.withAssertingPartyMetadata(this.metadata) + .assertingPartyMetadata((a) -> a.entityId("https://example.org/idp")) + .build() + .getAssertingPartyMetadata(); + this.repository.save(this.metadata); + this.repository.save(second); Iterator iterator = this.repository.iterator(); - AssertingPartyMetadata first = iterator.next(); - assertThat(first).isNotNull(); - AssertingPartyMetadata second = iterator.next(); - assertThat(second).isNotNull(); + + assertAssertingPartyEquals(iterator.next(), this.metadata); + assertAssertingPartyEquals(iterator.next(), second); assertThat(iterator.hasNext()).isFalse(); } + @Test + void saveWhenExistingThenUpdates() { + this.repository.save(this.metadata); + boolean existing = this.metadata.getWantAuthnRequestsSigned(); + this.repository.save(this.metadata.mutate().wantAuthnRequestsSigned(!existing).build()); + boolean updated = this.repository.findByEntityId(this.metadata.getEntityId()).getWantAuthnRequestsSigned(); + assertThat(existing).isNotEqualTo(updated); + } + private static EmbeddedDatabase createDb() { return createDb(SCHEMA_SQL_RESOURCE); } @@ -175,19 +133,19 @@ class JdbcAssertingPartyMetadataRepositoryTests { // @formatter:on } - private X509Certificate loadCertificate(String path) { - try (InputStream is = new ClassPathResource(path).getInputStream()) { - CertificateFactory factory = CertificateFactory.getInstance("X.509"); - return (X509Certificate) factory.generateCertificate(is); - } - catch (Exception ex) { - throw new RuntimeException("Error loading certificate from " + path, ex); - } - } - - private Collection asCredentials(X509Certificate certificate) { - return List.of(new Saml2X509Credential(certificate, Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION, - Saml2X509Credential.Saml2X509CredentialType.VERIFICATION)); + private void assertAssertingPartyEquals(AssertingPartyMetadata found, AssertingPartyMetadata expected) { + assertThat(found).isNotNull(); + assertThat(found.getEntityId()).isEqualTo(expected.getEntityId()); + assertThat(found.getSingleSignOnServiceLocation()).isEqualTo(expected.getSingleSignOnServiceLocation()); + assertThat(found.getSingleSignOnServiceBinding()).isEqualTo(expected.getSingleSignOnServiceBinding()); + assertThat(found.getWantAuthnRequestsSigned()).isEqualTo(expected.getWantAuthnRequestsSigned()); + assertThat(found.getSingleLogoutServiceLocation()).isEqualTo(expected.getSingleLogoutServiceLocation()); + assertThat(found.getSingleLogoutServiceResponseLocation()) + .isEqualTo(expected.getSingleLogoutServiceResponseLocation()); + assertThat(found.getSingleLogoutServiceBinding()).isEqualTo(expected.getSingleLogoutServiceBinding()); + assertThat(found.getSigningAlgorithms()).containsAll(expected.getSigningAlgorithms()); + assertThat(found.getVerificationX509Credentials()).containsAll(expected.getVerificationX509Credentials()); + assertThat(found.getEncryptionX509Credentials()).containsAll(expected.getEncryptionX509Credentials()); } }