From 459003e1b32f7be41f1380cdb49270358772f359 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Tue, 21 Jun 2022 16:38:56 -0600 Subject: [PATCH] Use SecurityContextHolderStrategy for Context Propagation Issue gh-11060 --- ...tractDelegatingSecurityContextSupport.java | 19 +++++- .../DelegatingSecurityContextCallable.java | 61 +++++++++++++++---- .../DelegatingSecurityContextExecutor.java | 11 ++++ .../DelegatingSecurityContextRunnable.java | 61 +++++++++++++++---- ...tDelegatingSecurityContextTestSupport.java | 18 +++--- ...elegatingSecurityContextCallableTests.java | 25 +++++++- ...elegatingSecurityContextRunnableTests.java | 25 ++++++++ .../MockSecurityContextHolderStrategy.java | 43 +++++++++++++ 8 files changed, 227 insertions(+), 36 deletions(-) create mode 100644 core/src/test/java/org/springframework/security/core/context/MockSecurityContextHolderStrategy.java diff --git a/core/src/main/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextSupport.java b/core/src/main/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextSupport.java index e3e6e0267a..f500cfaa2e 100644 --- a/core/src/main/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextSupport.java +++ b/core/src/main/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextSupport.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,9 @@ package org.springframework.security.concurrent; import java.util.concurrent.Callable; import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.util.Assert; /** * An internal support class that wraps {@link Callable} with @@ -30,6 +33,9 @@ import org.springframework.security.core.context.SecurityContext; */ abstract class AbstractDelegatingSecurityContextSupport { + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + private final SecurityContext securityContext; /** @@ -44,12 +50,19 @@ abstract class AbstractDelegatingSecurityContextSupport { this.securityContext = securityContext; } + void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + protected final Runnable wrap(Runnable delegate) { - return DelegatingSecurityContextRunnable.create(delegate, this.securityContext); + return DelegatingSecurityContextRunnable.create(delegate, this.securityContext, + this.securityContextHolderStrategy); } protected final Callable wrap(Callable delegate) { - return DelegatingSecurityContextCallable.create(delegate, this.securityContext); + return DelegatingSecurityContextCallable.create(delegate, this.securityContext, + this.securityContextHolderStrategy); } } diff --git a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextCallable.java b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextCallable.java index 50d1b89af0..8842551912 100644 --- a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextCallable.java +++ b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextCallable.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import java.util.concurrent.Callable; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.util.Assert; /** @@ -40,10 +41,15 @@ public final class DelegatingSecurityContextCallable implements Callable { private final Callable delegate; + private final boolean explicitSecurityContextProvided; + /** * The {@link SecurityContext} that the delegate {@link Callable} will be ran as. */ - private final SecurityContext delegateSecurityContext; + private SecurityContext delegateSecurityContext; + + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); /** * The {@link SecurityContext} that was on the {@link SecurityContextHolder} prior to @@ -60,10 +66,7 @@ public final class DelegatingSecurityContextCallable implements Callable { * {@link Callable}. Cannot be null. */ public DelegatingSecurityContextCallable(Callable delegate, SecurityContext securityContext) { - Assert.notNull(delegate, "delegate cannot be null"); - Assert.notNull(securityContext, "securityContext cannot be null"); - this.delegate = delegate; - this.delegateSecurityContext = securityContext; + this(delegate, securityContext, true); } /** @@ -73,28 +76,51 @@ public final class DelegatingSecurityContextCallable implements Callable { * {@link SecurityContext}. Cannot be null. */ public DelegatingSecurityContextCallable(Callable delegate) { - this(delegate, SecurityContextHolder.getContext()); + this(delegate, SecurityContextHolder.getContext(), false); + } + + private DelegatingSecurityContextCallable(Callable delegate, SecurityContext securityContext, + boolean explicitSecurityContextProvided) { + Assert.notNull(delegate, "delegate cannot be null"); + Assert.notNull(securityContext, "securityContext cannot be null"); + this.delegate = delegate; + this.delegateSecurityContext = securityContext; + this.explicitSecurityContextProvided = explicitSecurityContextProvided; } @Override public V call() throws Exception { - this.originalSecurityContext = SecurityContextHolder.getContext(); + this.originalSecurityContext = this.securityContextHolderStrategy.getContext(); try { - SecurityContextHolder.setContext(this.delegateSecurityContext); + this.securityContextHolderStrategy.setContext(this.delegateSecurityContext); return this.delegate.call(); } finally { - SecurityContext emptyContext = SecurityContextHolder.createEmptyContext(); + SecurityContext emptyContext = this.securityContextHolderStrategy.createEmptyContext(); if (emptyContext.equals(this.originalSecurityContext)) { - SecurityContextHolder.clearContext(); + this.securityContextHolderStrategy.clearContext(); } else { - SecurityContextHolder.setContext(this.originalSecurityContext); + this.securityContextHolderStrategy.setContext(this.originalSecurityContext); } this.originalSecurityContext = null; } } + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + * + * @since 5.8 + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + if (!this.explicitSecurityContextProvided) { + this.delegateSecurityContext = securityContextHolderStrategy.getContext(); + } + } + @Override public String toString() { return this.delegate.toString(); @@ -116,4 +142,15 @@ public final class DelegatingSecurityContextCallable implements Callable { : new DelegatingSecurityContextCallable<>(delegate); } + static Callable create(Callable delegate, SecurityContext securityContext, + SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(delegate, "delegate cannot be null"); + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + DelegatingSecurityContextCallable callable = (securityContext != null) + ? new DelegatingSecurityContextCallable<>(delegate, securityContext) + : new DelegatingSecurityContextCallable<>(delegate); + callable.setSecurityContextHolderStrategy(securityContextHolderStrategy); + return callable; + } + } diff --git a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextExecutor.java b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextExecutor.java index c1af6a7546..f4d0545d23 100644 --- a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextExecutor.java +++ b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextExecutor.java @@ -20,6 +20,7 @@ import java.util.concurrent.Executor; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.util.Assert; /** @@ -66,4 +67,14 @@ public class DelegatingSecurityContextExecutor extends AbstractDelegatingSecurit return this.delegate; } + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + * + * @since 5.8 + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + super.setSecurityContextHolderStrategy(securityContextHolderStrategy); + } + } diff --git a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnable.java b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnable.java index 24b0746641..98deb61cec 100644 --- a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnable.java +++ b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnable.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package org.springframework.security.concurrent; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.util.Assert; /** @@ -38,10 +39,15 @@ public final class DelegatingSecurityContextRunnable implements Runnable { private final Runnable delegate; + private final boolean explicitSecurityContextProvided; + + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + /** * The {@link SecurityContext} that the delegate {@link Runnable} will be ran as. */ - private final SecurityContext delegateSecurityContext; + private SecurityContext delegateSecurityContext; /** * The {@link SecurityContext} that was on the {@link SecurityContextHolder} prior to @@ -58,10 +64,7 @@ public final class DelegatingSecurityContextRunnable implements Runnable { * {@link Runnable}. Cannot be null. */ public DelegatingSecurityContextRunnable(Runnable delegate, SecurityContext securityContext) { - Assert.notNull(delegate, "delegate cannot be null"); - Assert.notNull(securityContext, "securityContext cannot be null"); - this.delegate = delegate; - this.delegateSecurityContext = securityContext; + this(delegate, securityContext, true); } /** @@ -71,28 +74,51 @@ public final class DelegatingSecurityContextRunnable implements Runnable { * {@link SecurityContext}. Cannot be null. */ public DelegatingSecurityContextRunnable(Runnable delegate) { - this(delegate, SecurityContextHolder.getContext()); + this(delegate, SecurityContextHolder.getContext(), false); + } + + private DelegatingSecurityContextRunnable(Runnable delegate, SecurityContext securityContext, + boolean explicitSecurityContextProvided) { + Assert.notNull(delegate, "delegate cannot be null"); + Assert.notNull(securityContext, "securityContext cannot be null"); + this.delegate = delegate; + this.delegateSecurityContext = securityContext; + this.explicitSecurityContextProvided = explicitSecurityContextProvided; } @Override public void run() { - this.originalSecurityContext = SecurityContextHolder.getContext(); + this.originalSecurityContext = this.securityContextHolderStrategy.getContext(); try { - SecurityContextHolder.setContext(this.delegateSecurityContext); + this.securityContextHolderStrategy.setContext(this.delegateSecurityContext); this.delegate.run(); } finally { - SecurityContext emptyContext = SecurityContextHolder.createEmptyContext(); + SecurityContext emptyContext = this.securityContextHolderStrategy.createEmptyContext(); if (emptyContext.equals(this.originalSecurityContext)) { - SecurityContextHolder.clearContext(); + this.securityContextHolderStrategy.clearContext(); } else { - SecurityContextHolder.setContext(this.originalSecurityContext); + this.securityContextHolderStrategy.setContext(this.originalSecurityContext); } this.originalSecurityContext = null; } } + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + * + * @since 5.8 + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + if (!this.explicitSecurityContextProvided) { + this.delegateSecurityContext = this.securityContextHolderStrategy.getContext(); + } + } + @Override public String toString() { return this.delegate.toString(); @@ -114,4 +140,15 @@ public final class DelegatingSecurityContextRunnable implements Runnable { : new DelegatingSecurityContextRunnable(delegate); } + static Runnable create(Runnable delegate, SecurityContext securityContext, + SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(delegate, "delegate cannot be null"); + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + DelegatingSecurityContextRunnable runnable = (securityContext != null) + ? new DelegatingSecurityContextRunnable(delegate, securityContext) + : new DelegatingSecurityContextRunnable(delegate); + runnable.setSecurityContextHolderStrategy(securityContextHolderStrategy); + return runnable; + } + } diff --git a/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextTestSupport.java b/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextTestSupport.java index 7ebd72bd1d..9012c26e5f 100644 --- a/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextTestSupport.java +++ b/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextTestSupport.java @@ -30,7 +30,9 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; /** * Abstract base class for testing classes that extend @@ -71,18 +73,18 @@ public abstract class AbstractDelegatingSecurityContextTestSupport { protected MockedStatic delegatingSecurityContextRunnable; public final void explicitSecurityContextSetup() throws Exception { - this.delegatingSecurityContextCallable.when( - () -> DelegatingSecurityContextCallable.create(eq(this.callable), this.securityContextCaptor.capture())) - .thenReturn(this.wrappedCallable); - this.delegatingSecurityContextRunnable.when( - () -> DelegatingSecurityContextRunnable.create(eq(this.runnable), this.securityContextCaptor.capture())) - .thenReturn(this.wrappedRunnable); + this.delegatingSecurityContextCallable.when(() -> DelegatingSecurityContextCallable.create(eq(this.callable), + this.securityContextCaptor.capture(), any())).thenReturn(this.wrappedCallable); + this.delegatingSecurityContextRunnable.when(() -> DelegatingSecurityContextRunnable.create(eq(this.runnable), + this.securityContextCaptor.capture(), any())).thenReturn(this.wrappedRunnable); } public final void currentSecurityContextSetup() throws Exception { - this.delegatingSecurityContextCallable.when(() -> DelegatingSecurityContextCallable.create(this.callable, null)) + this.delegatingSecurityContextCallable + .when(() -> DelegatingSecurityContextCallable.create(eq(this.callable), isNull(), any())) .thenReturn(this.wrappedCallable); - this.delegatingSecurityContextRunnable.when(() -> DelegatingSecurityContextRunnable.create(this.runnable, null)) + this.delegatingSecurityContextRunnable + .when(() -> DelegatingSecurityContextRunnable.create(eq(this.runnable), isNull(), any())) .thenReturn(this.wrappedRunnable); } diff --git a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java index c5110efd2a..08ceac3c7d 100644 --- a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java @@ -30,12 +30,16 @@ import org.mockito.internal.stubbing.answers.Returns; import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.security.core.context.MockSecurityContextHolderStrategy; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; /** @@ -68,10 +72,15 @@ public class DelegatingSecurityContextCallableTests { } private void givenDelegateCallWillAnswerWithCurrentSecurityContext() throws Exception { + givenDelegateCallWillAnswerWithCurrentSecurityContext(SecurityContextHolder.getContextHolderStrategy()); + } + + private void givenDelegateCallWillAnswerWithCurrentSecurityContext(SecurityContextHolderStrategy strategy) + throws Exception { given(this.delegate.call()).willAnswer(new Returns(this.callableResult) { @Override public Object answer(InvocationOnMock invocation) throws Throwable { - assertThat(SecurityContextHolder.getContext()) + assertThat(strategy.getContext()) .isEqualTo(DelegatingSecurityContextCallableTests.this.securityContext); return super.answer(invocation); } @@ -122,6 +131,20 @@ public class DelegatingSecurityContextCallableTests { assertWrapped(this.callable); } + @Test + public void callDefaultSecurityContextWithCustomSecurityContextHolderStrategy() throws Exception { + SecurityContextHolderStrategy securityContextHolderStrategy = spy(new MockSecurityContextHolderStrategy()); + givenDelegateCallWillAnswerWithCurrentSecurityContext(securityContextHolderStrategy); + securityContextHolderStrategy.setContext(this.securityContext); + DelegatingSecurityContextCallable callable = new DelegatingSecurityContextCallable<>(this.delegate); + callable.setSecurityContextHolderStrategy(securityContextHolderStrategy); + this.callable = callable; + // ensure callable is what sets up the SecurityContextHolder + securityContextHolderStrategy.clearContext(); + assertWrapped(this.callable); + verify(securityContextHolderStrategy, atLeastOnce()).getContext(); + } + // SEC-3031 @Test public void callOnSameThread() throws Exception { diff --git a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java index 8b1c3852f6..458a46e105 100644 --- a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java @@ -30,12 +30,16 @@ import org.mockito.stubbing.Answer; import org.springframework.core.task.SyncTaskExecutor; import org.springframework.core.task.support.ExecutorServiceAdapter; +import org.springframework.security.core.context.MockSecurityContextHolderStrategy; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.BDDMockito.willAnswer; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; /** @@ -73,6 +77,13 @@ public class DelegatingSecurityContextRunnableTests { }).given(this.delegate).run(); } + private void givenDelegateRunWillAnswerWithCurrentSecurityContext(SecurityContextHolderStrategy strategy) { + willAnswer((Answer) (invocation) -> { + assertThat(strategy.getContext()).isEqualTo(this.securityContext); + return null; + }).given(this.delegate).run(); + } + @AfterEach public void tearDown() { SecurityContextHolder.clearContext(); @@ -117,6 +128,20 @@ public class DelegatingSecurityContextRunnableTests { assertWrapped(this.runnable); } + @Test + public void callDefaultSecurityContextWithCustomSecurityContextHolderStrategy() throws Exception { + SecurityContextHolderStrategy securityContextHolderStrategy = spy(new MockSecurityContextHolderStrategy()); + givenDelegateRunWillAnswerWithCurrentSecurityContext(securityContextHolderStrategy); + securityContextHolderStrategy.setContext(this.securityContext); + DelegatingSecurityContextRunnable runnable = new DelegatingSecurityContextRunnable(this.delegate); + runnable.setSecurityContextHolderStrategy(securityContextHolderStrategy); + this.runnable = runnable; + // ensure callable is what sets up the SecurityContextHolder + securityContextHolderStrategy.clearContext(); + assertWrapped(this.runnable); + verify(securityContextHolderStrategy, atLeastOnce()).getContext(); + } + // SEC-3031 @Test public void callOnSameThread() throws Exception { diff --git a/core/src/test/java/org/springframework/security/core/context/MockSecurityContextHolderStrategy.java b/core/src/test/java/org/springframework/security/core/context/MockSecurityContextHolderStrategy.java new file mode 100644 index 0000000000..3f4ddfba11 --- /dev/null +++ b/core/src/test/java/org/springframework/security/core/context/MockSecurityContextHolderStrategy.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.core.context; + +public class MockSecurityContextHolderStrategy implements SecurityContextHolderStrategy { + + private SecurityContext context; + + @Override + public void clearContext() { + this.context = null; + } + + @Override + public SecurityContext getContext() { + return this.context; + } + + @Override + public void setContext(SecurityContext context) { + this.context = context; + } + + @Override + public SecurityContext createEmptyContext() { + return new SecurityContextImpl(); + } + +}