diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java index eb56460ecd..a4cfb815dc 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java @@ -32,8 +32,6 @@ import org.springframework.security.config.annotation.web.configurers.AbstractHt import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer; import org.springframework.security.core.Authentication; import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; -import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider; -import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; @@ -55,6 +53,7 @@ import org.springframework.security.web.authentication.ui.DefaultLoginPageGenera import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; /** @@ -112,6 +111,8 @@ import org.springframework.util.StringUtils; public final class Saml2LoginConfigurer> extends AbstractAuthenticationFilterConfigurer, Saml2WebSsoAuthenticationFilter> { + private static final String OPEN_SAML_4_VERSION = "4"; + private String loginPage; private String authenticationRequestUri = "/saml2/authenticate/{registrationId}"; @@ -320,11 +321,9 @@ public final class Saml2LoginConfigurer> return resolver; } if (version().startsWith("4")) { - return new OpenSaml4AuthenticationRequestFactory(); - } - else { - return new OpenSamlAuthenticationRequestFactory(); + return OpenSaml4LoginSupportFactory.getAuthenticationRequestFactory(); } + return new OpenSamlAuthenticationRequestFactory(); } private Saml2AuthenticationRequestContextResolver getAuthenticationRequestContextResolver(B http) { @@ -354,18 +353,9 @@ public final class Saml2LoginConfigurer> return authenticationConverterBean; } - private String version() { - String version = Version.getVersion(); - if (version != null) { - return version; - } - return Version.class.getModule().getDescriptor().version().map(Object::toString) - .orElseThrow(() -> new IllegalStateException("cannot determine OpenSAML version")); - } - private void registerDefaultAuthenticationProvider(B http) { if (version().startsWith("4")) { - http.authenticationProvider(postProcess(new OpenSaml4AuthenticationProvider())); + http.authenticationProvider(postProcess(OpenSaml4LoginSupportFactory.getAuthenticationProvider())); } else { http.authenticationProvider(postProcess(new OpenSamlAuthenticationProvider())); @@ -415,6 +405,19 @@ public final class Saml2LoginConfigurer> return repository; } + private String version() { + String version = Version.getVersion(); + if (StringUtils.hasText(version)) { + return version; + } + boolean openSaml4ClassPresent = ClassUtils + .isPresent("org.opensaml.core.xml.persist.impl.PassthroughSourceStrategy", null); + if (openSaml4ClassPresent) { + return OPEN_SAML_4_VERSION; + } + throw new IllegalStateException("cannot determine OpenSAML version"); + } + private C getSharedOrBean(B http, Class clazz) { C shared = http.getSharedObject(clazz); if (shared != null) { @@ -442,4 +445,33 @@ public final class Saml2LoginConfigurer> } } + private static class OpenSaml4LoginSupportFactory { + + private static Saml2AuthenticationRequestFactory getAuthenticationRequestFactory() { + try { + Class authenticationRequestFactory = ClassUtils.forName( + "org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory", + OpenSaml4LoginSupportFactory.class.getClassLoader()); + return (Saml2AuthenticationRequestFactory) authenticationRequestFactory.getDeclaredConstructor() + .newInstance(); + } + catch (ReflectiveOperationException ex) { + throw new IllegalStateException("Could not instantiate OpenSaml4AuthenticationRequestFactory", ex); + } + } + + private static AuthenticationProvider getAuthenticationProvider() { + try { + Class authenticationProvider = ClassUtils.forName( + "org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider", + OpenSaml4LoginSupportFactory.class.getClassLoader()); + return (AuthenticationProvider) authenticationProvider.getDeclaredConstructor().newInstance(); + } + catch (ReflectiveOperationException ex) { + throw new IllegalStateException("Could not instantiate OpenSaml4AuthenticationProvider", ex); + } + } + + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java index 45bd549c01..a5251bc94b 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,8 +47,6 @@ import org.springframework.security.saml2.provider.service.web.RelyingPartyRegis import org.springframework.security.saml2.provider.service.web.authentication.logout.HttpSessionLogoutRequestRepository; import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml3LogoutRequestResolver; import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml3LogoutResponseResolver; -import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutRequestResolver; -import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutResponseResolver; import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestFilter; import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestRepository; import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestResolver; @@ -67,6 +65,8 @@ import org.springframework.security.web.csrf.CsrfTokenRepository; import org.springframework.security.web.util.matcher.AndRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.ClassUtils; +import org.springframework.util.StringUtils; /** * Adds SAML 2.0 logout support. @@ -113,6 +113,8 @@ import org.springframework.security.web.util.matcher.RequestMatcher; public final class Saml2LogoutConfigurer> extends AbstractHttpConfigurer, H> { + private static final String OPEN_SAML_4_VERSION = "4"; + private ApplicationContext context; private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; @@ -304,6 +306,19 @@ public final class Saml2LogoutConfigurer> return this.logoutResponseConfigurer.logoutResponseResolver(relyingPartyRegistrationResolver); } + private String version() { + String version = Version.getVersion(); + if (StringUtils.hasText(version)) { + return version; + } + boolean openSaml4ClassPresent = ClassUtils + .isPresent("org.opensaml.core.xml.persist.impl.PassthroughSourceStrategy", null); + if (openSaml4ClassPresent) { + return OPEN_SAML_4_VERSION; + } + throw new IllegalStateException("cannot determine OpenSAML version"); + } + private C getBeanOrNull(Class clazz) { if (this.context == null) { return null; @@ -314,15 +329,6 @@ public final class Saml2LogoutConfigurer> return this.context.getBean(clazz); } - private String version() { - String version = Version.getVersion(); - if (version != null) { - return version; - } - return Version.class.getModule().getDescriptor().version().map(Object::toString) - .orElseThrow(() -> new IllegalStateException("cannot determine OpenSAML version")); - } - /** * A configurer for SAML 2.0 LogoutRequest components */ @@ -403,7 +409,7 @@ public final class Saml2LogoutConfigurer> return this.logoutRequestResolver; } if (version().startsWith("4")) { - return new OpenSaml4LogoutRequestResolver(relyingPartyRegistrationResolver); + return OpenSaml4LogoutSupportFactory.getLogoutRequestResolver(relyingPartyRegistrationResolver); } return new OpenSaml3LogoutRequestResolver(relyingPartyRegistrationResolver); } @@ -471,13 +477,13 @@ public final class Saml2LogoutConfigurer> private Saml2LogoutResponseResolver logoutResponseResolver( RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { - if (this.logoutResponseResolver == null) { - if (version().startsWith("4")) { - return new OpenSaml4LogoutResponseResolver(relyingPartyRegistrationResolver); - } - return new OpenSaml3LogoutResponseResolver(relyingPartyRegistrationResolver); + if (this.logoutResponseResolver != null) { + return this.logoutResponseResolver; } - return this.logoutResponseResolver; + if (version().startsWith("4")) { + return OpenSaml4LogoutSupportFactory.getLogoutResponseResolver(relyingPartyRegistrationResolver); + } + return new OpenSaml3LogoutResponseResolver(relyingPartyRegistrationResolver); } } @@ -520,4 +526,38 @@ public final class Saml2LogoutConfigurer> } + private static class OpenSaml4LogoutSupportFactory { + + private static Saml2LogoutResponseResolver getLogoutResponseResolver( + RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { + try { + Class logoutResponseResolver = ClassUtils.forName( + "org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutResponseResolver", + OpenSaml4LogoutSupportFactory.class.getClassLoader()); + return (Saml2LogoutResponseResolver) logoutResponseResolver + .getDeclaredConstructor(RelyingPartyRegistrationResolver.class) + .newInstance(relyingPartyRegistrationResolver); + } + catch (ReflectiveOperationException ex) { + throw new IllegalStateException("Could not instantiate OpenSaml4LogoutResponseResolver", ex); + } + } + + private static Saml2LogoutRequestResolver getLogoutRequestResolver( + RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { + try { + Class logoutRequestResolver = ClassUtils.forName( + "org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutRequestResolver", + OpenSaml4LogoutSupportFactory.class.getClassLoader()); + return (Saml2LogoutRequestResolver) logoutRequestResolver + .getDeclaredConstructor(RelyingPartyRegistrationResolver.class) + .newInstance(relyingPartyRegistrationResolver); + } + catch (ReflectiveOperationException ex) { + throw new IllegalStateException("Could not instantiate OpenSaml4LogoutRequestResolver", ex); + } + } + + } + }