diff --git a/web/src/main/java/org/springframework/security/web/server/context/WebSessionServerSecurityContextRepository.java b/web/src/main/java/org/springframework/security/web/server/context/WebSessionServerSecurityContextRepository.java index 3bd3ed8b8a..487ef70b2b 100644 --- a/web/src/main/java/org/springframework/security/web/server/context/WebSessionServerSecurityContextRepository.java +++ b/web/src/main/java/org/springframework/security/web/server/context/WebSessionServerSecurityContextRepository.java @@ -46,6 +46,8 @@ public class WebSessionServerSecurityContextRepository implements ServerSecurity private String springSecurityContextAttrName = DEFAULT_SPRING_SECURITY_CONTEXT_ATTR_NAME; + private boolean cacheSecurityContext; + /** * Sets the session attribute name used to save and load the {@link SecurityContext} * @param springSecurityContextAttrName the session attribute name to use to save and @@ -56,6 +58,16 @@ public class WebSessionServerSecurityContextRepository implements ServerSecurity this.springSecurityContextAttrName = springSecurityContextAttrName; } + /** + * If set to true the result of {@link #load(ServerWebExchange)} will use + * {@link Mono#cache()} to prevent multiple lookups. + * @param cacheSecurityContext true if {@link Mono#cache()} should be used, else + * false. + */ + public void setCacheSecurityContext(boolean cacheSecurityContext) { + this.cacheSecurityContext = cacheSecurityContext; + } + @Override public Mono save(ServerWebExchange exchange, SecurityContext context) { return exchange.getSession().doOnNext((session) -> { @@ -72,13 +84,14 @@ public class WebSessionServerSecurityContextRepository implements ServerSecurity @Override public Mono load(ServerWebExchange exchange) { - return exchange.getSession().flatMap((session) -> { + Mono result = exchange.getSession().flatMap((session) -> { SecurityContext context = (SecurityContext) session.getAttribute(this.springSecurityContextAttrName); logger.debug((context != null) ? LogMessage.format("Found SecurityContext '%s' in WebSession: '%s'", context, session) : LogMessage.format("No SecurityContext found in WebSession: '%s'", session)); return Mono.justOrEmpty(context); }); + return (cacheSecurityContext) ? result.cache() : result; } } diff --git a/web/src/test/java/org/springframework/security/web/server/context/WebSessionServerSecurityContextRepositoryTests.java b/web/src/test/java/org/springframework/security/web/server/context/WebSessionServerSecurityContextRepositoryTests.java index f4af6f74f2..aa372e69fb 100644 --- a/web/src/test/java/org/springframework/security/web/server/context/WebSessionServerSecurityContextRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/server/context/WebSessionServerSecurityContextRepositoryTests.java @@ -17,14 +17,19 @@ package org.springframework.security.web.server.context; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.publisher.PublisherProbe; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextImpl; +import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; /** * @author Rob Winch @@ -79,4 +84,25 @@ public class WebSessionServerSecurityContextRepositoryTests { assertThat(context).isNull(); } + @Test + public void loadWhenCacheSecurityContextThenSubscribeOnce() { + PublisherProbe webSession = PublisherProbe.empty(); + ServerWebExchange exchange = mock(ServerWebExchange.class); + given(exchange.getSession()).willReturn(webSession.mono()); + this.repository.setCacheSecurityContext(true); + Mono context = this.repository.load(exchange); + assertThat(context.block()).isSameAs(context.block()); + assertThat(webSession.subscribeCount()).isEqualTo(1); + } + + @Test + public void loadWhenNotCacheSecurityContextThenSubscribeMultiple() { + PublisherProbe webSession = PublisherProbe.empty(); + ServerWebExchange exchange = mock(ServerWebExchange.class); + given(exchange.getSession()).willReturn(webSession.mono()); + Mono context = this.repository.load(exchange); + assertThat(context.block()).isSameAs(context.block()); + assertThat(webSession.subscribeCount()).isEqualTo(2); + } + }