Check for an existing SessionRegistry bean

If a SessionRegistry is necessary, check for one in the ApplicationContext before creating one.
This commit is contained in:
Craig Andrews 2020-05-14 15:02:44 -04:00 committed by Rob Winch
parent 0fa339f75b
commit dbdeec4216
2 changed files with 89 additions and 0 deletions

View File

@ -21,6 +21,7 @@ import java.util.List;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession; import javax.servlet.http.HttpSession;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationListener; import org.springframework.context.ApplicationListener;
import org.springframework.context.event.GenericApplicationListenerAdapter; import org.springframework.context.event.GenericApplicationListenerAdapter;
@ -678,6 +679,9 @@ public final class SessionManagementConfigurer<H extends HttpSecurityBuilder<H>>
} }
private SessionRegistry getSessionRegistry(H http) { private SessionRegistry getSessionRegistry(H http) {
if (this.sessionRegistry == null) {
this.sessionRegistry = getBeanOrNull(SessionRegistry.class);
}
if (this.sessionRegistry == null) { if (this.sessionRegistry == null) {
SessionRegistryImpl sessionRegistry = new SessionRegistryImpl(); SessionRegistryImpl sessionRegistry = new SessionRegistryImpl();
registerDelegateApplicationListener(http, sessionRegistry); registerDelegateApplicationListener(http, sessionRegistry);
@ -717,4 +721,17 @@ public final class SessionManagementConfigurer<H extends HttpSecurityBuilder<H>>
private static SessionAuthenticationStrategy createDefaultSessionFixationProtectionStrategy() { private static SessionAuthenticationStrategy createDefaultSessionFixationProtectionStrategy() {
return new ChangeSessionIdAuthenticationStrategy(); return new ChangeSessionIdAuthenticationStrategy();
} }
private <T> T getBeanOrNull(Class<T> type) {
ApplicationContext context = getBuilder().getSharedObject(ApplicationContext.class);
if (context == null) {
return null;
}
try {
return context.getBean(type);
}
catch (NoSuchBeanDefinitionException e) {
return null;
}
}
} }

View File

@ -30,6 +30,7 @@ import org.springframework.security.config.annotation.web.configuration.WebSecur
import org.springframework.security.config.http.SessionCreationPolicy; import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.session.SessionRegistry;
import org.springframework.security.core.userdetails.PasswordEncodedUser; import org.springframework.security.core.userdetails.PasswordEncodedUser;
import org.springframework.security.web.authentication.session.ChangeSessionIdAuthenticationStrategy; import org.springframework.security.web.authentication.session.ChangeSessionIdAuthenticationStrategy;
import org.springframework.security.web.authentication.session.CompositeSessionAuthenticationStrategy; import org.springframework.security.web.authentication.session.CompositeSessionAuthenticationStrategy;
@ -53,6 +54,7 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
@ -483,4 +485,74 @@ public class SessionManagementConfigurerTests {
// @formatter:on // @formatter:on
} }
} }
@Test
public void whenOneSessionRegistryBeanThenUseIt() throws Exception {
SessionRegistryOneBeanConfig.SESSION_REGISTRY = mock(SessionRegistry.class);
this.spring.register(SessionRegistryOneBeanConfig.class).autowire();
MockHttpSession session = new MockHttpSession(this.spring.getContext().getServletContext());
this.mvc.perform(get("/").session(session));
verify(SessionRegistryOneBeanConfig.SESSION_REGISTRY)
.getSessionInformation(session.getId());
}
@Test
public void whenTwoSessionRegistryBeansThenUseNeither() throws Exception {
SessionRegistryTwoBeansConfig.SESSION_REGISTRY_ONE = mock(SessionRegistry.class);
SessionRegistryTwoBeansConfig.SESSION_REGISTRY_TWO = mock(SessionRegistry.class);
this.spring.register(SessionRegistryTwoBeansConfig.class).autowire();
MockHttpSession session = new MockHttpSession(this.spring.getContext().getServletContext());
this.mvc.perform(get("/").session(session));
verifyNoInteractions(SessionRegistryTwoBeansConfig.SESSION_REGISTRY_ONE);
verifyNoInteractions(SessionRegistryTwoBeansConfig.SESSION_REGISTRY_TWO);
}
@EnableWebSecurity
static class SessionRegistryOneBeanConfig extends WebSecurityConfigurerAdapter {
private static SessionRegistry SESSION_REGISTRY;
@Override
protected void configure(HttpSecurity http) throws Exception {
// @formatter:off
http
.sessionManagement()
.maximumSessions(1);
// @formatter:on
}
@Bean
public SessionRegistry sessionRegistry() {
return SESSION_REGISTRY;
}
}
@EnableWebSecurity
static class SessionRegistryTwoBeansConfig extends WebSecurityConfigurerAdapter {
private static SessionRegistry SESSION_REGISTRY_ONE;
private static SessionRegistry SESSION_REGISTRY_TWO;
@Override
protected void configure(HttpSecurity http) throws Exception {
// @formatter:off
http
.sessionManagement()
.maximumSessions(1);
// @formatter:on
}
@Bean
public SessionRegistry sessionRegistryOne() {
return SESSION_REGISTRY_ONE;
}
@Bean
public SessionRegistry sessionRegistryTwo() {
return SESSION_REGISTRY_TWO;
}
}
} }