From 7b398006068778ad67cd738d8a51df0765414c5f Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Mon, 1 Jul 2024 19:10:57 -0600 Subject: [PATCH] Add CachingRelyingPartyRegistrationRepository Closes gh-15341 --- .../pages/servlet/saml2/login/overview.adoc | 51 ++++++++++ ...ingRelyingPartyRegistrationRepository.java | 95 +++++++++++++++++++ ...lyingPartyRegistrationRepositoryTests.java | 81 ++++++++++++++++ 3 files changed, 227 insertions(+) create mode 100644 saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/CachingRelyingPartyRegistrationRepository.java create mode 100644 saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/CachingRelyingPartyRegistrationRepositoryTests.java diff --git a/docs/modules/ROOT/pages/servlet/saml2/login/overview.adoc b/docs/modules/ROOT/pages/servlet/saml2/login/overview.adoc index 2ab0b7554c..0721d9b15e 100644 --- a/docs/modules/ROOT/pages/servlet/saml2/login/overview.adoc +++ b/docs/modules/ROOT/pages/servlet/saml2/login/overview.adoc @@ -588,6 +588,57 @@ class MyCustomSecurityConfiguration { A relying party can be multi-tenant by registering more than one relying party in the `RelyingPartyRegistrationRepository`. ==== +[[servlet-saml2login-relyingpartyregistrationrepository-caching]] +If you want your metadata to be refreshable on a periodic basis, you can wrap your repository in `CachingRelyingPartyRegistrationRepository` like so: + +.Caching Relying Party Registration Repository +[tabs] +====== +Java:: ++ +[source,java,role="primary"] +---- +@Configuration +@EnableWebSecurity +public class MyCustomSecurityConfiguration { + @Bean + public RelyingPartyRegistrationRepository registrations(CacheManager cacheManager) { + Supplier delegate = () -> + new InMemoryRelyingPartyRegistrationRepository(RelyingPartyRegistrations + .fromMetadataLocation("https://idp.example.org/ap/metadata") + .registrationId("ap").build()); + CachingRelyingPartyRegistrationRepository registrations = + new CachingRelyingPartyRegistrationRepository(delegate); + registrations.setCache(cacheManager.getCache("my-cache-name")); + return registrations; + } +} +---- + +Kotlin:: ++ +[source,kotlin,role="secondary"] +---- +@Configuration +@EnableWebSecurity +class MyCustomSecurityConfiguration { + @Bean + fun registrations(cacheManager: CacheManager): RelyingPartyRegistrationRepository { + val delegate = Supplier { + InMemoryRelyingPartyRegistrationRepository(RelyingPartyRegistrations + .fromMetadataLocation("https://idp.example.org/ap/metadata") + .registrationId("ap").build()) + } + val registrations = CachingRelyingPartyRegistrationRepository(delegate) + registrations.setCache(cacheManager.getCache("my-cache-name")) + return registrations + } +} +---- +====== + +In this way, the set of `RelyingPartyRegistration`s will refresh based on {spring-framework-reference-url}integration/cache/store-configuration.html[the cache's eviction schedule]. + [[servlet-saml2login-relyingpartyregistration]] == RelyingPartyRegistration A {security-api-url}org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.html[`RelyingPartyRegistration`] diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/CachingRelyingPartyRegistrationRepository.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/CachingRelyingPartyRegistrationRepository.java new file mode 100644 index 0000000000..bfc39e486f --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/CachingRelyingPartyRegistrationRepository.java @@ -0,0 +1,95 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.registration; + +import java.util.Iterator; +import java.util.Spliterator; +import java.util.concurrent.Callable; +import java.util.function.Consumer; + +import org.springframework.cache.Cache; +import org.springframework.cache.concurrent.ConcurrentMapCache; +import org.springframework.util.Assert; + +/** + * An {@link IterableRelyingPartyRegistrationRepository} that lazily queries and caches + * metadata from a backing {@link IterableRelyingPartyRegistrationRepository}. Delegates + * caching policies to Spring Cache. + * + * @author Josh Cummings + * @since 6.4 + */ +public final class CachingRelyingPartyRegistrationRepository implements IterableRelyingPartyRegistrationRepository { + + private final Callable registrationLoader; + + private Cache cache = new ConcurrentMapCache("registrations"); + + public CachingRelyingPartyRegistrationRepository(Callable loader) { + this.registrationLoader = loader; + } + + /** + * {@inheritDoc} + */ + @Override + public Iterator iterator() { + return registrations().iterator(); + } + + /** + * {@inheritDoc} + */ + @Override + public RelyingPartyRegistration findByRegistrationId(String registrationId) { + return registrations().findByRegistrationId(registrationId); + } + + @Override + public RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) { + return registrations().findUniqueByAssertingPartyEntityId(entityId); + } + + @Override + public void forEach(Consumer action) { + registrations().forEach(action); + } + + @Override + public Spliterator spliterator() { + return registrations().spliterator(); + } + + private IterableRelyingPartyRegistrationRepository registrations() { + return this.cache.get("registrations", this.registrationLoader); + } + + /** + * Use this cache for the completed {@link RelyingPartyRegistration} instances. + * + *

+ * Defaults to {@link ConcurrentMapCache}, meaning that the registrations are cached + * without expiry. To turn off the cache, use + * {@link org.springframework.cache.support.NoOpCache}. + * @param cache the {@link Cache} to use + */ + public void setCache(Cache cache) { + Assert.notNull(cache, "cache cannot be null"); + this.cache = cache; + } + +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/CachingRelyingPartyRegistrationRepositoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/CachingRelyingPartyRegistrationRepositoryTests.java new file mode 100644 index 0000000000..7e4d57d444 --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/CachingRelyingPartyRegistrationRepositoryTests.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.registration; + +import java.util.concurrent.Callable; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.cache.Cache; + +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Tests for {@link CachingRelyingPartyRegistrationRepository} + */ +@ExtendWith(MockitoExtension.class) +public class CachingRelyingPartyRegistrationRepositoryTests { + + @Mock + Callable> callable; + + @InjectMocks + CachingRelyingPartyRegistrationRepository registrations; + + @Test + public void iteratorWhenResolvableThenPopulatesCache() throws Exception { + given(this.callable.call()).willReturn(mock(IterableRelyingPartyRegistrationRepository.class)); + this.registrations.iterator(); + verify(this.callable).call(); + this.registrations.iterator(); + verifyNoMoreInteractions(this.callable); + } + + @Test + public void iteratorWhenExceptionThenPropagates() throws Exception { + given(this.callable.call()).willThrow(IllegalStateException.class); + assertThatExceptionOfType(Cache.ValueRetrievalException.class).isThrownBy(this.registrations::iterator) + .withCauseInstanceOf(IllegalStateException.class); + } + + @Test + public void findByRegistrationIdWhenResolvableThenPopulatesCache() throws Exception { + given(this.callable.call()).willReturn(mock(IterableRelyingPartyRegistrationRepository.class)); + this.registrations.findByRegistrationId("id"); + verify(this.callable).call(); + this.registrations.findByRegistrationId("id"); + verifyNoMoreInteractions(this.callable); + } + + @Test + public void findUniqueByAssertingPartyEntityIdWhenResolvableThenPopulatesCache() throws Exception { + given(this.callable.call()).willReturn(mock(IterableRelyingPartyRegistrationRepository.class)); + this.registrations.findUniqueByAssertingPartyEntityId("id"); + verify(this.callable).call(); + this.registrations.findUniqueByAssertingPartyEntityId("id"); + verifyNoMoreInteractions(this.callable); + } + +}