diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java index 6987f42464..8acc64bf97 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java @@ -16,8 +16,6 @@ package org.springframework.security.saml2.provider.service.web; -import java.util.function.Function; - import jakarta.servlet.http.HttpServletRequest; import org.springframework.http.HttpMethod; @@ -43,7 +41,7 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver; - private Function loader; + private Saml2AuthenticationRequestRepository authenticationRequestRepository; /** * Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for @@ -54,12 +52,13 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo public Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null"); this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; - this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest; + this.authenticationRequestRepository = new HttpSessionSaml2AuthenticationRequestRepository(); } @Override public Saml2AuthenticationToken convert(HttpServletRequest request) { - AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request); + AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository + .loadAuthenticationRequest(request); String relyingPartyRegistrationId = (authenticationRequest != null) ? authenticationRequest.getRelyingPartyRegistrationId() : null; RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request, @@ -84,11 +83,7 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo public void setAuthenticationRequestRepository( Saml2AuthenticationRequestRepository authenticationRequestRepository) { Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null"); - this.loader = authenticationRequestRepository::loadAuthenticationRequest; - } - - private AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) { - return this.loader.apply(request); + this.authenticationRequestRepository = authenticationRequestRepository; } private String decode(HttpServletRequest request) { diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2WebSsoAuthenticationFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2WebSsoAuthenticationFilter.java index 2e85ce082e..6c9ed2dc13 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2WebSsoAuthenticationFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2WebSsoAuthenticationFilter.java @@ -29,7 +29,6 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.HttpSessionSaml2AuthenticationRequestRepository; -import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter; import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; @@ -77,9 +76,7 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce public Saml2WebSsoAuthenticationFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, String filterProcessesUrl) { this(new Saml2AuthenticationTokenConverter( - (RelyingPartyRegistrationResolver) new DefaultRelyingPartyRegistrationResolver( - relyingPartyRegistrationRepository)), - filterProcessesUrl); + new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)), filterProcessesUrl); Assert.isTrue(filterProcessesUrl.contains("{registrationId}"), "filterProcessesUrl must contain a {registrationId} match variable"); } @@ -159,9 +156,9 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce } private void setDetails(HttpServletRequest request, Authentication authentication) { - if (AbstractAuthenticationToken.class.isAssignableFrom(authentication.getClass())) { + if (authentication instanceof AbstractAuthenticationToken token) { Object details = this.authenticationDetailsSource.buildDetails(request); - ((AbstractAuthenticationToken) authentication).setDetails(details); + token.setDetails(details); } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2WebSsoAuthenticationFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2WebSsoAuthenticationFilterTests.java index 98cf1765df..3ef96f8481 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2WebSsoAuthenticationFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2WebSsoAuthenticationFilterTests.java @@ -16,6 +16,7 @@ package org.springframework.security.saml2.provider.service.web.authentication; +import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -94,16 +95,18 @@ public class Saml2WebSsoAuthenticationFilterTests { @Test public void requiresAuthenticationWhenHappyPathThenReturnsTrue() { - assertThat(this.filter.requiresAuthentication(this.request, this.response)).isTrue(); + RequiresAuthenticationExposingFilter filter = new RequiresAuthenticationExposingFilter(this.repository); + assertThat(filter.requiresAuthentication(this.request, this.response)).isTrue(); } @Test public void requiresAuthenticationWhenCustomProcessingUrlThenReturnsTrue() { - this.filter = new Saml2WebSsoAuthenticationFilter(this.repository, "/some/other/path/{registrationId}"); + RequiresAuthenticationExposingFilter filter = new RequiresAuthenticationExposingFilter(this.repository, + "/some/other/path/{registrationId}"); this.request.setRequestURI("/some/other/path/idp-registration-id"); this.request.setPathInfo("/some/other/path/idp-registration-id"); this.request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "xml-data-goes-here"); - assertThat(this.filter.requiresAuthentication(this.request, this.response)).isTrue(); + assertThat(filter.requiresAuthentication(this.request, this.response)).isTrue(); } @Test @@ -212,4 +215,21 @@ public class Saml2WebSsoAuthenticationFilterTests { verify(this.repository).findByRegistrationId("registration-id"); } + static final class RequiresAuthenticationExposingFilter extends Saml2WebSsoAuthenticationFilter { + + RequiresAuthenticationExposingFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { + super(relyingPartyRegistrationRepository); + } + + RequiresAuthenticationExposingFilter(RelyingPartyRegistrationRepository registrations, String url) { + super(registrations, url); + } + + @Override + protected boolean requiresAuthentication(HttpServletRequest request, HttpServletResponse response) { + return super.requiresAuthentication(request, response); + } + + } + }