diff --git a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextCallable.java b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextCallable.java index 3cfba87101..2114fc32b5 100644 --- a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextCallable.java +++ b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextCallable.java @@ -19,9 +19,17 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.util.Assert; /** - * Wraps a delegate {@link Callable} with logic for setting up a {@link SecurityContext} - * before invoking the delegate {@link Callable} and then removing the - * {@link SecurityContext} after the delegate has completed. + *

+ * Wraps a delegate {@link Callable} with logic for setting up a + * {@link SecurityContext} before invoking the delegate {@link Callable} and + * then removing the {@link SecurityContext} after the delegate has completed. + *

+ *

+ * By default the {@link SecurityContext} is only setup if {@link #call()} is + * invoked on a separate {@link Thread} than the + * {@link DelegatingSecurityContextCallable} was created on. This can be + * overridden by setting {@link #setEnableOnOriginalThread(boolean)} to true. + *

* * @author Rob Winch * @since 3.2 @@ -32,6 +40,10 @@ public final class DelegatingSecurityContextCallable implements Callable { private final SecurityContext securityContext; + private final Thread originalThread; + + private boolean enableOnOriginalThread; + /** * Creates a new {@link DelegatingSecurityContextCallable} with a specific * {@link SecurityContext}. @@ -46,6 +58,7 @@ public final class DelegatingSecurityContextCallable implements Callable { Assert.notNull(securityContext, "securityContext cannot be null"); this.delegate = delegate; this.securityContext = securityContext; + this.originalThread = Thread.currentThread(); } /** @@ -58,7 +71,27 @@ public final class DelegatingSecurityContextCallable implements Callable { this(delegate, SecurityContextHolder.getContext()); } + /** + * Determines if the SecurityContext should be transfered if {@link #call()} + * is invoked on the same {@link Thread} the + * {@link DelegatingSecurityContextCallable} was created on. + * + * @param enableOnOriginalThread + * if false (default), will only transfer the + * {@link SecurityContext} if {@link #call()} is invoked on a + * different {@link Thread} than the + * {@link DelegatingSecurityContextCallable} was created on. + * + * @since 4.0.2 + */ + public void setEnableOnOriginalThread(boolean enableOnOriginalThread) { + this.enableOnOriginalThread = enableOnOriginalThread; + } + public V call() throws Exception { + if(!enableOnOriginalThread && originalThread == Thread.currentThread()) { + return delegate.call(); + } try { SecurityContextHolder.setContext(securityContext); return delegate.call(); diff --git a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnable.java b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnable.java index 6674c8c6ec..30f9d58f29 100644 --- a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnable.java +++ b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnable.java @@ -17,9 +17,17 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.util.Assert; /** + *

* Wraps a delegate {@link Runnable} with logic for setting up a {@link SecurityContext} * before invoking the delegate {@link Runnable} and then removing the * {@link SecurityContext} after the delegate has completed. + *

+ *

+ * By default the {@link SecurityContext} is only setup if {@link #run()} is + * invoked on a separate {@link Thread} than the + * {@link DelegatingSecurityContextRunnable} was created on. This can be + * overridden by setting {@link #setEnableOnOriginalThread(boolean)} to true. + *

* * @author Rob Winch * @since 3.2 @@ -30,6 +38,10 @@ public final class DelegatingSecurityContextRunnable implements Runnable { private final SecurityContext securityContext; + private final Thread originalThread; + + private boolean enableOnOriginalThread; + /** * Creates a new {@link DelegatingSecurityContextRunnable} with a specific * {@link SecurityContext}. @@ -44,6 +56,7 @@ public final class DelegatingSecurityContextRunnable implements Runnable { Assert.notNull(securityContext, "securityContext cannot be null"); this.delegate = delegate; this.securityContext = securityContext; + this.originalThread = Thread.currentThread(); } /** @@ -56,7 +69,27 @@ public final class DelegatingSecurityContextRunnable implements Runnable { this(delegate, SecurityContextHolder.getContext()); } + /** + * Determines if the SecurityContext should be transfered if {@link #call()} + * is invoked on the same {@link Thread} the + * {@link DelegatingSecurityContextCallable} was created on. + * + * @param enableOnOriginalThread + * if false (default), will only transfer the + * {@link SecurityContext} if {@link #call()} is invoked on a + * different {@link Thread} than the + * {@link DelegatingSecurityContextCallable} was created on. + * @since 4.0.2 + */ + public void setEnableOnOriginalThread(boolean enableOnOriginalThread) { + this.enableOnOriginalThread = enableOnOriginalThread; + } + public void run() { + if(!enableOnOriginalThread && originalThread == Thread.currentThread()) { + delegate.run(); + return; + } try { SecurityContextHolder.setContext(securityContext); delegate.run(); diff --git a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java index c3eeafaf92..d6644de83d 100644 --- a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java @@ -17,6 +17,9 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; import org.junit.After; import org.junit.Before; @@ -45,6 +48,8 @@ public class DelegatingSecurityContextCallableTests { private Callable callable; + private ExecutorService executor; + @Before @SuppressWarnings("serial") public void setUp() throws Exception { @@ -55,6 +60,7 @@ public class DelegatingSecurityContextCallableTests { return super.answer(invocation); } }); + executor = Executors.newFixedThreadPool(1); } @After @@ -90,7 +96,7 @@ public class DelegatingSecurityContextCallableTests { public void call() throws Exception { callable = new DelegatingSecurityContextCallable(delegate, securityContext); - assertWrapped(callable.call()); + assertWrapped(callable); } @Test @@ -99,6 +105,23 @@ public class DelegatingSecurityContextCallableTests { callable = new DelegatingSecurityContextCallable(delegate); SecurityContextHolder.clearContext(); // ensure callable is what sets up the // SecurityContextHolder + assertWrapped(callable); + } + + // SEC-3031 + @Test + public void callOnSameThread() throws Exception { + callable = new DelegatingSecurityContextCallable(delegate, + securityContext); + securityContext = SecurityContextHolder.createEmptyContext(); + assertWrapped(callable.call()); + } + + @Test + public void callOnSameThreadExplicitlyEnabled() throws Exception { + DelegatingSecurityContextCallable callable = new DelegatingSecurityContextCallable(delegate, + securityContext); + callable.setEnableOnOriginalThread(true); assertWrapped(callable.call()); } @@ -120,13 +143,13 @@ public class DelegatingSecurityContextCallableTests { callable = DelegatingSecurityContextCallable.create(delegate, null); SecurityContextHolder.clearContext(); // ensure callable is what sets up the // SecurityContextHolder - assertWrapped(callable.call()); + assertWrapped(callable); } @Test public void create() throws Exception { callable = DelegatingSecurityContextCallable.create(delegate, securityContext); - assertWrapped(callable.call()); + assertWrapped(callable); } // --- toString @@ -139,8 +162,12 @@ public class DelegatingSecurityContextCallableTests { assertThat(callable.toString()).isEqualTo(delegate.toString()); } - private void assertWrapped(Object actualResult) throws Exception { - assertThat(actualResult).isEqualTo(callableResult); + private void assertWrapped(Callable callable) throws Exception { + Future submit = executor.submit(callable); + assertWrapped(submit.get()); + } + + private void assertWrapped(Object callableResult) throws Exception { verify(delegate).call(); assertThat(SecurityContextHolder.getContext()).isEqualTo( SecurityContextHolder.createEmptyContext()); diff --git a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java index c7031e0fae..fc5c6e70c1 100644 --- a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java @@ -16,6 +16,10 @@ import static org.fest.assertions.Assertions.assertThat; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -24,6 +28,8 @@ import org.mockito.Mock; import org.mockito.invocation.InvocationOnMock; import org.mockito.runners.MockitoJUnitRunner; import org.mockito.stubbing.Answer; +import org.springframework.core.task.SyncTaskExecutor; +import org.springframework.core.task.support.ExecutorServiceAdapter; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; @@ -43,6 +49,8 @@ public class DelegatingSecurityContextRunnableTests { private Runnable runnable; + private ExecutorService executor; + @Before public void setUp() throws Exception { doAnswer(new Answer() { @@ -51,6 +59,8 @@ public class DelegatingSecurityContextRunnableTests { return null; } }).when(delegate).run(); + + executor = Executors.newFixedThreadPool(1); } @After @@ -85,8 +95,7 @@ public class DelegatingSecurityContextRunnableTests { @Test public void call() throws Exception { runnable = new DelegatingSecurityContextRunnable(delegate, securityContext); - runnable.run(); - assertWrapped(); + assertWrapped(runnable); } @Test @@ -95,8 +104,26 @@ public class DelegatingSecurityContextRunnableTests { runnable = new DelegatingSecurityContextRunnable(delegate); SecurityContextHolder.clearContext(); // ensure runnable is what sets up the // SecurityContextHolder - runnable.run(); - assertWrapped(); + assertWrapped(runnable); + } + + // SEC-3031 + @Test + public void callOnSameThread() throws Exception { + executor = synchronousExecutor(); + runnable = new DelegatingSecurityContextRunnable(delegate, + securityContext); + securityContext = SecurityContextHolder.createEmptyContext(); + assertWrapped(runnable); + } + + @Test + public void callOnSameThreadExplicitlyEnabled() throws Exception { + executor = synchronousExecutor(); + DelegatingSecurityContextRunnable runnable = new DelegatingSecurityContextRunnable(delegate, + securityContext); + runnable.setEnableOnOriginalThread(true); + assertWrapped(runnable); } // --- create --- @@ -112,20 +139,18 @@ public class DelegatingSecurityContextRunnableTests { } @Test - public void createNullSecurityContext() { + public void createNullSecurityContext() throws Exception { SecurityContextHolder.setContext(securityContext); runnable = DelegatingSecurityContextRunnable.create(delegate, null); SecurityContextHolder.clearContext(); // ensure runnable is what sets up the // SecurityContextHolder - runnable.run(); - assertWrapped(); + assertWrapped(runnable); } @Test - public void create() { + public void create() throws Exception { runnable = DelegatingSecurityContextRunnable.create(delegate, securityContext); - runnable.run(); - assertWrapped(); + assertWrapped(runnable); } // --- toString @@ -137,9 +162,15 @@ public class DelegatingSecurityContextRunnableTests { assertThat(runnable.toString()).isEqualTo(delegate.toString()); } - private void assertWrapped() { + private void assertWrapped(Runnable runnable) throws Exception { + Future submit = executor.submit(runnable); + submit.get(); verify(delegate).run(); assertThat(SecurityContextHolder.getContext()).isEqualTo( SecurityContextHolder.createEmptyContext()); } + + private static ExecutorService synchronousExecutor() { + return new ExecutorServiceAdapter(new SyncTaskExecutor()); + } } \ No newline at end of file diff --git a/web/src/test/java/org/springframework/security/web/servletapi/SecurityContextHolderAwareRequestFilterTests.java b/web/src/test/java/org/springframework/security/web/servletapi/SecurityContextHolderAwareRequestFilterTests.java index ab27dab8ce..076d666658 100644 --- a/web/src/test/java/org/springframework/security/web/servletapi/SecurityContextHolderAwareRequestFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/servletapi/SecurityContextHolderAwareRequestFilterTests.java @@ -321,7 +321,7 @@ public class SecurityContextHolderAwareRequestFilterTests { .getValue(); assertThat(WhiteboxImpl.getInternalState(wrappedRunnable, SecurityContext.class)) .isEqualTo(context); - assertThat(WhiteboxImpl.getInternalState(wrappedRunnable, Runnable.class)) + assertThat(WhiteboxImpl.getInternalState(wrappedRunnable, "delegate")) .isEqualTo(runnable); } @@ -348,7 +348,7 @@ public class SecurityContextHolderAwareRequestFilterTests { .getValue(); assertThat(WhiteboxImpl.getInternalState(wrappedRunnable, SecurityContext.class)) .isEqualTo(context); - assertThat(WhiteboxImpl.getInternalState(wrappedRunnable, Runnable.class)) + assertThat(WhiteboxImpl.getInternalState(wrappedRunnable, "delegate")) .isEqualTo(runnable); } @@ -375,7 +375,7 @@ public class SecurityContextHolderAwareRequestFilterTests { .getValue(); assertThat(WhiteboxImpl.getInternalState(wrappedRunnable, SecurityContext.class)) .isEqualTo(context); - assertThat(WhiteboxImpl.getInternalState(wrappedRunnable, Runnable.class)) + assertThat(WhiteboxImpl.getInternalState(wrappedRunnable, "delegate")) .isEqualTo(runnable); }