mirror of
https://github.com/spring-projects/spring-security.git
synced 2025-06-23 20:42:14 +00:00
Add JdbcAssertingPartyMetadataRepository#save
Issue gh-16012 Co-Authored-By: chao.wang <chao.wang@zatech.com>
This commit is contained in:
parent
e2e42a5580
commit
2bd05128ec
@ -19,16 +19,20 @@ package org.springframework.security.saml2.provider.service.registration;
|
|||||||
import java.sql.ResultSet;
|
import java.sql.ResultSet;
|
||||||
import java.sql.SQLException;
|
import java.sql.SQLException;
|
||||||
import java.sql.Types;
|
import java.sql.Types;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.function.Function;
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
|
||||||
import org.springframework.core.log.LogMessage;
|
import org.springframework.core.log.LogMessage;
|
||||||
import org.springframework.core.serializer.DefaultDeserializer;
|
import org.springframework.core.serializer.DefaultDeserializer;
|
||||||
|
import org.springframework.core.serializer.DefaultSerializer;
|
||||||
import org.springframework.core.serializer.Deserializer;
|
import org.springframework.core.serializer.Deserializer;
|
||||||
|
import org.springframework.core.serializer.Serializer;
|
||||||
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
|
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
|
||||||
import org.springframework.jdbc.core.JdbcOperations;
|
import org.springframework.jdbc.core.JdbcOperations;
|
||||||
import org.springframework.jdbc.core.PreparedStatementSetter;
|
import org.springframework.jdbc.core.PreparedStatementSetter;
|
||||||
@ -37,6 +41,7 @@ import org.springframework.jdbc.core.SqlParameterValue;
|
|||||||
import org.springframework.security.saml2.core.Saml2X509Credential;
|
import org.springframework.security.saml2.core.Saml2X509Credential;
|
||||||
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.AssertingPartyDetails;
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.AssertingPartyDetails;
|
||||||
import org.springframework.util.Assert;
|
import org.springframework.util.Assert;
|
||||||
|
import org.springframework.util.function.ThrowingFunction;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A JDBC implementation of {@link AssertingPartyMetadataRepository}.
|
* A JDBC implementation of {@link AssertingPartyMetadataRepository}.
|
||||||
@ -51,6 +56,8 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
|
|||||||
private RowMapper<AssertingPartyMetadata> assertingPartyMetadataRowMapper = new AssertingPartyMetadataRowMapper(
|
private RowMapper<AssertingPartyMetadata> assertingPartyMetadataRowMapper = new AssertingPartyMetadataRowMapper(
|
||||||
ResultSet::getBytes);
|
ResultSet::getBytes);
|
||||||
|
|
||||||
|
private final AssertingPartyMetadataParametersMapper assertingPartyMetadataParametersMapper = new AssertingPartyMetadataParametersMapper();
|
||||||
|
|
||||||
// @formatter:off
|
// @formatter:off
|
||||||
static final String COLUMN_NAMES = "entity_id, "
|
static final String COLUMN_NAMES = "entity_id, "
|
||||||
+ "singlesignon_url, "
|
+ "singlesignon_url, "
|
||||||
@ -77,6 +84,25 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
|
|||||||
+ " FROM " + TABLE_NAME;
|
+ " FROM " + TABLE_NAME;
|
||||||
// @formatter:on
|
// @formatter:on
|
||||||
|
|
||||||
|
// @formatter:off
|
||||||
|
private static final String SAVE_CREDENTIAL_RECORD_SQL = "INSERT INTO " + TABLE_NAME
|
||||||
|
+ " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
|
||||||
|
// @formatter:on
|
||||||
|
|
||||||
|
// @formatter:off
|
||||||
|
private static final String UPDATE_CREDENTIAL_RECORD_SQL = "UPDATE " + TABLE_NAME
|
||||||
|
+ " SET singlesignon_url = ?, "
|
||||||
|
+ "singlesignon_binding = ?, "
|
||||||
|
+ "singlesignon_sign_request = ?, "
|
||||||
|
+ "signing_algorithms = ?, "
|
||||||
|
+ "verification_credentials = ?, "
|
||||||
|
+ "encryption_credentials = ?, "
|
||||||
|
+ "singlelogout_url = ?, "
|
||||||
|
+ "singlelogout_response_url = ?, "
|
||||||
|
+ "singlelogout_binding = ?"
|
||||||
|
+ " WHERE " + ENTITY_ID_FILTER;
|
||||||
|
// @formatter:on
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constructs a {@code JdbcRelyingPartyRegistrationRepository} using the provided
|
* Constructs a {@code JdbcRelyingPartyRegistrationRepository} using the provided
|
||||||
* parameters.
|
* parameters.
|
||||||
@ -116,6 +142,30 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
|
|||||||
return result.iterator();
|
return result.iterator();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Persist this {@link AssertingPartyMetadata}
|
||||||
|
* @param metadata the metadata to persist
|
||||||
|
*/
|
||||||
|
public void save(AssertingPartyMetadata metadata) {
|
||||||
|
Assert.notNull(metadata, "metadata cannot be null");
|
||||||
|
int rows = updateCredentialRecord(metadata);
|
||||||
|
if (rows == 0) {
|
||||||
|
insertCredentialRecord(metadata);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void insertCredentialRecord(AssertingPartyMetadata metadata) {
|
||||||
|
List<SqlParameterValue> parameters = this.assertingPartyMetadataParametersMapper.apply(metadata);
|
||||||
|
this.jdbcOperations.update(SAVE_CREDENTIAL_RECORD_SQL, parameters.toArray());
|
||||||
|
}
|
||||||
|
|
||||||
|
private int updateCredentialRecord(AssertingPartyMetadata metadata) {
|
||||||
|
List<SqlParameterValue> parameters = this.assertingPartyMetadataParametersMapper.apply(metadata);
|
||||||
|
SqlParameterValue credentialId = parameters.remove(0);
|
||||||
|
parameters.add(credentialId);
|
||||||
|
return this.jdbcOperations.update(UPDATE_CREDENTIAL_RECORD_SQL, parameters.toArray());
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The default {@link RowMapper} that maps the current row in
|
* The default {@link RowMapper} that maps the current row in
|
||||||
* {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}.
|
* {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}.
|
||||||
@ -181,6 +231,34 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static class AssertingPartyMetadataParametersMapper
|
||||||
|
implements Function<AssertingPartyMetadata, List<SqlParameterValue>> {
|
||||||
|
|
||||||
|
private final Serializer<Object> serializer = new DefaultSerializer();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SqlParameterValue> apply(AssertingPartyMetadata record) {
|
||||||
|
List<SqlParameterValue> parameters = new ArrayList<>();
|
||||||
|
|
||||||
|
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getEntityId()));
|
||||||
|
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceLocation()));
|
||||||
|
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceBinding().getUrn()));
|
||||||
|
parameters.add(new SqlParameterValue(Types.BOOLEAN, record.getWantAuthnRequestsSigned()));
|
||||||
|
ThrowingFunction<List<String>, byte[]> algorithms = this.serializer::serializeToByteArray;
|
||||||
|
parameters.add(new SqlParameterValue(Types.BLOB, algorithms.apply(record.getSigningAlgorithms())));
|
||||||
|
ThrowingFunction<Collection<Saml2X509Credential>, byte[]> credentials = this.serializer::serializeToByteArray;
|
||||||
|
parameters
|
||||||
|
.add(new SqlParameterValue(Types.BLOB, credentials.apply(record.getVerificationX509Credentials())));
|
||||||
|
parameters.add(new SqlParameterValue(Types.BLOB, credentials.apply(record.getEncryptionX509Credentials())));
|
||||||
|
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceLocation()));
|
||||||
|
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceResponseLocation()));
|
||||||
|
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceBinding().getUrn()));
|
||||||
|
|
||||||
|
return parameters;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
private interface GetBytes {
|
private interface GetBytes {
|
||||||
|
|
||||||
byte[] getBytes(ResultSet rs, String columnName) throws SQLException;
|
byte[] getBytes(ResultSet rs, String columnName) throws SQLException;
|
||||||
|
@ -16,27 +16,17 @@
|
|||||||
|
|
||||||
package org.springframework.security.saml2.provider.service.registration;
|
package org.springframework.security.saml2.provider.service.registration;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.security.cert.CertificateFactory;
|
|
||||||
import java.security.cert.X509Certificate;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.AfterEach;
|
import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import org.springframework.core.io.ClassPathResource;
|
|
||||||
import org.springframework.core.serializer.DefaultSerializer;
|
|
||||||
import org.springframework.core.serializer.Serializer;
|
|
||||||
import org.springframework.jdbc.core.JdbcOperations;
|
import org.springframework.jdbc.core.JdbcOperations;
|
||||||
import org.springframework.jdbc.core.JdbcTemplate;
|
import org.springframework.jdbc.core.JdbcTemplate;
|
||||||
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
|
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
|
||||||
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
|
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
|
||||||
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
|
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
|
||||||
import org.springframework.security.saml2.core.Saml2X509Credential;
|
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
||||||
@ -48,41 +38,21 @@ class JdbcAssertingPartyMetadataRepositoryTests {
|
|||||||
|
|
||||||
private static final String SCHEMA_SQL_RESOURCE = "org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql";
|
private static final String SCHEMA_SQL_RESOURCE = "org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql";
|
||||||
|
|
||||||
private static final String SAVE_SQL = "INSERT INTO saml2_asserting_party_metadata ("
|
|
||||||
+ JdbcAssertingPartyMetadataRepository.COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
|
|
||||||
|
|
||||||
private static final String ENTITY_ID = "https://localhost/simplesaml/saml2/idp/metadata.php";
|
|
||||||
|
|
||||||
private static final String SINGLE_SIGNON_URL = "https://localhost/SSO";
|
|
||||||
|
|
||||||
private static final String SINGLE_SIGNON_BINDING = Saml2MessageBinding.REDIRECT.getUrn();
|
|
||||||
|
|
||||||
private static final boolean SINGLE_SIGNON_SIGN_REQUEST = false;
|
|
||||||
|
|
||||||
private static final String SINGLE_LOGOUT_URL = "https://localhost/SLO";
|
|
||||||
|
|
||||||
private static final String SINGLE_LOGOUT_RESPONSE_URL = "https://localhost/SLO/response";
|
|
||||||
|
|
||||||
private static final String SINGLE_LOGOUT_BINDING = Saml2MessageBinding.REDIRECT.getUrn();
|
|
||||||
|
|
||||||
private static final List<String> SIGNING_ALGORITHMS = List.of("http://www.w3.org/2001/04/xmldsig-more#rsa-sha512");
|
|
||||||
|
|
||||||
private X509Certificate certificate;
|
|
||||||
|
|
||||||
private EmbeddedDatabase db;
|
private EmbeddedDatabase db;
|
||||||
|
|
||||||
private JdbcAssertingPartyMetadataRepository repository;
|
private JdbcAssertingPartyMetadataRepository repository;
|
||||||
|
|
||||||
private JdbcOperations jdbcOperations;
|
private JdbcOperations jdbcOperations;
|
||||||
|
|
||||||
private final Serializer<Object> serializer = new DefaultSerializer();
|
private final AssertingPartyMetadata metadata = TestRelyingPartyRegistrations.full()
|
||||||
|
.build()
|
||||||
|
.getAssertingPartyMetadata();
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
void setUp() {
|
void setUp() {
|
||||||
this.db = createDb();
|
this.db = createDb();
|
||||||
this.jdbcOperations = new JdbcTemplate(this.db);
|
this.jdbcOperations = new JdbcTemplate(this.db);
|
||||||
this.repository = new JdbcAssertingPartyMetadataRepository(this.jdbcOperations);
|
this.repository = new JdbcAssertingPartyMetadataRepository(this.jdbcOperations);
|
||||||
this.certificate = loadCertificate("rsa.crt");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@AfterEach
|
@AfterEach
|
||||||
@ -109,26 +79,12 @@ class JdbcAssertingPartyMetadataRepositoryTests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void findByEntityId() throws IOException {
|
void findByEntityId() {
|
||||||
this.jdbcOperations.update(SAVE_SQL, ENTITY_ID, SINGLE_SIGNON_URL, SINGLE_SIGNON_BINDING,
|
this.repository.save(this.metadata);
|
||||||
SINGLE_SIGNON_SIGN_REQUEST, this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
|
|
||||||
this.serializer.serializeToByteArray(asCredentials(this.certificate)),
|
|
||||||
this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
|
|
||||||
SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
|
|
||||||
|
|
||||||
AssertingPartyMetadata found = this.repository.findByEntityId(ENTITY_ID);
|
AssertingPartyMetadata found = this.repository.findByEntityId(this.metadata.getEntityId());
|
||||||
|
|
||||||
assertThat(found).isNotNull();
|
assertAssertingPartyEquals(found, this.metadata);
|
||||||
assertThat(found.getEntityId()).isEqualTo(ENTITY_ID);
|
|
||||||
assertThat(found.getSingleSignOnServiceLocation()).isEqualTo(SINGLE_SIGNON_URL);
|
|
||||||
assertThat(found.getSingleSignOnServiceBinding().getUrn()).isEqualTo(SINGLE_SIGNON_BINDING);
|
|
||||||
assertThat(found.getWantAuthnRequestsSigned()).isEqualTo(SINGLE_SIGNON_SIGN_REQUEST);
|
|
||||||
assertThat(found.getSingleLogoutServiceLocation()).isEqualTo(SINGLE_LOGOUT_URL);
|
|
||||||
assertThat(found.getSingleLogoutServiceResponseLocation()).isEqualTo(SINGLE_LOGOUT_RESPONSE_URL);
|
|
||||||
assertThat(found.getSingleLogoutServiceBinding().getUrn()).isEqualTo(SINGLE_LOGOUT_BINDING);
|
|
||||||
assertThat(found.getSigningAlgorithms()).contains(SIGNING_ALGORITHMS.get(0));
|
|
||||||
assertThat(found.getVerificationX509Credentials()).hasSize(1);
|
|
||||||
assertThat(found.getEncryptionX509Credentials()).hasSize(1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -138,28 +94,30 @@ class JdbcAssertingPartyMetadataRepositoryTests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void iterator() throws IOException {
|
void iterator() {
|
||||||
this.jdbcOperations.update(SAVE_SQL, ENTITY_ID, SINGLE_SIGNON_URL, SINGLE_SIGNON_BINDING,
|
AssertingPartyMetadata second = RelyingPartyRegistration.withAssertingPartyMetadata(this.metadata)
|
||||||
SINGLE_SIGNON_SIGN_REQUEST, this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
|
.assertingPartyMetadata((a) -> a.entityId("https://example.org/idp"))
|
||||||
this.serializer.serializeToByteArray(asCredentials(this.certificate)),
|
.build()
|
||||||
this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
|
.getAssertingPartyMetadata();
|
||||||
SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
|
this.repository.save(this.metadata);
|
||||||
|
this.repository.save(second);
|
||||||
this.jdbcOperations.update(SAVE_SQL, "https://localhost/simplesaml2/saml2/idp/metadata.php", SINGLE_SIGNON_URL,
|
|
||||||
SINGLE_SIGNON_BINDING, SINGLE_SIGNON_SIGN_REQUEST,
|
|
||||||
this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
|
|
||||||
this.serializer.serializeToByteArray(asCredentials(this.certificate)),
|
|
||||||
this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
|
|
||||||
SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
|
|
||||||
|
|
||||||
Iterator<AssertingPartyMetadata> iterator = this.repository.iterator();
|
Iterator<AssertingPartyMetadata> iterator = this.repository.iterator();
|
||||||
AssertingPartyMetadata first = iterator.next();
|
|
||||||
assertThat(first).isNotNull();
|
assertAssertingPartyEquals(iterator.next(), this.metadata);
|
||||||
AssertingPartyMetadata second = iterator.next();
|
assertAssertingPartyEquals(iterator.next(), second);
|
||||||
assertThat(second).isNotNull();
|
|
||||||
assertThat(iterator.hasNext()).isFalse();
|
assertThat(iterator.hasNext()).isFalse();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void saveWhenExistingThenUpdates() {
|
||||||
|
this.repository.save(this.metadata);
|
||||||
|
boolean existing = this.metadata.getWantAuthnRequestsSigned();
|
||||||
|
this.repository.save(this.metadata.mutate().wantAuthnRequestsSigned(!existing).build());
|
||||||
|
boolean updated = this.repository.findByEntityId(this.metadata.getEntityId()).getWantAuthnRequestsSigned();
|
||||||
|
assertThat(existing).isNotEqualTo(updated);
|
||||||
|
}
|
||||||
|
|
||||||
private static EmbeddedDatabase createDb() {
|
private static EmbeddedDatabase createDb() {
|
||||||
return createDb(SCHEMA_SQL_RESOURCE);
|
return createDb(SCHEMA_SQL_RESOURCE);
|
||||||
}
|
}
|
||||||
@ -175,19 +133,19 @@ class JdbcAssertingPartyMetadataRepositoryTests {
|
|||||||
// @formatter:on
|
// @formatter:on
|
||||||
}
|
}
|
||||||
|
|
||||||
private X509Certificate loadCertificate(String path) {
|
private void assertAssertingPartyEquals(AssertingPartyMetadata found, AssertingPartyMetadata expected) {
|
||||||
try (InputStream is = new ClassPathResource(path).getInputStream()) {
|
assertThat(found).isNotNull();
|
||||||
CertificateFactory factory = CertificateFactory.getInstance("X.509");
|
assertThat(found.getEntityId()).isEqualTo(expected.getEntityId());
|
||||||
return (X509Certificate) factory.generateCertificate(is);
|
assertThat(found.getSingleSignOnServiceLocation()).isEqualTo(expected.getSingleSignOnServiceLocation());
|
||||||
}
|
assertThat(found.getSingleSignOnServiceBinding()).isEqualTo(expected.getSingleSignOnServiceBinding());
|
||||||
catch (Exception ex) {
|
assertThat(found.getWantAuthnRequestsSigned()).isEqualTo(expected.getWantAuthnRequestsSigned());
|
||||||
throw new RuntimeException("Error loading certificate from " + path, ex);
|
assertThat(found.getSingleLogoutServiceLocation()).isEqualTo(expected.getSingleLogoutServiceLocation());
|
||||||
}
|
assertThat(found.getSingleLogoutServiceResponseLocation())
|
||||||
}
|
.isEqualTo(expected.getSingleLogoutServiceResponseLocation());
|
||||||
|
assertThat(found.getSingleLogoutServiceBinding()).isEqualTo(expected.getSingleLogoutServiceBinding());
|
||||||
private Collection<Saml2X509Credential> asCredentials(X509Certificate certificate) {
|
assertThat(found.getSigningAlgorithms()).containsAll(expected.getSigningAlgorithms());
|
||||||
return List.of(new Saml2X509Credential(certificate, Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION,
|
assertThat(found.getVerificationX509Credentials()).containsAll(expected.getVerificationX509Credentials());
|
||||||
Saml2X509Credential.Saml2X509CredentialType.VERIFICATION));
|
assertThat(found.getEncryptionX509Credentials()).containsAll(expected.getEncryptionX509Credentials());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user