From 3e87ef84aeefef5c0b14ba982ff5d8a139772b38 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Tue, 7 Sep 2021 12:13:49 -0600 Subject: [PATCH] Replace SecurityContextHolder#addListener Closes gh-10226 --- .../GlobalSecurityContextHolderStrategy.java | 4 - ...eadLocalSecurityContextHolderStrategy.java | 4 - ...isteningSecurityContextHolderStrategy.java | 131 +++++++++++++----- .../core/context/SecurityContextHolder.java | 118 +++++++++------- ...eadLocalSecurityContextHolderStrategy.java | 4 - ...ingSecurityContextHolderStrategyTests.java | 70 ++++++++++ .../context/SecurityContextHolderTests.java | 23 ++- 7 files changed, 243 insertions(+), 111 deletions(-) create mode 100644 core/src/test/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategyTests.java diff --git a/core/src/main/java/org/springframework/security/core/context/GlobalSecurityContextHolderStrategy.java b/core/src/main/java/org/springframework/security/core/context/GlobalSecurityContextHolderStrategy.java index 0aaf696f0d..d8367c4ebd 100644 --- a/core/src/main/java/org/springframework/security/core/context/GlobalSecurityContextHolderStrategy.java +++ b/core/src/main/java/org/springframework/security/core/context/GlobalSecurityContextHolderStrategy.java @@ -31,10 +31,6 @@ final class GlobalSecurityContextHolderStrategy implements SecurityContextHolder private static SecurityContext contextHolder; - SecurityContext peek() { - return contextHolder; - } - @Override public void clearContext() { contextHolder = null; diff --git a/core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java b/core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java index 7ce665c2a1..cb415500ca 100644 --- a/core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java +++ b/core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java @@ -29,10 +29,6 @@ final class InheritableThreadLocalSecurityContextHolderStrategy implements Secur private static final ThreadLocal contextHolder = new InheritableThreadLocal<>(); - SecurityContext peek() { - return contextHolder.get(); - } - @Override public void clearContext() { contextHolder.remove(); diff --git a/core/src/main/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategy.java b/core/src/main/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategy.java index 24c0fbd17e..3c5f763fef 100644 --- a/core/src/main/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategy.java +++ b/core/src/main/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategy.java @@ -16,73 +16,130 @@ package org.springframework.security.core.context; -import java.util.List; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.function.BiConsumer; -import java.util.function.Supplier; +import java.util.Arrays; +import java.util.Collection; -final class ListeningSecurityContextHolderStrategy implements SecurityContextHolderStrategy { +import org.springframework.util.Assert; - private static final BiConsumer NULL_PUBLISHER = (previous, current) -> { - }; +/** + * An API for notifying when the {@link SecurityContext} changes. + * + * Note that this does not notify when the underlying authentication changes. To get + * notified about authentication changes, ensure that you are using {@link #setContext} + * when changing the authentication like so: + * + *
+ *	SecurityContext context = SecurityContextHolder.createEmptyContext();
+ *	context.setAuthentication(authentication);
+ *	SecurityContextHolder.setContext(context);
+ * 
+ * + * To add a listener to the existing {@link SecurityContextHolder}, you can do: + * + *
+ *  SecurityContextHolderStrategy original = SecurityContextHolder.getContextHolderStrategy();
+ *  SecurityContextChangedListener listener = new YourListener();
+ *  SecurityContextHolderStrategy strategy = new ListeningSecurityContextHolderStrategy(original, listener);
+ *  SecurityContextHolder.setContextHolderStrategy(strategy);
+ * 
+ * + * NOTE: Any object that you supply to the {@link SecurityContextHolder} is now part of + * the static context and as such will not get garbage collected. To remove the reference, + * {@link SecurityContextHolder#setContextHolderStrategy reset the strategy} like so: + * + *
+ *   SecurityContextHolder.setContextHolderStrategy(original);
+ * 
+ * + * This will then allow {@code YourListener} and its members to be garbage collected. + * + * @author Josh Cummings + * @since 5.6 + */ +public final class ListeningSecurityContextHolderStrategy implements SecurityContextHolderStrategy { - private final Supplier peek; + private final Collection listeners; private final SecurityContextHolderStrategy delegate; - private final SecurityContextEventPublisher base = new SecurityContextEventPublisher(); - - private BiConsumer publisher = NULL_PUBLISHER; - - ListeningSecurityContextHolderStrategy(Supplier peek, SecurityContextHolderStrategy delegate) { - this.peek = peek; + /** + * Construct a {@link ListeningSecurityContextHolderStrategy} + * @param listeners the listeners that should be notified when the + * {@link SecurityContext} is {@link #setContext(SecurityContext) set} or + * {@link #clearContext() cleared} + * @param delegate the underlying {@link SecurityContextHolderStrategy} + */ + public ListeningSecurityContextHolderStrategy(SecurityContextHolderStrategy delegate, + Collection listeners) { + Assert.notNull(delegate, "securityContextHolderStrategy cannot be null"); + Assert.notNull(listeners, "securityContextChangedListeners cannot be null"); + Assert.notEmpty(listeners, "securityContextChangedListeners cannot be empty"); + Assert.noNullElements(listeners, "securityContextChangedListeners cannot contain null elements"); this.delegate = delegate; + this.listeners = listeners; } + /** + * Construct a {@link ListeningSecurityContextHolderStrategy} + * @param listeners the listeners that should be notified when the + * {@link SecurityContext} is {@link #setContext(SecurityContext) set} or + * {@link #clearContext() cleared} + * @param delegate the underlying {@link SecurityContextHolderStrategy} + */ + public ListeningSecurityContextHolderStrategy(SecurityContextHolderStrategy delegate, + SecurityContextChangedListener... listeners) { + Assert.notNull(delegate, "securityContextHolderStrategy cannot be null"); + Assert.notNull(listeners, "securityContextChangedListeners cannot be null"); + Assert.notEmpty(listeners, "securityContextChangedListeners cannot be empty"); + Assert.noNullElements(listeners, "securityContextChangedListeners cannot contain null elements"); + this.delegate = delegate; + this.listeners = Arrays.asList(listeners); + } + + /** + * {@inheritDoc} + */ @Override public void clearContext() { - SecurityContext from = this.peek.get(); + SecurityContext from = getContext(); this.delegate.clearContext(); - this.publisher.accept(from, null); + publish(from, null); } + /** + * {@inheritDoc} + */ @Override public SecurityContext getContext() { return this.delegate.getContext(); } + /** + * {@inheritDoc} + */ @Override public void setContext(SecurityContext context) { - SecurityContext from = this.peek.get(); + SecurityContext from = getContext(); this.delegate.setContext(context); - this.publisher.accept(from, context); + publish(from, context); } + /** + * {@inheritDoc} + */ @Override public SecurityContext createEmptyContext() { return this.delegate.createEmptyContext(); } - void addListener(SecurityContextChangedListener listener) { - this.base.listeners.add(listener); - this.publisher = this.base; - } - - private static class SecurityContextEventPublisher implements BiConsumer { - - private final List listeners = new CopyOnWriteArrayList<>(); - - @Override - public void accept(SecurityContext previous, SecurityContext current) { - if (previous == current) { - return; - } - SecurityContextChangedEvent event = new SecurityContextChangedEvent(previous, current); - for (SecurityContextChangedListener listener : this.listeners) { - listener.securityContextChanged(event); - } + private void publish(SecurityContext previous, SecurityContext current) { + if (previous == current) { + return; + } + SecurityContextChangedEvent event = new SecurityContextChangedEvent(previous, current); + for (SecurityContextChangedListener listener : this.listeners) { + listener.securityContextChanged(event); } - } } diff --git a/core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java b/core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java index cfce45ad25..337fde3a57 100644 --- a/core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java +++ b/core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java @@ -56,6 +56,8 @@ public class SecurityContextHolder { public static final String MODE_GLOBAL = "MODE_GLOBAL"; + private static final String MODE_PRE_INITIALIZED = "MODE_PRE_INITIALIZED"; + public static final String SYSTEM_PROPERTY = "spring.security.strategy"; private static String strategyName = System.getProperty(SYSTEM_PROPERTY); @@ -69,34 +71,41 @@ public class SecurityContextHolder { } private static void initialize() { + initializeStrategy(); + initializeCount++; + } + + private static void initializeStrategy() { + if (MODE_PRE_INITIALIZED.equals(strategyName)) { + Assert.state(strategy != null, "When using " + MODE_PRE_INITIALIZED + + ", setContextHolderStrategy must be called with the fully constructed strategy"); + return; + } if (!StringUtils.hasText(strategyName)) { // Set default strategyName = MODE_THREADLOCAL; } if (strategyName.equals(MODE_THREADLOCAL)) { - ThreadLocalSecurityContextHolderStrategy delegate = new ThreadLocalSecurityContextHolderStrategy(); - strategy = new ListeningSecurityContextHolderStrategy(delegate::peek, delegate); + strategy = new ThreadLocalSecurityContextHolderStrategy(); + return; } - else if (strategyName.equals(MODE_INHERITABLETHREADLOCAL)) { - InheritableThreadLocalSecurityContextHolderStrategy delegate = new InheritableThreadLocalSecurityContextHolderStrategy(); - strategy = new ListeningSecurityContextHolderStrategy(delegate::peek, delegate); + if (strategyName.equals(MODE_INHERITABLETHREADLOCAL)) { + strategy = new InheritableThreadLocalSecurityContextHolderStrategy(); + return; } - else if (strategyName.equals(MODE_GLOBAL)) { - GlobalSecurityContextHolderStrategy delegate = new GlobalSecurityContextHolderStrategy(); - strategy = new ListeningSecurityContextHolderStrategy(delegate::peek, delegate); + if (strategyName.equals(MODE_GLOBAL)) { + strategy = new GlobalSecurityContextHolderStrategy(); + return; } - else { - // Try to load a custom strategy - try { - Class clazz = Class.forName(strategyName); - Constructor customStrategy = clazz.getConstructor(); - strategy = (SecurityContextHolderStrategy) customStrategy.newInstance(); - } - catch (Exception ex) { - ReflectionUtils.handleReflectionException(ex); - } + // Try to load a custom strategy + try { + Class clazz = Class.forName(strategyName); + Constructor customStrategy = clazz.getConstructor(); + strategy = (SecurityContextHolderStrategy) customStrategy.newInstance(); + } + catch (Exception ex) { + ReflectionUtils.handleReflectionException(ex); } - initializeCount++; } /** @@ -118,7 +127,9 @@ public class SecurityContextHolder { * Primarily for troubleshooting purposes, this method shows how many times the class * has re-initialized its SecurityContextHolderStrategy. * @return the count (should be one unless you've called - * {@link #setStrategyName(String)} to switch to an alternate strategy. + * {@link #setStrategyName(String)} or + * {@link #setContextHolderStrategy(SecurityContextHolderStrategy)} to switch to an + * alternate strategy). */ public static int getInitializeCount() { return initializeCount; @@ -144,6 +155,41 @@ public class SecurityContextHolder { initialize(); } + /** + * Use this {@link SecurityContextHolderStrategy}. + * + * Call either {@link #setStrategyName(String)} or this method, but not both. + * + * This method is not thread safe. Changing the strategy while requests are in-flight + * may cause race conditions. + * + * {@link SecurityContextHolder} maintains a static reference to the provided + * {@link SecurityContextHolderStrategy}. This means that the strategy and its members + * will not be garbage collected until you remove your strategy. + * + * To ensure garbage collection, remember the original strategy like so: + * + *
+	 *     SecurityContextHolderStrategy original = SecurityContextHolder.getContextHolderStrategy();
+	 *     SecurityContextHolder.setContextHolderStrategy(myStrategy);
+	 * 
+ * + * And then when you are ready for {@code myStrategy} to be garbage collected you can + * do: + * + *
+	 *     SecurityContextHolder.setContextHolderStrategy(original);
+	 * 
+ * @param strategy the {@link SecurityContextHolderStrategy} to use + * @since 5.6 + */ + public static void setContextHolderStrategy(SecurityContextHolderStrategy strategy) { + Assert.notNull(strategy, "securityContextHolderStrategy cannot be null"); + SecurityContextHolder.strategyName = MODE_PRE_INITIALIZED; + SecurityContextHolder.strategy = strategy; + initialize(); + } + /** * Allows retrieval of the context strategy. See SEC-1188. * @return the configured strategy for storing the security context. @@ -159,38 +205,10 @@ public class SecurityContextHolder { return strategy.createEmptyContext(); } - /** - * Register a listener to be notified when the {@link SecurityContext} changes. - * - * Note that this does not notify when the underlying authentication changes. To get - * notified about authentication changes, ensure that you are using - * {@link #setContext} when changing the authentication like so: - * - *
-	 *	SecurityContext context = SecurityContextHolder.createEmptyContext();
-	 *	context.setAuthentication(authentication);
-	 *	SecurityContextHolder.setContext(context);
-	 * 
- * - * To integrate this with Spring's - * {@link org.springframework.context.ApplicationEvent} support, you can add a - * listener like so: - * - *
-	 *	SecurityContextHolder.addListener(this.applicationContext::publishEvent);
-	 * 
- * @param listener a listener to be notified when the {@link SecurityContext} changes - * @since 5.6 - */ - public static void addListener(SecurityContextChangedListener listener) { - Assert.isInstanceOf(ListeningSecurityContextHolderStrategy.class, strategy, - "strategy must be of type ListeningSecurityContextHolderStrategy to add listeners"); - ((ListeningSecurityContextHolderStrategy) strategy).addListener(listener); - } - @Override public String toString() { - return "SecurityContextHolder[strategy='" + strategyName + "'; initializeCount=" + initializeCount + "]"; + return "SecurityContextHolder[strategy='" + strategy.getClass().getSimpleName() + "'; initializeCount=" + + initializeCount + "]"; } } diff --git a/core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java b/core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java index a3094bfa70..801f5c8207 100644 --- a/core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java +++ b/core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java @@ -30,10 +30,6 @@ final class ThreadLocalSecurityContextHolderStrategy implements SecurityContextH private static final ThreadLocal contextHolder = new ThreadLocal<>(); - SecurityContext peek() { - return contextHolder.get(); - } - @Override public void clearContext() { contextHolder.remove(); diff --git a/core/src/test/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategyTests.java b/core/src/test/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategyTests.java new file mode 100644 index 0000000000..999fee6fd6 --- /dev/null +++ b/core/src/test/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategyTests.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.core.context; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +public class ListeningSecurityContextHolderStrategyTests { + + @Test + public void setContextWhenInvokedThenListenersAreNotified() { + SecurityContextHolderStrategy delegate = mock(SecurityContextHolderStrategy.class); + SecurityContextChangedListener one = mock(SecurityContextChangedListener.class); + SecurityContextChangedListener two = mock(SecurityContextChangedListener.class); + SecurityContextHolderStrategy strategy = new ListeningSecurityContextHolderStrategy(delegate, one, two); + given(delegate.createEmptyContext()).willReturn(new SecurityContextImpl()); + SecurityContext context = strategy.createEmptyContext(); + strategy.setContext(context); + verify(delegate).setContext(context); + verify(one).securityContextChanged(any()); + verify(two).securityContextChanged(any()); + } + + @Test + public void setContextWhenNoChangeToContextThenListenersAreNotNotified() { + SecurityContextHolderStrategy delegate = mock(SecurityContextHolderStrategy.class); + SecurityContextChangedListener listener = mock(SecurityContextChangedListener.class); + SecurityContextHolderStrategy strategy = new ListeningSecurityContextHolderStrategy(delegate, listener); + SecurityContext context = new SecurityContextImpl(); + given(delegate.getContext()).willReturn(context); + strategy.setContext(strategy.getContext()); + verify(delegate).setContext(context); + verifyNoInteractions(listener); + } + + @Test + public void constructorWhenNullDelegateThenIllegalArgument() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> new ListeningSecurityContextHolderStrategy(null, (event) -> { + })); + } + + @Test + public void constructorWhenNullListenerThenIllegalArgument() { + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy( + () -> new ListeningSecurityContextHolderStrategy(new ThreadLocalSecurityContextHolderStrategy(), + (SecurityContextChangedListener) null)); + } + +} diff --git a/core/src/test/java/org/springframework/security/core/context/SecurityContextHolderTests.java b/core/src/test/java/org/springframework/security/core/context/SecurityContextHolderTests.java index d1eaf70909..563f7a307a 100644 --- a/core/src/test/java/org/springframework/security/core/context/SecurityContextHolderTests.java +++ b/core/src/test/java/org/springframework/security/core/context/SecurityContextHolderTests.java @@ -23,9 +23,7 @@ import org.springframework.security.authentication.UsernamePasswordAuthenticatio import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; /** @@ -63,16 +61,17 @@ public class SecurityContextHolderTests { } @Test - public void addListenerWhenInvokedThenListenersAreNotified() { - SecurityContextChangedListener one = mock(SecurityContextChangedListener.class); - SecurityContextChangedListener two = mock(SecurityContextChangedListener.class); - SecurityContextHolder.addListener(one); - SecurityContextHolder.addListener(two); - SecurityContext context = SecurityContextHolder.createEmptyContext(); - SecurityContextHolder.setContext(context); - SecurityContextHolder.clearContext(); - verify(one, times(2)).securityContextChanged(any(SecurityContextChangedEvent.class)); - verify(two, times(2)).securityContextChanged(any(SecurityContextChangedEvent.class)); + public void setContextHolderStrategyWhenCalledThenUsed() { + SecurityContextHolderStrategy original = SecurityContextHolder.getContextHolderStrategy(); + try { + SecurityContextHolderStrategy delegate = mock(SecurityContextHolderStrategy.class); + SecurityContextHolder.setContextHolderStrategy(delegate); + SecurityContextHolder.getContext(); + verify(delegate).getContext(); + } + finally { + SecurityContextHolder.setContextHolderStrategy(original); + } } }