diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml4AssertingPartyMetadataRepositoryTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml4AssertingPartyMetadataRepositoryTests.java index b4663507b6..2da900fd7c 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml4AssertingPartyMetadataRepositoryTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml4AssertingPartyMetadataRepositoryTests.java @@ -20,16 +20,23 @@ import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.io.InputStreamReader; +import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import net.shibboleth.utilities.java.support.xml.SerializeSupport; +import okhttp3.mockwebserver.Dispatcher; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; -import org.junit.jupiter.api.BeforeEach; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; @@ -68,52 +75,59 @@ public class OpenSaml4AssertingPartyMetadataRepositoryTests { OpenSamlInitializationService.initialize(); } - private String metadata; + private static MetadataDispatcher dispatcher = new MetadataDispatcher() + .addResponse("/entity.xml", readFile("test-metadata.xml")) + .addResponse("/entities.xml", readFile("test-entitiesdescriptor.xml")); - private String entitiesDescriptor; + private static MockWebServer web = new MockWebServer(); - @BeforeEach - public void setup() throws Exception { - ClassPathResource resource = new ClassPathResource("test-metadata.xml"); - try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) { - this.metadata = reader.lines().collect(Collectors.joining()); + private static String readFile(String fileName) { + try { + ClassPathResource resource = new ClassPathResource(fileName); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) { + return reader.lines().collect(Collectors.joining()); + } } - resource = new ClassPathResource("test-entitiesdescriptor.xml"); - try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) { - this.entitiesDescriptor = reader.lines().collect(Collectors.joining()); + catch (IOException ex) { + throw new UncheckedIOException(ex); } } + @BeforeAll + public static void start() throws Exception { + web.setDispatcher(dispatcher); + web.start(); + } + + @AfterAll + public static void shutdown() throws Exception { + web.shutdown(); + } + @Test public void withMetadataUrlLocationWhenResolvableThenFindByEntityIdReturns() throws Exception { - try (MockWebServer server = new MockWebServer()) { - enqueue(server, this.metadata, 3); - AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository - .withTrustedMetadataLocation(server.url("/").toString()) - .build(); - AssertingPartyMetadata party = parties.findByEntityId("https://idp.example.com/idp/shibboleth"); - assertThat(party.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth"); - assertThat(party.getSingleSignOnServiceLocation()) - .isEqualTo("https://idp.example.com/idp/profile/SAML2/POST/SSO"); - assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST); - assertThat(party.getVerificationX509Credentials()).hasSize(1); - assertThat(party.getEncryptionX509Credentials()).hasSize(1); - } + AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository + .withTrustedMetadataLocation(web.url("/entity.xml").toString()) + .build(); + AssertingPartyMetadata party = parties.findByEntityId("https://idp.example.com/idp/shibboleth"); + assertThat(party.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth"); + assertThat(party.getSingleSignOnServiceLocation()) + .isEqualTo("https://idp.example.com/idp/profile/SAML2/POST/SSO"); + assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST); + assertThat(party.getVerificationX509Credentials()).hasSize(1); + assertThat(party.getEncryptionX509Credentials()).hasSize(1); } @Test public void withMetadataUrlLocationnWhenResolvableThenIteratorReturns() throws Exception { - try (MockWebServer server = new MockWebServer()) { - enqueue(server, this.entitiesDescriptor, 3); - List parties = new ArrayList<>(); - OpenSaml4AssertingPartyMetadataRepository.withTrustedMetadataLocation(server.url("/").toString()) - .build() - .iterator() - .forEachRemaining(parties::add); - assertThat(parties).hasSize(2); - assertThat(parties).extracting(AssertingPartyMetadata::getEntityId) - .contains("https://ap.example.org/idp/shibboleth", "https://idp.example.com/idp/shibboleth"); - } + List parties = new ArrayList<>(); + OpenSaml4AssertingPartyMetadataRepository.withTrustedMetadataLocation(web.url("/entities.xml").toString()) + .build() + .iterator() + .forEachRemaining(parties::add); + assertThat(parties).hasSize(2); + assertThat(parties).extracting(AssertingPartyMetadata::getEntityId) + .contains("https://ap.example.org/idp/shibboleth", "https://idp.example.com/idp/shibboleth"); } @Test @@ -128,12 +142,10 @@ public class OpenSaml4AssertingPartyMetadataRepositoryTests { @Test public void withMetadataUrlLocationWhenMalformedResponseThenSaml2Exception() throws Exception { - try (MockWebServer server = new MockWebServer()) { - enqueue(server, "malformed", 3); - String url = server.url("/").toString(); - assertThatExceptionOfType(Saml2Exception.class) - .isThrownBy(() -> OpenSaml4AssertingPartyMetadataRepository.withTrustedMetadataLocation(url).build()); - } + dispatcher.addResponse("/malformed", "malformed"); + String url = web.url("/malformed").toString(); + assertThatExceptionOfType(Saml2Exception.class) + .isThrownBy(() -> OpenSaml4AssertingPartyMetadataRepository.withTrustedMetadataLocation(url).build()); } @Test @@ -211,14 +223,13 @@ public class OpenSaml4AssertingPartyMetadataRepositoryTests { String serialized = serialize(descriptor); Credential credential = TestOpenSamlObjects .getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID()); - try (MockWebServer server = new MockWebServer()) { - enqueue(server, serialized, 3); - AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository - .withTrustedMetadataLocation(server.url("/").toString()) - .verificationCredentials((c) -> c.add(credential)) - .build(); - assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); - } + String endpoint = "/" + UUID.randomUUID().toString(); + dispatcher.addResponse(endpoint, serialized); + AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository + .withTrustedMetadataLocation(web.url(endpoint).toString()) + .verificationCredentials((c) -> c.add(credential)) + .build(); + assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); } @Test @@ -230,13 +241,12 @@ public class OpenSaml4AssertingPartyMetadataRepositoryTests { String serialized = serialize(descriptor); Credential credential = TestOpenSamlObjects .getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID()); - try (MockWebServer server = new MockWebServer()) { - enqueue(server, serialized, 3); - assertThatExceptionOfType(Saml2Exception.class).isThrownBy(() -> OpenSaml4AssertingPartyMetadataRepository - .withTrustedMetadataLocation(server.url("/").toString()) - .verificationCredentials((c) -> c.add(credential)) - .build()); - } + String endpoint = "/" + UUID.randomUUID().toString(); + dispatcher.addResponse(endpoint, serialized); + assertThatExceptionOfType(Saml2Exception.class).isThrownBy(() -> OpenSaml4AssertingPartyMetadataRepository + .withTrustedMetadataLocation(web.url(endpoint).toString()) + .verificationCredentials((c) -> c.add(credential)) + .build()); } @Test @@ -326,14 +336,13 @@ public class OpenSaml4AssertingPartyMetadataRepositoryTests { String serialized = serialize(descriptor); Credential credential = TestOpenSamlObjects .getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID()); - try (MockWebServer server = new MockWebServer()) { - enqueue(server, serialized, 3); - AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository - .withMetadataLocation(server.url("/").toString()) - .verificationCredentials((c) -> c.add(credential)) - .build(); - assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); - } + String endpoint = "/" + UUID.randomUUID().toString(); + dispatcher.addResponse(endpoint, serialized); + AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository + .withMetadataLocation(web.url(endpoint).toString()) + .verificationCredentials((c) -> c.add(credential)) + .build(); + assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); } private static String serialize(XMLObject object) { @@ -353,4 +362,28 @@ public class OpenSaml4AssertingPartyMetadataRepositoryTests { } } + private static final class MetadataDispatcher extends Dispatcher { + + private final MockResponse head = new MockResponse(); + + private final Map responses = new ConcurrentHashMap<>(); + + private MetadataDispatcher() { + } + + @Override + public MockResponse dispatch(RecordedRequest request) throws InterruptedException { + if ("HEAD".equals(request.getMethod())) { + return this.head; + } + return this.responses.get(request.getPath()); + } + + private MetadataDispatcher addResponse(String path, String body) { + this.responses.put(path, new MockResponse().setBody(body).setResponseCode(200)); + return this; + } + + } + } diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5AssertingPartyMetadataRepositoryTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5AssertingPartyMetadataRepositoryTests.java index 27c0fd5adb..c01bb82ea6 100644 --- a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5AssertingPartyMetadataRepositoryTests.java +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5AssertingPartyMetadataRepositoryTests.java @@ -20,16 +20,23 @@ import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.io.InputStreamReader; +import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import net.shibboleth.shared.xml.SerializeSupport; +import okhttp3.mockwebserver.Dispatcher; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; -import org.junit.jupiter.api.BeforeEach; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; @@ -68,52 +75,59 @@ public class OpenSaml5AssertingPartyMetadataRepositoryTests { OpenSamlInitializationService.initialize(); } - private String metadata; + private static MetadataDispatcher dispatcher = new MetadataDispatcher() + .addResponse("/entity.xml", readFile("test-metadata.xml")) + .addResponse("/entities.xml", readFile("test-entitiesdescriptor.xml")); - private String entitiesDescriptor; + private static MockWebServer web = new MockWebServer(); - @BeforeEach - public void setup() throws Exception { - ClassPathResource resource = new ClassPathResource("test-metadata.xml"); - try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) { - this.metadata = reader.lines().collect(Collectors.joining()); + private static String readFile(String fileName) { + try { + ClassPathResource resource = new ClassPathResource(fileName); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) { + return reader.lines().collect(Collectors.joining()); + } } - resource = new ClassPathResource("test-entitiesdescriptor.xml"); - try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) { - this.entitiesDescriptor = reader.lines().collect(Collectors.joining()); + catch (IOException ex) { + throw new UncheckedIOException(ex); } } + @BeforeAll + public static void start() throws Exception { + web.setDispatcher(dispatcher); + web.start(); + } + + @AfterAll + public static void shutdown() throws Exception { + web.shutdown(); + } + @Test public void withMetadataUrlLocationWhenResolvableThenFindByEntityIdReturns() throws Exception { - try (MockWebServer server = new MockWebServer()) { - enqueue(server, this.metadata, 3); - AssertingPartyMetadataRepository parties = OpenSaml5AssertingPartyMetadataRepository - .withTrustedMetadataLocation(server.url("/").toString()) - .build(); - AssertingPartyMetadata party = parties.findByEntityId("https://idp.example.com/idp/shibboleth"); - assertThat(party.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth"); - assertThat(party.getSingleSignOnServiceLocation()) - .isEqualTo("https://idp.example.com/idp/profile/SAML2/POST/SSO"); - assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST); - assertThat(party.getVerificationX509Credentials()).hasSize(1); - assertThat(party.getEncryptionX509Credentials()).hasSize(1); - } + AssertingPartyMetadataRepository parties = OpenSaml5AssertingPartyMetadataRepository + .withTrustedMetadataLocation(web.url("/entity.xml").toString()) + .build(); + AssertingPartyMetadata party = parties.findByEntityId("https://idp.example.com/idp/shibboleth"); + assertThat(party.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth"); + assertThat(party.getSingleSignOnServiceLocation()) + .isEqualTo("https://idp.example.com/idp/profile/SAML2/POST/SSO"); + assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST); + assertThat(party.getVerificationX509Credentials()).hasSize(1); + assertThat(party.getEncryptionX509Credentials()).hasSize(1); } @Test public void withMetadataUrlLocationnWhenResolvableThenIteratorReturns() throws Exception { - try (MockWebServer server = new MockWebServer()) { - enqueue(server, this.entitiesDescriptor, 3); - List parties = new ArrayList<>(); - OpenSaml5AssertingPartyMetadataRepository.withTrustedMetadataLocation(server.url("/").toString()) - .build() - .iterator() - .forEachRemaining(parties::add); - assertThat(parties).hasSize(2); - assertThat(parties).extracting(AssertingPartyMetadata::getEntityId) - .contains("https://ap.example.org/idp/shibboleth", "https://idp.example.com/idp/shibboleth"); - } + List parties = new ArrayList<>(); + OpenSaml5AssertingPartyMetadataRepository.withTrustedMetadataLocation(web.url("/entities.xml").toString()) + .build() + .iterator() + .forEachRemaining(parties::add); + assertThat(parties).hasSize(2); + assertThat(parties).extracting(AssertingPartyMetadata::getEntityId) + .contains("https://ap.example.org/idp/shibboleth", "https://idp.example.com/idp/shibboleth"); } @Test @@ -128,12 +142,10 @@ public class OpenSaml5AssertingPartyMetadataRepositoryTests { @Test public void withMetadataUrlLocationWhenMalformedResponseThenSaml2Exception() throws Exception { - try (MockWebServer server = new MockWebServer()) { - enqueue(server, "malformed", 3); - String url = server.url("/").toString(); - assertThatExceptionOfType(Saml2Exception.class) - .isThrownBy(() -> OpenSaml5AssertingPartyMetadataRepository.withTrustedMetadataLocation(url).build()); - } + dispatcher.addResponse("/malformed", "malformed"); + String url = web.url("/malformed").toString(); + assertThatExceptionOfType(Saml2Exception.class) + .isThrownBy(() -> OpenSaml5AssertingPartyMetadataRepository.withTrustedMetadataLocation(url).build()); } @Test @@ -211,14 +223,13 @@ public class OpenSaml5AssertingPartyMetadataRepositoryTests { String serialized = serialize(descriptor); Credential credential = TestOpenSamlObjects .getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID()); - try (MockWebServer server = new MockWebServer()) { - enqueue(server, serialized, 3); - AssertingPartyMetadataRepository parties = OpenSaml5AssertingPartyMetadataRepository - .withTrustedMetadataLocation(server.url("/").toString()) - .verificationCredentials((c) -> c.add(credential)) - .build(); - assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); - } + String endpoint = "/" + UUID.randomUUID().toString(); + dispatcher.addResponse(endpoint, serialized); + AssertingPartyMetadataRepository parties = OpenSaml5AssertingPartyMetadataRepository + .withTrustedMetadataLocation(web.url(endpoint).toString()) + .verificationCredentials((c) -> c.add(credential)) + .build(); + assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); } @Test @@ -230,13 +241,12 @@ public class OpenSaml5AssertingPartyMetadataRepositoryTests { String serialized = serialize(descriptor); Credential credential = TestOpenSamlObjects .getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID()); - try (MockWebServer server = new MockWebServer()) { - enqueue(server, serialized, 3); - assertThatExceptionOfType(Saml2Exception.class).isThrownBy(() -> OpenSaml5AssertingPartyMetadataRepository - .withTrustedMetadataLocation(server.url("/").toString()) - .verificationCredentials((c) -> c.add(credential)) - .build()); - } + String endpoint = "/" + UUID.randomUUID().toString(); + dispatcher.addResponse(endpoint, serialized); + assertThatExceptionOfType(Saml2Exception.class).isThrownBy(() -> OpenSaml5AssertingPartyMetadataRepository + .withTrustedMetadataLocation(web.url(endpoint).toString()) + .verificationCredentials((c) -> c.add(credential)) + .build()); } @Test @@ -326,14 +336,13 @@ public class OpenSaml5AssertingPartyMetadataRepositoryTests { String serialized = serialize(descriptor); Credential credential = TestOpenSamlObjects .getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID()); - try (MockWebServer server = new MockWebServer()) { - enqueue(server, serialized, 3); - AssertingPartyMetadataRepository parties = OpenSaml5AssertingPartyMetadataRepository - .withMetadataLocation(server.url("/").toString()) - .verificationCredentials((c) -> c.add(credential)) - .build(); - assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); - } + String endpoint = "/" + UUID.randomUUID().toString(); + dispatcher.addResponse(endpoint, serialized); + AssertingPartyMetadataRepository parties = OpenSaml5AssertingPartyMetadataRepository + .withMetadataLocation(web.url(endpoint).toString()) + .verificationCredentials((c) -> c.add(credential)) + .build(); + assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); } private static String serialize(XMLObject object) { @@ -353,4 +362,28 @@ public class OpenSaml5AssertingPartyMetadataRepositoryTests { } } + private static final class MetadataDispatcher extends Dispatcher { + + private final MockResponse head = new MockResponse(); + + private final Map responses = new ConcurrentHashMap<>(); + + private MetadataDispatcher() { + } + + @Override + public MockResponse dispatch(RecordedRequest request) throws InterruptedException { + if ("HEAD".equals(request.getMethod())) { + return this.head; + } + return this.responses.get(request.getPath()); + } + + private MetadataDispatcher addResponse(String path, String body) { + this.responses.put(path, new MockResponse().setBody(body).setResponseCode(200)); + return this; + } + + } + }