diff --git a/docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc b/docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc index 2798eea564..29505f3024 100644 --- a/docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc +++ b/docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc @@ -83,6 +83,46 @@ open fun authenticationRequestRepository(): Saml2AuthenticationRequestRepository ---- ====== +=== Caching the `` by the Relay State + +If you don't want to use the session to store the ``, you can also store it in a distributed cache. +This can be helpful if you are trying to use `SameSite=Strict` and are losing the authentication request in the redirect from the Identity Provider. + +[NOTE] +===== +It's important to remember that there are security benefits to storing it in the session. +One such benefit is the natural login fixation defense it provides. +For example, if an application looks the authentication request up from the session, then even if an attacker provides their own SAML response to a victim, the login will fail. + +On the other hand, if we trust the InResponseTo or RelayState to retrieve the authentication request, then there's no way to know if the SAML response was requested by that handshake. +===== + +To help with this, Spring Security has `CacheSaml2AuthenticationRequestRepository`, which you can publish as a bean for the filter chain to pick up: + +[tabs] +====== +Java:: ++ +[source,java,role="primary"] +---- +@Bean +Saml2AuthenticationRequestRepository authenticationRequestRepository() { + return new CacheSaml2AuthenticationRequestRepository(); +} +---- + +Kotlin:: ++ +[source,kotlin,role="secondary"] +---- +@Bean +fun authenticationRequestRepository(): Saml2AuthenticationRequestRepository<*> { + return CacheSaml2AuthenticationRequestRepository() +} +---- +====== + + [[servlet-saml2login-sp-initiated-factory-signing]] == Changing How the `` Gets Sent diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/CacheSaml2AuthenticationRequestRepository.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/CacheSaml2AuthenticationRequestRepository.java new file mode 100644 index 0000000000..d5813548bd --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/CacheSaml2AuthenticationRequestRepository.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import org.springframework.cache.Cache; +import org.springframework.cache.concurrent.ConcurrentMapCache; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; +import org.springframework.util.Assert; + +/** + * A cache-based {@link Saml2AuthenticationRequestRepository}. This can be handy when you + * are dropping requests due to using SameSite=Strict and the previous session is lost. + * + *

+ * On the other hand, this presents a tradeoff where the application can only tell that + * the given authentication request was created by this application, but cannot guarantee + * that it was for the user trying to log in. Please see the reference for details. + * + * @author Josh Cummings + * @since 6.5 + */ +public final class CacheSaml2AuthenticationRequestRepository + implements Saml2AuthenticationRequestRepository { + + private Cache cache = new ConcurrentMapCache("authentication-requests"); + + @Override + public AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) { + String relayState = request.getParameter(Saml2ParameterNames.RELAY_STATE); + Assert.notNull(relayState, "relayState must not be null"); + return this.cache.get(relayState, AbstractSaml2AuthenticationRequest.class); + } + + @Override + public void saveAuthenticationRequest(AbstractSaml2AuthenticationRequest authenticationRequest, + HttpServletRequest request, HttpServletResponse response) { + String relayState = request.getParameter(Saml2ParameterNames.RELAY_STATE); + Assert.notNull(relayState, "relayState must not be null"); + this.cache.put(relayState, authenticationRequest); + } + + @Override + public AbstractSaml2AuthenticationRequest removeAuthenticationRequest(HttpServletRequest request, + HttpServletResponse response) { + String relayState = request.getParameter(Saml2ParameterNames.RELAY_STATE); + Assert.notNull(relayState, "relayState must not be null"); + AbstractSaml2AuthenticationRequest authenticationRequest = this.cache.get(relayState, + AbstractSaml2AuthenticationRequest.class); + if (authenticationRequest == null) { + return null; + } + this.cache.evict(relayState); + return authenticationRequest; + } + + /** + * Use this {@link Cache} instance. The default is an in-memory cache, which means it + * won't work in a clustered environment. Instead, replace it here with a distributed + * cache. + * @param cache the {@link Cache} instance to use + */ + public void setCache(Cache cache) { + this.cache = cache; + } + +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/CacheSaml2AuthenticationRequestRepositoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/CacheSaml2AuthenticationRequestRepositoryTests.java new file mode 100644 index 0000000000..3bd35652f4 --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/CacheSaml2AuthenticationRequestRepositoryTests.java @@ -0,0 +1,91 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import org.junit.jupiter.api.Test; + +import org.springframework.cache.Cache; +import org.springframework.cache.concurrent.ConcurrentMapCache; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; +import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; +import org.springframework.security.saml2.provider.service.authentication.TestSaml2PostAuthenticationRequests; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link CacheSaml2AuthenticationRequestRepository} + */ +class CacheSaml2AuthenticationRequestRepositoryTests { + + CacheSaml2AuthenticationRequestRepository repository = new CacheSaml2AuthenticationRequestRepository(); + + @Test + void loadAuthenticationRequestWhenCachedThenReturns() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setParameter(Saml2ParameterNames.RELAY_STATE, "test"); + Saml2PostAuthenticationRequest authenticationRequest = TestSaml2PostAuthenticationRequests.create(); + this.repository.saveAuthenticationRequest(authenticationRequest, request, null); + assertThat(this.repository.loadAuthenticationRequest(request)).isEqualTo(authenticationRequest); + this.repository.removeAuthenticationRequest(request, null); + assertThat(this.repository.loadAuthenticationRequest(request)).isNull(); + } + + @Test + void loadAuthenticationRequestWhenNoRelayStateThenException() { + MockHttpServletRequest request = new MockHttpServletRequest(); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.repository.loadAuthenticationRequest(request)); + } + + @Test + void saveAuthenticationRequestWhenNoRelayStateThenException() { + MockHttpServletRequest request = new MockHttpServletRequest(); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.repository.saveAuthenticationRequest(null, request, null)); + } + + @Test + void removeAuthenticationRequestWhenNoRelayStateThenException() { + MockHttpServletRequest request = new MockHttpServletRequest(); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.repository.removeAuthenticationRequest(request, null)); + } + + @Test + void repositoryWhenCustomCacheThenUses() { + CacheSaml2AuthenticationRequestRepository repository = new CacheSaml2AuthenticationRequestRepository(); + Cache cache = spy(new ConcurrentMapCache("requests")); + repository.setCache(cache); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setParameter(Saml2ParameterNames.RELAY_STATE, "test"); + Saml2PostAuthenticationRequest authenticationRequest = TestSaml2PostAuthenticationRequests.create(); + repository.saveAuthenticationRequest(authenticationRequest, request, null); + verify(cache).put(eq("test"), any()); + repository.loadAuthenticationRequest(request); + verify(cache).get("test", AbstractSaml2AuthenticationRequest.class); + repository.removeAuthenticationRequest(request, null); + verify(cache).evict("test"); + } + +}