diff --git a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle index 8b40d415d6..b05c1bbd57 100644 --- a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle +++ b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle @@ -106,6 +106,7 @@ dependencies { provided 'jakarta.servlet:jakarta.servlet-api' optional 'com.fasterxml.jackson.core:jackson-databind' + optional 'org.springframework:spring-jdbc' testImplementation 'com.squareup.okhttp3:mockwebserver' testImplementation "org.assertj:assertj-core" @@ -118,6 +119,7 @@ dependencies { testImplementation "org.springframework:spring-test" testRuntimeOnly 'org.junit.platform:junit-platform-launcher' + testRuntimeOnly 'org.hsqldb:hsqldb' } jar { diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepository.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepository.java new file mode 100644 index 0000000000..620e6bdf2e --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepository.java @@ -0,0 +1,190 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.registration; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Types; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; +import org.springframework.core.serializer.DefaultDeserializer; +import org.springframework.core.serializer.Deserializer; +import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.PreparedStatementSetter; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.SqlParameterValue; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.AssertingPartyDetails; +import org.springframework.util.Assert; + +/** + * A JDBC implementation of {@link AssertingPartyMetadataRepository}. + * + * @author Cathy Wang + * @since 7.0 + */ +public final class JdbcAssertingPartyMetadataRepository implements AssertingPartyMetadataRepository { + + private final JdbcOperations jdbcOperations; + + private RowMapper assertingPartyMetadataRowMapper = new AssertingPartyMetadataRowMapper( + ResultSet::getBytes); + + // @formatter:off + static final String COLUMN_NAMES = "entity_id, " + + "singlesignon_url, " + + "singlesignon_binding, " + + "singlesignon_sign_request, " + + "signing_algorithms, " + + "verification_credentials, " + + "encryption_credentials, " + + "singlelogout_url, " + + "singlelogout_response_url, " + + "singlelogout_binding"; + // @formatter:on + + private static final String TABLE_NAME = "saml2_asserting_party_metadata"; + + private static final String ENTITY_ID_FILTER = "entity_id = ?"; + + // @formatter:off + private static final String LOAD_BY_ID_SQL = "SELECT " + COLUMN_NAMES + + " FROM " + TABLE_NAME + + " WHERE " + ENTITY_ID_FILTER; + + private static final String LOAD_ALL_SQL = "SELECT " + COLUMN_NAMES + + " FROM " + TABLE_NAME; + // @formatter:on + + /** + * Constructs a {@code JdbcRelyingPartyRegistrationRepository} using the provided + * parameters. + * @param jdbcOperations the JDBC operations + */ + public JdbcAssertingPartyMetadataRepository(JdbcOperations jdbcOperations) { + Assert.notNull(jdbcOperations, "jdbcOperations cannot be null"); + this.jdbcOperations = jdbcOperations; + } + + /** + * Sets the {@link RowMapper} used for mapping the current row in + * {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}. The default is + * {@link AssertingPartyMetadataRowMapper}. + * @param assertingPartyMetadataRowMapper the {@link RowMapper} used for mapping the + * current row in {@code java.sql.ResultSet} to {@link AssertingPartyMetadata} + */ + public void setAssertingPartyMetadataRowMapper(RowMapper assertingPartyMetadataRowMapper) { + Assert.notNull(assertingPartyMetadataRowMapper, "assertingPartyMetadataRowMapper cannot be null"); + this.assertingPartyMetadataRowMapper = assertingPartyMetadataRowMapper; + } + + @Override + public AssertingPartyMetadata findByEntityId(String entityId) { + Assert.hasText(entityId, "entityId cannot be empty"); + SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, entityId) }; + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); + List result = this.jdbcOperations.query(LOAD_BY_ID_SQL, pss, + this.assertingPartyMetadataRowMapper); + return !result.isEmpty() ? result.get(0) : null; + } + + @Override + public Iterator iterator() { + List result = this.jdbcOperations.query(LOAD_ALL_SQL, + this.assertingPartyMetadataRowMapper); + return result.iterator(); + } + + /** + * The default {@link RowMapper} that maps the current row in + * {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}. + */ + private static final class AssertingPartyMetadataRowMapper implements RowMapper { + + private final Log logger = LogFactory.getLog(AssertingPartyMetadataRowMapper.class); + + private final Deserializer deserializer = new DefaultDeserializer(); + + private final GetBytes getBytes; + + AssertingPartyMetadataRowMapper(GetBytes getBytes) { + this.getBytes = getBytes; + } + + @Override + public AssertingPartyMetadata mapRow(ResultSet rs, int rowNum) throws SQLException { + String entityId = rs.getString("entity_id"); + String singleSignOnUrl = rs.getString("singlesignon_url"); + Saml2MessageBinding singleSignOnBinding = Saml2MessageBinding.from(rs.getString("singlesignon_binding")); + boolean singleSignOnSignRequest = rs.getBoolean("singlesignon_sign_request"); + String singleLogoutUrl = rs.getString("singlelogout_url"); + String singleLogoutResponseUrl = rs.getString("singlelogout_response_url"); + Saml2MessageBinding singleLogoutBinding = Saml2MessageBinding.from(rs.getString("singlelogout_binding")); + byte[] signingAlgorithmsBytes = this.getBytes.getBytes(rs, "signing_algorithms"); + byte[] verificationCredentialsBytes = this.getBytes.getBytes(rs, "verification_credentials"); + byte[] encryptionCredentialsBytes = this.getBytes.getBytes(rs, "encryption_credentials"); + + AssertingPartyMetadata.Builder builder = new AssertingPartyDetails.Builder(); + try { + if (signingAlgorithmsBytes != null) { + List signingAlgorithms = (List) this.deserializer + .deserializeFromByteArray(signingAlgorithmsBytes); + builder.signingAlgorithms((algorithms) -> algorithms.addAll(signingAlgorithms)); + } + if (verificationCredentialsBytes != null) { + Collection verificationCredentials = (Collection) this.deserializer + .deserializeFromByteArray(verificationCredentialsBytes); + builder.verificationX509Credentials((credentials) -> credentials.addAll(verificationCredentials)); + } + if (encryptionCredentialsBytes != null) { + Collection encryptionCredentials = (Collection) this.deserializer + .deserializeFromByteArray(encryptionCredentialsBytes); + builder.encryptionX509Credentials((credentials) -> credentials.addAll(encryptionCredentials)); + } + } + catch (Exception ex) { + this.logger.debug(LogMessage.format("Parsing serialized credentials for entity %s failed", entityId), + ex); + return null; + } + + builder.entityId(entityId) + .wantAuthnRequestsSigned(singleSignOnSignRequest) + .singleSignOnServiceLocation(singleSignOnUrl) + .singleSignOnServiceBinding(singleSignOnBinding) + .singleLogoutServiceLocation(singleLogoutUrl) + .singleLogoutServiceBinding(singleLogoutBinding) + .singleLogoutServiceResponseLocation(singleLogoutResponseUrl); + return builder.build(); + } + + } + + private interface GetBytes { + + byte[] getBytes(ResultSet rs, String columnName) throws SQLException; + + } + +} diff --git a/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-asserting-party-metadata-schema-postgres.sql b/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-asserting-party-metadata-schema-postgres.sql new file mode 100644 index 0000000000..ffa047fe7b --- /dev/null +++ b/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-asserting-party-metadata-schema-postgres.sql @@ -0,0 +1,14 @@ +CREATE TABLE saml2_asserting_party_metadata +( + entity_id VARCHAR(1000) NOT NULL, + singlesignon_url VARCHAR(1000) NOT NULL, + singlesignon_binding VARCHAR(100), + singlesignon_sign_request boolean, + signing_algorithms BYTEA, + verification_credentials BYTEA NOT NULL, + encryption_credentials BYTEA, + singlelogout_url VARCHAR(1000), + singlelogout_response_url VARCHAR(1000), + singlelogout_binding VARCHAR(100), + PRIMARY KEY (entity_id) +); diff --git a/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql b/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql new file mode 100644 index 0000000000..2fd6cb8cdf --- /dev/null +++ b/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql @@ -0,0 +1,14 @@ +CREATE TABLE saml2_asserting_party_metadata +( + entity_id VARCHAR(1000) NOT NULL, + singlesignon_url VARCHAR(1000) NOT NULL, + singlesignon_binding VARCHAR(100), + singlesignon_sign_request boolean, + signing_algorithms blob, + verification_credentials blob NOT NULL, + encryption_credentials blob, + singlelogout_url VARCHAR(1000), + singlelogout_response_url VARCHAR(1000), + singlelogout_binding VARCHAR(100), + PRIMARY KEY (entity_id) +); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepositoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepositoryTests.java new file mode 100644 index 0000000000..c734bcb236 --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepositoryTests.java @@ -0,0 +1,177 @@ +package org.springframework.security.saml2.provider.service.registration; + +import java.io.IOException; +import java.io.InputStream; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.serializer.DefaultSerializer; +import org.springframework.core.serializer.Serializer; +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; +import org.springframework.security.saml2.core.Saml2X509Credential; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link JdbcAssertingPartyMetadataRepository} + */ +class JdbcAssertingPartyMetadataRepositoryTests { + + private static final String SCHEMA_SQL_RESOURCE = "org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql"; + + private static final String SAVE_SQL = "INSERT INTO saml2_asserting_party_metadata (" + + JdbcAssertingPartyMetadataRepository.COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; + + private static final String ENTITY_ID = "https://localhost/simplesaml/saml2/idp/metadata.php"; + + private static final String SINGLE_SIGNON_URL = "https://localhost/SSO"; + + private static final String SINGLE_SIGNON_BINDING = Saml2MessageBinding.REDIRECT.getUrn(); + + private static final boolean SINGLE_SIGNON_SIGN_REQUEST = false; + + private static final String SINGLE_LOGOUT_URL = "https://localhost/SLO"; + + private static final String SINGLE_LOGOUT_RESPONSE_URL = "https://localhost/SLO/response"; + + private static final String SINGLE_LOGOUT_BINDING = Saml2MessageBinding.REDIRECT.getUrn(); + + private static final List SIGNING_ALGORITHMS = List.of("http://www.w3.org/2001/04/xmldsig-more#rsa-sha512"); + + private X509Certificate certificate; + + private EmbeddedDatabase db; + + private JdbcAssertingPartyMetadataRepository repository; + + private JdbcOperations jdbcOperations; + + private final Serializer serializer = new DefaultSerializer(); + + @BeforeEach + public void setUp() throws Exception { + this.db = createDb(); + this.jdbcOperations = new JdbcTemplate(this.db); + this.repository = new JdbcAssertingPartyMetadataRepository(this.jdbcOperations); + this.certificate = loadCertificate("rsa.crt"); + } + + @AfterEach + public void tearDown() { + this.db.shutdown(); + } + + @Test + void constructorWhenJdbcOperationsIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new JdbcAssertingPartyMetadataRepository(null)) + .withMessage("jdbcOperations cannot be null"); + // @formatter:on + } + + @Test + void findByEntityIdWhenEntityIdIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.repository.findByEntityId(null)) + .withMessage("entityId cannot be empty"); + // @formatter:on + } + + @Test + void findByEntityId() throws IOException { + this.jdbcOperations.update(SAVE_SQL, ENTITY_ID, SINGLE_SIGNON_URL, SINGLE_SIGNON_BINDING, + SINGLE_SIGNON_SIGN_REQUEST, this.serializer.serializeToByteArray(SIGNING_ALGORITHMS), + this.serializer.serializeToByteArray(asCredentials(this.certificate)), + this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL, + SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING); + + AssertingPartyMetadata found = this.repository.findByEntityId(ENTITY_ID); + + assertThat(found).isNotNull(); + assertThat(found.getEntityId()).isEqualTo(ENTITY_ID); + assertThat(found.getSingleSignOnServiceLocation()).isEqualTo(SINGLE_SIGNON_URL); + assertThat(found.getSingleSignOnServiceBinding().getUrn()).isEqualTo(SINGLE_SIGNON_BINDING); + assertThat(found.getWantAuthnRequestsSigned()).isEqualTo(SINGLE_SIGNON_SIGN_REQUEST); + assertThat(found.getSingleLogoutServiceLocation()).isEqualTo(SINGLE_LOGOUT_URL); + assertThat(found.getSingleLogoutServiceResponseLocation()).isEqualTo(SINGLE_LOGOUT_RESPONSE_URL); + assertThat(found.getSingleLogoutServiceBinding().getUrn()).isEqualTo(SINGLE_LOGOUT_BINDING); + assertThat(found.getSigningAlgorithms()).contains(SIGNING_ALGORITHMS.get(0)); + assertThat(found.getVerificationX509Credentials()).hasSize(1); + assertThat(found.getEncryptionX509Credentials()).hasSize(1); + } + + @Test + void findByEntityIdWhenNotExists() { + AssertingPartyMetadata found = this.repository.findByEntityId("non-existent-entity-id"); + assertThat(found).isNull(); + } + + @Test + void iterator() throws IOException { + this.jdbcOperations.update(SAVE_SQL, ENTITY_ID, SINGLE_SIGNON_URL, SINGLE_SIGNON_BINDING, + SINGLE_SIGNON_SIGN_REQUEST, this.serializer.serializeToByteArray(SIGNING_ALGORITHMS), + this.serializer.serializeToByteArray(asCredentials(this.certificate)), + this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL, + SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING); + + this.jdbcOperations.update(SAVE_SQL, "https://localhost/simplesaml2/saml2/idp/metadata.php", SINGLE_SIGNON_URL, + SINGLE_SIGNON_BINDING, SINGLE_SIGNON_SIGN_REQUEST, + this.serializer.serializeToByteArray(SIGNING_ALGORITHMS), + this.serializer.serializeToByteArray(asCredentials(this.certificate)), + this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL, + SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING); + + Iterator iterator = this.repository.iterator(); + AssertingPartyMetadata first = iterator.next(); + assertThat(first).isNotNull(); + AssertingPartyMetadata second = iterator.next(); + assertThat(second).isNotNull(); + assertThat(iterator.hasNext()).isFalse(); + } + + private static EmbeddedDatabase createDb() { + return createDb(SCHEMA_SQL_RESOURCE); + } + + private static EmbeddedDatabase createDb(String schema) { + // @formatter:off + return new EmbeddedDatabaseBuilder() + .generateUniqueName(true) + .setType(EmbeddedDatabaseType.HSQL) + .setScriptEncoding("UTF-8") + .addScript(schema) + .build(); + // @formatter:on + } + + private X509Certificate loadCertificate(String path) { + try (InputStream is = new ClassPathResource(path).getInputStream()) { + CertificateFactory factory = CertificateFactory.getInstance("X.509"); + return (X509Certificate) factory.generateCertificate(is); + } + catch (Exception ex) { + throw new RuntimeException("Error loading certificate from " + path, ex); + } + } + + private Collection asCredentials(X509Certificate certificate) { + return List.of(new Saml2X509Credential(certificate, Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION, + Saml2X509Credential.Saml2X509CredentialType.VERIFICATION)); + } + +} diff --git a/saml2/saml2-service-provider/src/test/resources/rsa.crt b/saml2/saml2-service-provider/src/test/resources/rsa.crt new file mode 100644 index 0000000000..aa147065de --- /dev/null +++ b/saml2/saml2-service-provider/src/test/resources/rsa.crt @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIID1zCCAr+gAwIBAgIUCzQeKBMTO0iHVW3iKmZC41haqCowDQYJKoZIhvcNAQEL +BQAwezELMAkGA1UEBhMCWFgxEjAQBgNVBAgMCVN0YXRlTmFtZTERMA8GA1UEBwwI +Q2l0eU5hbWUxFDASBgNVBAoMC0NvbXBhbnlOYW1lMRswGQYDVQQLDBJDb21wYW55 +U2VjdGlvbk5hbWUxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0yMzA5MjAwODI5MDNa +Fw0zMzA5MTcwODI5MDNaMHsxCzAJBgNVBAYTAlhYMRIwEAYDVQQIDAlTdGF0ZU5h +bWUxETAPBgNVBAcMCENpdHlOYW1lMRQwEgYDVQQKDAtDb21wYW55TmFtZTEbMBkG +A1UECwwSQ29tcGFueVNlY3Rpb25OYW1lMRIwEAYDVQQDDAlsb2NhbGhvc3QwggEi +MA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDUfi4aaCotJZX6OSDjv6fxCCfc +ihSs91Z/mmN+yc1fsxVSs53SIbqUuo+Wzhv34kp8I/r03P9LWVTkFPbeDxAl75Oa +PGggxK55US0Zfy9Hj1BwWIKV3330N61emID1GDEtFKL4yJbJdreQXnIXTBL2o76V +nuV/tYozyZnb07IQ1WhUm5WDxgzM0yFudMynTczCBeZHfvharDtB8PFFhCZXW2/9 +TZVVfW4oOML8EAX3hvnvYBlFl/foxXekZSwq/odOkmWCZavT2+0sburHUlOnPGUh +Qj4tHwpMRczp7VX4ptV1D2UrxsK/2B+s9FK2QSLKQ9JzAYJ6WxQjHcvET9jvAgMB +AAGjUzBRMB0GA1UdDgQWBBQjDr/1E/01pfLPD8uWF7gbaYL0TTAfBgNVHSMEGDAW +gBQjDr/1E/01pfLPD8uWF7gbaYL0TTAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3 +DQEBCwUAA4IBAQAGjUuec0+0XNMCRDKZslbImdCAVsKsEWk6NpnUViDFAxL+KQuC +NW131UeHb9SCzMqRwrY4QI3nAwJQCmilL/hFM3ss4acn3WHu1yci/iKPUKeL1ec5 +kCFUmqX1NpTiVaytZ/9TKEr69SMVqNfQiuW5U1bIIYTqK8xo46WpM6YNNHO3eJK6 +NH0MW79Wx5ryi4i4C6afqYbVbx7tqcmy8CFeNxgZ0bFQ87SiwYXIj77b6sVYbu32 +doykBQgSHLcagWASPQ73m73CWUgo+7+EqSKIQqORbgmTLPmOUh99gFIx7jmjTyHm +NBszx1ZVWuIv3mWmp626Kncyc+LLM9tvgymx +-----END CERTIFICATE-----