From dbdeec4216ae09f8014e31c6bbf3cf4c3149e0eb Mon Sep 17 00:00:00 2001 From: Craig Andrews Date: Thu, 14 May 2020 15:02:44 -0400 Subject: [PATCH] Check for an existing SessionRegistry bean If a SessionRegistry is necessary, check for one in the ApplicationContext before creating one. --- .../SessionManagementConfigurer.java | 17 +++++ .../SessionManagementConfigurerTests.java | 72 +++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurer.java index e53be5d8bb..baf041e44a 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurer.java @@ -21,6 +21,7 @@ import java.util.List; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; +import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationListener; import org.springframework.context.event.GenericApplicationListenerAdapter; @@ -678,6 +679,9 @@ public final class SessionManagementConfigurer> } private SessionRegistry getSessionRegistry(H http) { + if (this.sessionRegistry == null) { + this.sessionRegistry = getBeanOrNull(SessionRegistry.class); + } if (this.sessionRegistry == null) { SessionRegistryImpl sessionRegistry = new SessionRegistryImpl(); registerDelegateApplicationListener(http, sessionRegistry); @@ -717,4 +721,17 @@ public final class SessionManagementConfigurer> private static SessionAuthenticationStrategy createDefaultSessionFixationProtectionStrategy() { return new ChangeSessionIdAuthenticationStrategy(); } + + private T getBeanOrNull(Class type) { + ApplicationContext context = getBuilder().getSharedObject(ApplicationContext.class); + if (context == null) { + return null; + } + try { + return context.getBean(type); + } + catch (NoSuchBeanDefinitionException e) { + return null; + } + } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerTests.java index a5e2cdf160..b1104f58d3 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerTests.java @@ -30,6 +30,7 @@ import org.springframework.security.config.annotation.web.configuration.WebSecur import org.springframework.security.config.http.SessionCreationPolicy; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.core.userdetails.PasswordEncodedUser; import org.springframework.security.web.authentication.session.ChangeSessionIdAuthenticationStrategy; import org.springframework.security.web.authentication.session.CompositeSessionAuthenticationStrategy; @@ -53,6 +54,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; @@ -483,4 +485,74 @@ public class SessionManagementConfigurerTests { // @formatter:on } } + + @Test + public void whenOneSessionRegistryBeanThenUseIt() throws Exception { + SessionRegistryOneBeanConfig.SESSION_REGISTRY = mock(SessionRegistry.class); + this.spring.register(SessionRegistryOneBeanConfig.class).autowire(); + + MockHttpSession session = new MockHttpSession(this.spring.getContext().getServletContext()); + this.mvc.perform(get("/").session(session)); + + verify(SessionRegistryOneBeanConfig.SESSION_REGISTRY) + .getSessionInformation(session.getId()); + } + + @Test + public void whenTwoSessionRegistryBeansThenUseNeither() throws Exception { + SessionRegistryTwoBeansConfig.SESSION_REGISTRY_ONE = mock(SessionRegistry.class); + SessionRegistryTwoBeansConfig.SESSION_REGISTRY_TWO = mock(SessionRegistry.class); + this.spring.register(SessionRegistryTwoBeansConfig.class).autowire(); + + MockHttpSession session = new MockHttpSession(this.spring.getContext().getServletContext()); + this.mvc.perform(get("/").session(session)); + + verifyNoInteractions(SessionRegistryTwoBeansConfig.SESSION_REGISTRY_ONE); + verifyNoInteractions(SessionRegistryTwoBeansConfig.SESSION_REGISTRY_TWO); + } + + @EnableWebSecurity + static class SessionRegistryOneBeanConfig extends WebSecurityConfigurerAdapter { + private static SessionRegistry SESSION_REGISTRY; + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .sessionManagement() + .maximumSessions(1); + // @formatter:on + } + + @Bean + public SessionRegistry sessionRegistry() { + return SESSION_REGISTRY; + } + } + + @EnableWebSecurity + static class SessionRegistryTwoBeansConfig extends WebSecurityConfigurerAdapter { + private static SessionRegistry SESSION_REGISTRY_ONE; + + private static SessionRegistry SESSION_REGISTRY_TWO; + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .sessionManagement() + .maximumSessions(1); + // @formatter:on + } + + @Bean + public SessionRegistry sessionRegistryOne() { + return SESSION_REGISTRY_ONE; + } + + @Bean + public SessionRegistry sessionRegistryTwo() { + return SESSION_REGISTRY_TWO; + } + } }