diff --git a/core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java b/core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java index cb415500ca..e64d9058d1 100644 --- a/core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java +++ b/core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java @@ -16,6 +16,8 @@ package org.springframework.security.core.context; +import java.util.function.Supplier; + import org.springframework.util.Assert; /** @@ -23,11 +25,12 @@ import org.springframework.util.Assert; * {@link org.springframework.security.core.context.SecurityContextHolderStrategy}. * * @author Ben Alex + * @author Rob Winch * @see java.lang.ThreadLocal */ final class InheritableThreadLocalSecurityContextHolderStrategy implements SecurityContextHolderStrategy { - private static final ThreadLocal contextHolder = new InheritableThreadLocal<>(); + private static final ThreadLocal> contextHolder = new InheritableThreadLocal<>(); @Override public void clearContext() { @@ -36,18 +39,35 @@ final class InheritableThreadLocalSecurityContextHolderStrategy implements Secur @Override public SecurityContext getContext() { - SecurityContext ctx = contextHolder.get(); - if (ctx == null) { - ctx = createEmptyContext(); - contextHolder.set(ctx); + return getDeferredContext().get(); + } + + @Override + public Supplier getDeferredContext() { + Supplier result = contextHolder.get(); + if (result == null) { + SecurityContext context = createEmptyContext(); + result = () -> context; + contextHolder.set(result); } - return ctx; + return result; } @Override public void setContext(SecurityContext context) { Assert.notNull(context, "Only non-null SecurityContext instances are permitted"); - contextHolder.set(context); + contextHolder.set(() -> context); + } + + @Override + public void setDeferredContext(Supplier deferredContext) { + Assert.notNull(deferredContext, "Only non-null Supplier instances are permitted"); + Supplier notNullDeferredContext = () -> { + SecurityContext result = deferredContext.get(); + Assert.notNull(result, "A Supplier returned null and is not allowed."); + return result; + }; + contextHolder.set(notNullDeferredContext); } @Override diff --git a/core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java b/core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java index 337fde3a57..de593d1fc0 100644 --- a/core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java +++ b/core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java @@ -17,6 +17,7 @@ package org.springframework.security.core.context; import java.lang.reflect.Constructor; +import java.util.function.Supplier; import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; @@ -46,6 +47,7 @@ import org.springframework.util.StringUtils; * {@link #MODE_GLOBAL} is definitely inappropriate for server use). * * @author Ben Alex + * @author Rob Winch * */ public class SecurityContextHolder { @@ -123,6 +125,16 @@ public class SecurityContextHolder { return strategy.getContext(); } + /** + * Obtains a {@link Supplier} that returns the current context. + * @return a {@link Supplier} that returns the current context (never + * null - create a default implementation if necessary) + * @since 5.8 + */ + public static Supplier getDeferredContext() { + return strategy.getDeferredContext(); + } + /** * Primarily for troubleshooting purposes, this method shows how many times the class * has re-initialized its SecurityContextHolderStrategy. @@ -143,6 +155,16 @@ public class SecurityContextHolder { strategy.setContext(context); } + /** + * Sets a {@link Supplier} that will return the current context. Implementations can + * override the default to avoid invoking {@link Supplier#get()}. + * @param deferredContext a {@link Supplier} that returns the {@link SecurityContext} + * @since 5.8 + */ + public static void setDeferredContext(Supplier deferredContext) { + strategy.setDeferredContext(deferredContext); + } + /** * Changes the preferred strategy. Do NOT call this method more than once for * a given JVM, as it will re-initialize the strategy and adversely affect any diff --git a/core/src/main/java/org/springframework/security/core/context/SecurityContextHolderStrategy.java b/core/src/main/java/org/springframework/security/core/context/SecurityContextHolderStrategy.java index 4954db70aa..aaf7def3c2 100644 --- a/core/src/main/java/org/springframework/security/core/context/SecurityContextHolderStrategy.java +++ b/core/src/main/java/org/springframework/security/core/context/SecurityContextHolderStrategy.java @@ -16,6 +16,8 @@ package org.springframework.security.core.context; +import java.util.function.Supplier; + /** * A strategy for storing security context information against a thread. * @@ -23,6 +25,7 @@ package org.springframework.security.core.context; * The preferred strategy is loaded by {@link SecurityContextHolder}. * * @author Ben Alex + * @author Rob Winch */ public interface SecurityContextHolderStrategy { @@ -38,6 +41,16 @@ public interface SecurityContextHolderStrategy { */ SecurityContext getContext(); + /** + * Obtains a {@link Supplier} that returns the current context. + * @return a {@link Supplier} that returns the current context (never + * null - create a default implementation if necessary) + * @since 5.8 + */ + default Supplier getDeferredContext() { + return () -> getContext(); + } + /** * Sets the current context. * @param context to the new argument (should never be null, although @@ -46,6 +59,16 @@ public interface SecurityContextHolderStrategy { */ void setContext(SecurityContext context); + /** + * Sets a {@link Supplier} that will return the current context. Implementations can + * override the default to avoid invoking {@link Supplier#get()}. + * @param deferredContext a {@link Supplier} that returns the {@link SecurityContext} + * @since 5.8 + */ + default void setDeferredContext(Supplier deferredContext) { + setContext(deferredContext.get()); + } + /** * Creates a new, empty context implementation, for use by * SecurityContextRepository implementations, when creating a new context for diff --git a/core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java b/core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java index 801f5c8207..f4fb5689c0 100644 --- a/core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java +++ b/core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java @@ -16,6 +16,8 @@ package org.springframework.security.core.context; +import java.util.function.Supplier; + import org.springframework.util.Assert; /** @@ -23,12 +25,13 @@ import org.springframework.util.Assert; * {@link SecurityContextHolderStrategy}. * * @author Ben Alex + * @author Rob Winch * @see java.lang.ThreadLocal * @see org.springframework.security.core.context.web.SecurityContextPersistenceFilter */ final class ThreadLocalSecurityContextHolderStrategy implements SecurityContextHolderStrategy { - private static final ThreadLocal contextHolder = new ThreadLocal<>(); + private static final ThreadLocal> contextHolder = new ThreadLocal<>(); @Override public void clearContext() { @@ -37,18 +40,35 @@ final class ThreadLocalSecurityContextHolderStrategy implements SecurityContextH @Override public SecurityContext getContext() { - SecurityContext ctx = contextHolder.get(); - if (ctx == null) { - ctx = createEmptyContext(); - contextHolder.set(ctx); + return getDeferredContext().get(); + } + + @Override + public Supplier getDeferredContext() { + Supplier result = contextHolder.get(); + if (result == null) { + SecurityContext context = createEmptyContext(); + result = () -> context; + contextHolder.set(result); } - return ctx; + return result; } @Override public void setContext(SecurityContext context) { Assert.notNull(context, "Only non-null SecurityContext instances are permitted"); - contextHolder.set(context); + contextHolder.set(() -> context); + } + + @Override + public void setDeferredContext(Supplier deferredContext) { + Assert.notNull(deferredContext, "Only non-null Supplier instances are permitted"); + Supplier notNullDeferredContext = () -> { + SecurityContext result = deferredContext.get(); + Assert.notNull(result, "A Supplier returned null and is not allowed."); + return result; + }; + contextHolder.set(notNullDeferredContext); } @Override diff --git a/core/src/test/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategyTests.java b/core/src/test/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategyTests.java new file mode 100644 index 0000000000..6fc01cca5d --- /dev/null +++ b/core/src/test/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategyTests.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.core.context; + +import java.util.function.Supplier; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import org.springframework.security.core.Authentication; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyNoInteractions; + +class InheritableThreadLocalSecurityContextHolderStrategyTests { + + InheritableThreadLocalSecurityContextHolderStrategy strategy = new InheritableThreadLocalSecurityContextHolderStrategy(); + + @AfterEach + void clearContext() { + this.strategy.clearContext(); + } + + @Test + void deferredNotInvoked() { + Supplier deferredContext = mock(Supplier.class); + this.strategy.setDeferredContext(deferredContext); + verifyNoInteractions(deferredContext); + } + + @Test + void deferredContext() { + Authentication authentication = mock(Authentication.class); + Supplier deferredContext = () -> new SecurityContextImpl(authentication); + this.strategy.setDeferredContext(deferredContext); + assertThat(this.strategy.getDeferredContext().get()).isEqualTo(deferredContext.get()); + assertThat(this.strategy.getContext()).isEqualTo(deferredContext.get()); + } + + @Test + void deferredContextValidates() { + this.strategy.setDeferredContext(() -> null); + Supplier deferredContext = this.strategy.getDeferredContext(); + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> deferredContext.get()); + } + + @Test + void context() { + Authentication authentication = mock(Authentication.class); + SecurityContext context = new SecurityContextImpl(authentication); + this.strategy.setContext(context); + assertThat(this.strategy.getContext()).isEqualTo(context); + assertThat(this.strategy.getDeferredContext().get()).isEqualTo(context); + } + + @Test + void contextValidates() { + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> this.strategy.setContext(null)); + } + + @Test + void getContextWhenEmptyThenReturnsSameInstance() { + Authentication authentication = mock(Authentication.class); + this.strategy.getContext().setAuthentication(authentication); + assertThat(this.strategy.getContext().getAuthentication()).isEqualTo(authentication); + } + +} diff --git a/core/src/test/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategyTests.java b/core/src/test/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategyTests.java new file mode 100644 index 0000000000..9fc4b66401 --- /dev/null +++ b/core/src/test/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategyTests.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.core.context; + +import java.util.function.Supplier; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import org.springframework.security.core.Authentication; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyNoInteractions; + +class ThreadLocalSecurityContextHolderStrategyTests { + + ThreadLocalSecurityContextHolderStrategy strategy = new ThreadLocalSecurityContextHolderStrategy(); + + @AfterEach + void clearContext() { + this.strategy.clearContext(); + } + + @Test + void deferredNotInvoked() { + Supplier deferredContext = mock(Supplier.class); + this.strategy.setDeferredContext(deferredContext); + verifyNoInteractions(deferredContext); + } + + @Test + void deferredContext() { + Authentication authentication = mock(Authentication.class); + Supplier deferredContext = () -> new SecurityContextImpl(authentication); + this.strategy.setDeferredContext(deferredContext); + assertThat(this.strategy.getDeferredContext().get()).isEqualTo(deferredContext.get()); + assertThat(this.strategy.getContext()).isEqualTo(deferredContext.get()); + } + + @Test + void deferredContextValidates() { + this.strategy.setDeferredContext(() -> null); + Supplier deferredContext = this.strategy.getDeferredContext(); + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> deferredContext.get()); + } + + @Test + void context() { + Authentication authentication = mock(Authentication.class); + SecurityContext context = new SecurityContextImpl(authentication); + this.strategy.setContext(context); + assertThat(this.strategy.getContext()).isEqualTo(context); + assertThat(this.strategy.getDeferredContext().get()).isEqualTo(context); + } + + @Test + void contextValidates() { + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> this.strategy.setContext(null)); + } + + @Test + void getContextWhenEmptyThenReturnsSameInstance() { + Authentication authentication = mock(Authentication.class); + this.strategy.getContext().setAuthentication(authentication); + assertThat(this.strategy.getContext().getAuthentication()).isEqualTo(authentication); + } + +} diff --git a/web/src/main/java/org/springframework/security/web/context/SecurityContextHolderFilter.java b/web/src/main/java/org/springframework/security/web/context/SecurityContextHolderFilter.java index 1b0764601d..45ecdad2e3 100644 --- a/web/src/main/java/org/springframework/security/web/context/SecurityContextHolderFilter.java +++ b/web/src/main/java/org/springframework/security/web/context/SecurityContextHolderFilter.java @@ -17,6 +17,7 @@ package org.springframework.security.web.context; import java.io.IOException; +import java.util.function.Supplier; import javax.servlet.FilterChain; import javax.servlet.ServletException; @@ -62,9 +63,9 @@ public class SecurityContextHolderFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - SecurityContext securityContext = this.securityContextRepository.loadContext(request).get(); + Supplier deferredContext = this.securityContextRepository.loadContext(request); try { - this.securityContextHolderStrategy.setContext(securityContext); + this.securityContextHolderStrategy.setDeferredContext(deferredContext); filterChain.doFilter(request, response); } finally { diff --git a/web/src/test/java/org/springframework/security/web/context/SecurityContextHolderFilterTests.java b/web/src/test/java/org/springframework/security/web/context/SecurityContextHolderFilterTests.java index b785d30df6..4983eb3166 100644 --- a/web/src/test/java/org/springframework/security/web/context/SecurityContextHolderFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/context/SecurityContextHolderFilterTests.java @@ -16,6 +16,8 @@ package org.springframework.security.web.context; +import java.util.function.Supplier; + import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -94,7 +96,9 @@ class SecurityContextHolderFilterTests { this.filter.setSecurityContextHolderStrategy(this.strategy); this.filter.doFilter(this.request, this.response, filterChain); - verify(this.strategy).setContext(expectedContext); + ArgumentCaptor> deferredContextArg = ArgumentCaptor.forClass(Supplier.class); + verify(this.strategy).setDeferredContext(deferredContextArg.capture()); + assertThat(deferredContextArg.getValue().get()).isEqualTo(expectedContext); verify(this.strategy).clearContext(); }