Add CacheSaml2AuthenticationRequestRepository

Closes gh-14793
This commit is contained in:
Josh Cummings 2025-04-03 17:10:11 -06:00
parent 8cbe02e3aa
commit a283700ef8
No known key found for this signature in database
GPG Key ID: 869B37A20E876129
3 changed files with 215 additions and 0 deletions

View File

@ -83,6 +83,46 @@ open fun authenticationRequestRepository(): Saml2AuthenticationRequestRepository
----
======
=== Caching the `<saml2:AuthnRequest>` by the Relay State
If you don't want to use the session to store the `<saml2:AuthnRequest>`, you can also store it in a distributed cache.
This can be helpful if you are trying to use `SameSite=Strict` and are losing the authentication request in the redirect from the Identity Provider.
[NOTE]
=====
It's important to remember that there are security benefits to storing it in the session.
One such benefit is the natural login fixation defense it provides.
For example, if an application looks the authentication request up from the session, then even if an attacker provides their own SAML response to a victim, the login will fail.
On the other hand, if we trust the InResponseTo or RelayState to retrieve the authentication request, then there's no way to know if the SAML response was requested by that handshake.
=====
To help with this, Spring Security has `CacheSaml2AuthenticationRequestRepository`, which you can publish as a bean for the filter chain to pick up:
[tabs]
======
Java::
+
[source,java,role="primary"]
----
@Bean
Saml2AuthenticationRequestRepository<?> authenticationRequestRepository() {
return new CacheSaml2AuthenticationRequestRepository();
}
----
Kotlin::
+
[source,kotlin,role="secondary"]
----
@Bean
fun authenticationRequestRepository(): Saml2AuthenticationRequestRepository<*> {
return CacheSaml2AuthenticationRequestRepository()
}
----
======
[[servlet-saml2login-sp-initiated-factory-signing]]
== Changing How the `<saml2:AuthnRequest>` Gets Sent

View File

@ -0,0 +1,84 @@
/*
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.web;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.cache.Cache;
import org.springframework.cache.concurrent.ConcurrentMapCache;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.util.Assert;
/**
* A cache-based {@link Saml2AuthenticationRequestRepository}. This can be handy when you
* are dropping requests due to using SameSite=Strict and the previous session is lost.
*
* <p>
* On the other hand, this presents a tradeoff where the application can only tell that
* the given authentication request was created by this application, but cannot guarantee
* that it was for the user trying to log in. Please see the reference for details.
*
* @author Josh Cummings
* @since 6.5
*/
public final class CacheSaml2AuthenticationRequestRepository
implements Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> {
private Cache cache = new ConcurrentMapCache("authentication-requests");
@Override
public AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
String relayState = request.getParameter(Saml2ParameterNames.RELAY_STATE);
Assert.notNull(relayState, "relayState must not be null");
return this.cache.get(relayState, AbstractSaml2AuthenticationRequest.class);
}
@Override
public void saveAuthenticationRequest(AbstractSaml2AuthenticationRequest authenticationRequest,
HttpServletRequest request, HttpServletResponse response) {
String relayState = request.getParameter(Saml2ParameterNames.RELAY_STATE);
Assert.notNull(relayState, "relayState must not be null");
this.cache.put(relayState, authenticationRequest);
}
@Override
public AbstractSaml2AuthenticationRequest removeAuthenticationRequest(HttpServletRequest request,
HttpServletResponse response) {
String relayState = request.getParameter(Saml2ParameterNames.RELAY_STATE);
Assert.notNull(relayState, "relayState must not be null");
AbstractSaml2AuthenticationRequest authenticationRequest = this.cache.get(relayState,
AbstractSaml2AuthenticationRequest.class);
if (authenticationRequest == null) {
return null;
}
this.cache.evict(relayState);
return authenticationRequest;
}
/**
* Use this {@link Cache} instance. The default is an in-memory cache, which means it
* won't work in a clustered environment. Instead, replace it here with a distributed
* cache.
* @param cache the {@link Cache} instance to use
*/
public void setCache(Cache cache) {
this.cache = cache;
}
}

View File

@ -0,0 +1,91 @@
/*
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.web;
import org.junit.jupiter.api.Test;
import org.springframework.cache.Cache;
import org.springframework.cache.concurrent.ConcurrentMapCache;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.TestSaml2PostAuthenticationRequests;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link CacheSaml2AuthenticationRequestRepository}
*/
class CacheSaml2AuthenticationRequestRepositoryTests {
CacheSaml2AuthenticationRequestRepository repository = new CacheSaml2AuthenticationRequestRepository();
@Test
void loadAuthenticationRequestWhenCachedThenReturns() {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setParameter(Saml2ParameterNames.RELAY_STATE, "test");
Saml2PostAuthenticationRequest authenticationRequest = TestSaml2PostAuthenticationRequests.create();
this.repository.saveAuthenticationRequest(authenticationRequest, request, null);
assertThat(this.repository.loadAuthenticationRequest(request)).isEqualTo(authenticationRequest);
this.repository.removeAuthenticationRequest(request, null);
assertThat(this.repository.loadAuthenticationRequest(request)).isNull();
}
@Test
void loadAuthenticationRequestWhenNoRelayStateThenException() {
MockHttpServletRequest request = new MockHttpServletRequest();
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> this.repository.loadAuthenticationRequest(request));
}
@Test
void saveAuthenticationRequestWhenNoRelayStateThenException() {
MockHttpServletRequest request = new MockHttpServletRequest();
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> this.repository.saveAuthenticationRequest(null, request, null));
}
@Test
void removeAuthenticationRequestWhenNoRelayStateThenException() {
MockHttpServletRequest request = new MockHttpServletRequest();
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> this.repository.removeAuthenticationRequest(request, null));
}
@Test
void repositoryWhenCustomCacheThenUses() {
CacheSaml2AuthenticationRequestRepository repository = new CacheSaml2AuthenticationRequestRepository();
Cache cache = spy(new ConcurrentMapCache("requests"));
repository.setCache(cache);
MockHttpServletRequest request = new MockHttpServletRequest();
request.setParameter(Saml2ParameterNames.RELAY_STATE, "test");
Saml2PostAuthenticationRequest authenticationRequest = TestSaml2PostAuthenticationRequests.create();
repository.saveAuthenticationRequest(authenticationRequest, request, null);
verify(cache).put(eq("test"), any());
repository.loadAuthenticationRequest(request);
verify(cache).get("test", AbstractSaml2AuthenticationRequest.class);
repository.removeAuthenticationRequest(request, null);
verify(cache).evict("test");
}
}