From f1a76efc2db6ab361b6121af0244a42ac23dbb77 Mon Sep 17 00:00:00 2001
From: Sander van Schouwenburg <svschouw@betterbe.com>
Date: Fri, 11 Feb 2022 10:52:07 +0100
Subject: [PATCH] Preserve order of RelyingPartRegistration credentials

Issue gh-10799
---
 .../saml2/core/TestSaml2X509Credentials.java  | 51 +++++++++++++++++-
 .../RelyingPartyRegistrationTests.java        | 52 +++++++++++++++++--
 2 files changed, 98 insertions(+), 5 deletions(-)

diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/TestSaml2X509Credentials.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/TestSaml2X509Credentials.java
index 55cd6b53b9..519a2a254a 100644
--- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/TestSaml2X509Credentials.java
+++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/TestSaml2X509Credentials.java
@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -51,6 +51,10 @@ public final class TestSaml2X509Credentials {
 		return new Saml2X509Credential(idpCertificate(), Saml2X509CredentialType.VERIFICATION);
 	}
 
+	public static Saml2X509Credential relyingPartyEncryptingCredential() {
+		return new Saml2X509Credential(idpCertificate(), Saml2X509CredentialType.ENCRYPTION);
+	}
+
 	public static Saml2X509Credential relyingPartySigningCredential() {
 		return new Saml2X509Credential(spPrivateKey(), spCertificate(), Saml2X509CredentialType.SIGNING);
 	}
@@ -59,6 +63,15 @@ public final class TestSaml2X509Credentials {
 		return new Saml2X509Credential(spPrivateKey(), spCertificate(), Saml2X509CredentialType.DECRYPTION);
 	}
 
+	public static Saml2X509Credential altPublicCredential() {
+		return new Saml2X509Credential(altCertificate(), Saml2X509CredentialType.VERIFICATION, Saml2X509CredentialType.ENCRYPTION);
+	}
+
+	public static Saml2X509Credential altPrivateCredential() {
+		return new Saml2X509Credential(altPrivateKey(), altCertificate(), Saml2X509CredentialType.SIGNING,
+				Saml2X509CredentialType.DECRYPTION);
+	}
+
 	private static X509Certificate certificate(String cert) {
 		ByteArrayInputStream certBytes = new ByteArrayInputStream(cert.getBytes());
 		try {
@@ -170,4 +183,40 @@ public final class TestSaml2X509Credentials {
 						+ "-----END PRIVATE KEY-----");
 	}
 
+	private static X509Certificate altCertificate() {
+		return certificate(
+				"-----BEGIN CERTIFICATE-----\n"	+ "MIICkDCCAfkCFEstVfmWSFQp/j88GaMUwqVK72adMA0GCSqGSIb3DQEBCwUAMIGG\n"
+						+ "MQswCQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjESMBAGA1UEBwwJVmFu\n"
+						+ "Y291dmVyMR0wGwYDVQQKDBRTcHJpbmcgU2VjdXJpdHkgU0FNTDEMMAoGA1UECwwD\n"
+						+ "YWx0MSEwHwYDVQQDDBhhbHQuc3ByaW5nLnNlY3VyaXR5LnNhbWwwHhcNMjIwMjEw\n"
+						+ "MTY1ODA4WhcNMzIwMjEwMTY1ODA4WjCBhjELMAkGA1UEBhMCVVMxEzARBgNVBAgM\n"
+						+ "Cldhc2hpbmd0b24xEjAQBgNVBAcMCVZhbmNvdXZlcjEdMBsGA1UECgwUU3ByaW5n\n"
+						+ "IFNlY3VyaXR5IFNBTUwxDDAKBgNVBAsMA2FsdDEhMB8GA1UEAwwYYWx0LnNwcmlu\n"
+						+ "Zy5zZWN1cml0eS5zYW1sMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC9ZGWj\n"
+						+ "TPDsymQCJL044py4xLsBI/S9RvzNeR9oD/tHyoxCE+YZzjf0PyBtwqKzkKWqCPf4\n"
+						+ "XGUYHfEpkM5kJYwCW8TsOx5fnwLIQweiPqjYrBr/O0IjHMqYG9HlR/ros7iBt4ab\n"
+						+ "EGUu/B9yYg1YRYPxKQ6TNP3AD+9tBT8TsFFyjwIDAQABMA0GCSqGSIb3DQEBCwUA\n"
+						+ "A4GBAKJf2VHLjkCHRxlbWn63jGiquq3ENYgd1JS0DZ3ggFmuc6zQiqxzRGtArIDZ\n"
+						+ "0jH5nrG0jcvO0fqDqBQh0iT8thfUnkViAQvACZ9a+0x0NzUicJ+Ra51c8Z2enqbg\n"
+						+ "pXy+ga67HcAXrDekm1MCGCgiEb/Cgl41lsideqhC8Efl7PRN\n" + "-----END CERTIFICATE-----");
+	}
+
+	private static PrivateKey altPrivateKey() {
+		return privateKey(
+				"-----BEGIN PRIVATE KEY-----\n"	+ "MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBAL1kZaNM8OzKZAIk\n"
+						+ "vTjinLjEuwEj9L1G/M15H2gP+0fKjEIT5hnON/Q/IG3CorOQpaoI9/hcZRgd8SmQ\n"
+						+ "zmQljAJbxOw7Hl+fAshDB6I+qNisGv87QiMcypgb0eVH+uizuIG3hpsQZS78H3Ji\n"
+						+ "DVhFg/EpDpM0/cAP720FPxOwUXKPAgMBAAECgYEApYKslAZ0cer5dSoYNzNLFOnQ\n"
+						+ "J1H92r/Dw+k6+h0lUvr+keyD5T9jhM76DxHOUDBzpmIKGoDcVDQugk2rILfzXsQA\n"
+						+ "JtwvDRJk32Z02Vt0jb7t/WUOOQhjKCjQuv9/tOx90GCl0VxYG69UOjaMRWrlg/i9\n"
+						+ "6/zcTRIahIn5XxF0psECQQD7ivJCpDbOLJGsc8gNJR4cvjZ1q0mHIOrbKqJC0y1n\n"
+						+ "5DrzGEflPeyCUwnOKNp9HJQP8gmZzXfj0JM9KsjpiUChAkEAwL+FmhDoTiqStIrH\n"
+						+ "h9Kdnsev//imMmRHxjwDhntYvqavUsISRmY3imd8inoYq5dzWQMzBtoTyMRmqeLT\n"
+						+ "DHV1LwJAW4xaV37Eo4z9B7Kr4Hzd1MA1ueW5QQDt+Q4vN/r7z4/1FHyFzh0Xcucd\n"
+						+ "7nZX7qj0CkmgzOVG+Rb0P5LOxJA7gQJBAK1KQ2qNct375qPM9bEGSVGchH6k5X7+\n"
+						+ "q4ztHdpFgTb/EzdbZiTG935GpjC1rwJuinTnrHOnkwv4j7iDRm24GF8CQQDqPvrQ\n"
+						+ "GcItR6UUy0q/B8UxLzlE6t+HiznfiJKfyGgCHU56Y4/ZhzSQz2MZHz9SK4DsUL9s\n"
+						+ "bOYrWq8VY2fyjV1t\n" + "-----END PRIVATE KEY-----");
+	}
+
 }
diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java
index d25d4b981c..b6a6c52276 100644
--- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java
+++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java
@@ -17,8 +17,8 @@
 package org.springframework.security.saml2.provider.service.registration;
 
 import org.junit.jupiter.api.Test;
-
-import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
+import org.springframework.security.saml2.core.Saml2X509Credential;
+import org.springframework.security.saml2.core.TestSaml2X509Credentials;
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -81,9 +81,53 @@ public class RelyingPartyRegistrationTests {
 		RelyingPartyRegistration relyingPartyRegistration = RelyingPartyRegistration.withRegistrationId("id")
 				.entityId("entity-id").assertionConsumerServiceLocation("location")
 				.assertingPartyDetails((assertingParty) -> assertingParty.entityId("entity-id")
-						.singleSignOnServiceLocation("location"))
-				.credentials((c) -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential())).build();
+						.singleSignOnServiceLocation("location")
+						.verificationX509Credentials((c) -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential()))
+				).build();
 		assertThat(relyingPartyRegistration.getAssertionConsumerServiceBinding()).isEqualTo(Saml2MessageBinding.POST);
 	}
 
+	@Test
+	public void buildPreservesCredentialsOrder() {
+		Saml2X509Credential altRpCredential = TestSaml2X509Credentials.altPrivateCredential();
+		Saml2X509Credential altApCredential = TestSaml2X509Credentials.altPublicCredential();
+		Saml2X509Credential verifyingCredential = TestSaml2X509Credentials.relyingPartyVerifyingCredential();
+		Saml2X509Credential encryptingCredential = TestSaml2X509Credentials.relyingPartyEncryptingCredential();
+		Saml2X509Credential signingCredential = TestSaml2X509Credentials.relyingPartySigningCredential();
+		Saml2X509Credential decryptionCredential = TestSaml2X509Credentials.relyingPartyDecryptingCredential();
+
+		// Test with the alt credentials first
+		RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.noCredentials()
+				.assertingPartyDetails((assertingParty) -> assertingParty
+						.verificationX509Credentials((c) -> { c.add(altApCredential); c.add(verifyingCredential); })
+						.encryptionX509Credentials((c) -> { c.add(altApCredential); c.add(encryptingCredential); }))
+				.signingX509Credentials(c -> { c.add(altRpCredential); c.add(signingCredential); })
+				.decryptionX509Credentials(c -> { c.add(altRpCredential); c.add(decryptionCredential); })
+				.build();
+		assertThat(relyingPartyRegistration.getSigningX509Credentials())
+				.containsExactly(altRpCredential, signingCredential);
+		assertThat(relyingPartyRegistration.getDecryptionX509Credentials())
+				.containsExactly(altRpCredential, decryptionCredential);
+		assertThat(relyingPartyRegistration.getAssertingPartyDetails().getVerificationX509Credentials())
+				.containsExactly(altApCredential, verifyingCredential);
+		assertThat(relyingPartyRegistration.getAssertingPartyDetails().getEncryptionX509Credentials())
+				.containsExactly(altApCredential, encryptingCredential);
+
+		// Test with the alt credentials last
+		relyingPartyRegistration = TestRelyingPartyRegistrations.noCredentials()
+				.assertingPartyDetails((assertingParty) -> assertingParty
+						.verificationX509Credentials((c) -> { c.add(verifyingCredential); c.add(altApCredential); })
+						.encryptionX509Credentials((c) -> { c.add(encryptingCredential); c.add(altApCredential); }))
+				.signingX509Credentials(c -> { c.add(signingCredential); c.add(altRpCredential); })
+				.decryptionX509Credentials(c -> { c.add(decryptionCredential); c.add(altRpCredential); })
+				.build();
+		assertThat(relyingPartyRegistration.getSigningX509Credentials())
+				.containsExactly(signingCredential, altRpCredential);
+		assertThat(relyingPartyRegistration.getDecryptionX509Credentials())
+				.containsExactly(decryptionCredential, altRpCredential);
+		assertThat(relyingPartyRegistration.getAssertingPartyDetails().getVerificationX509Credentials())
+				.containsExactly(verifyingCredential, altApCredential);
+		assertThat(relyingPartyRegistration.getAssertingPartyDetails().getEncryptionX509Credentials())
+				.containsExactly(encryptingCredential, altApCredential);
+	}
 }