diff --git a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextCallable.java b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextCallable.java index 2114fc32b5..0b7928d41c 100644 --- a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextCallable.java +++ b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextCallable.java @@ -25,10 +25,8 @@ import org.springframework.util.Assert; * then removing the {@link SecurityContext} after the delegate has completed. *

*

- * By default the {@link SecurityContext} is only setup if {@link #call()} is - * invoked on a separate {@link Thread} than the - * {@link DelegatingSecurityContextCallable} was created on. This can be - * overridden by setting {@link #setEnableOnOriginalThread(boolean)} to true. + * If there is a {@link SecurityContext} that already exists, it will be + * restored after the {@link #call()} method is invoked. *

* * @author Rob Winch @@ -38,11 +36,18 @@ public final class DelegatingSecurityContextCallable implements Callable { private final Callable delegate; - private final SecurityContext securityContext; - private final Thread originalThread; + /** + * The {@link SecurityContext} that the delegate {@link Callable} will be + * ran as. + */ + private final SecurityContext delegateSecurityContext; - private boolean enableOnOriginalThread; + /** + * The {@link SecurityContext} that was on the {@link SecurityContextHolder} + * prior to being set to the delegateSecurityContext. + */ + private SecurityContext originalSecurityContext; /** * Creates a new {@link DelegatingSecurityContextCallable} with a specific @@ -57,8 +62,7 @@ public final class DelegatingSecurityContextCallable implements Callable { Assert.notNull(delegate, "delegate cannot be null"); Assert.notNull(securityContext, "securityContext cannot be null"); this.delegate = delegate; - this.securityContext = securityContext; - this.originalThread = Thread.currentThread(); + this.delegateSecurityContext = securityContext; } /** @@ -71,33 +75,21 @@ public final class DelegatingSecurityContextCallable implements Callable { this(delegate, SecurityContextHolder.getContext()); } - /** - * Determines if the SecurityContext should be transfered if {@link #call()} - * is invoked on the same {@link Thread} the - * {@link DelegatingSecurityContextCallable} was created on. - * - * @param enableOnOriginalThread - * if false (default), will only transfer the - * {@link SecurityContext} if {@link #call()} is invoked on a - * different {@link Thread} than the - * {@link DelegatingSecurityContextCallable} was created on. - * - * @since 4.0.2 - */ - public void setEnableOnOriginalThread(boolean enableOnOriginalThread) { - this.enableOnOriginalThread = enableOnOriginalThread; - } - public V call() throws Exception { - if(!enableOnOriginalThread && originalThread == Thread.currentThread()) { - return delegate.call(); - } + this.originalSecurityContext = SecurityContextHolder.getContext(); + try { - SecurityContextHolder.setContext(securityContext); + SecurityContextHolder.setContext(delegateSecurityContext); return delegate.call(); } finally { - SecurityContextHolder.clearContext(); + SecurityContext emptyContext = SecurityContextHolder.createEmptyContext(); + if(emptyContext.equals(originalSecurityContext)) { + SecurityContextHolder.clearContext(); + } else { + SecurityContextHolder.setContext(originalSecurityContext); + } + this.originalSecurityContext = null; } } diff --git a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnable.java b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnable.java index 95d3c76920..38a652c2f6 100644 --- a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnable.java +++ b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnable.java @@ -23,10 +23,8 @@ import org.springframework.util.Assert; * {@link SecurityContext} after the delegate has completed. *

*

- * By default the {@link SecurityContext} is only setup if {@link #run()} is - * invoked on a separate {@link Thread} than the - * {@link DelegatingSecurityContextRunnable} was created on. This can be - * overridden by setting {@link #setEnableOnOriginalThread(boolean)} to true. + * If there is a {@link SecurityContext} that already exists, it will be + * restored after the {@link #run()} method is invoked. *

* * @author Rob Winch @@ -36,11 +34,17 @@ public final class DelegatingSecurityContextRunnable implements Runnable { private final Runnable delegate; - private final SecurityContext securityContext; + /** + * The {@link SecurityContext} that the delegate {@link Runnable} will be + * ran as. + */ + private final SecurityContext delegateSecurityContext; - private final Thread originalThread; - - private boolean enableOnOriginalThread; + /** + * The {@link SecurityContext} that was on the {@link SecurityContextHolder} + * prior to being set to the delegateSecurityContext. + */ + private SecurityContext originalSecurityContext; /** * Creates a new {@link DelegatingSecurityContextRunnable} with a specific @@ -55,8 +59,7 @@ public final class DelegatingSecurityContextRunnable implements Runnable { Assert.notNull(delegate, "delegate cannot be null"); Assert.notNull(securityContext, "securityContext cannot be null"); this.delegate = delegate; - this.securityContext = securityContext; - this.originalThread = Thread.currentThread(); + this.delegateSecurityContext = securityContext; } /** @@ -69,33 +72,21 @@ public final class DelegatingSecurityContextRunnable implements Runnable { this(delegate, SecurityContextHolder.getContext()); } - /** - * Determines if the SecurityContext should be transfered if {@link #run()} - * is invoked on the same {@link Thread} the - * {@link DelegatingSecurityContextCallable} was created on. - * - * @param enableOnOriginalThread - * if false (default), will only transfer the - * {@link SecurityContext} if {@link #run()} is invoked on a - * different {@link Thread} than the - * {@link DelegatingSecurityContextCallable} was created on. - * @since 4.0.2 - */ - public void setEnableOnOriginalThread(boolean enableOnOriginalThread) { - this.enableOnOriginalThread = enableOnOriginalThread; - } - public void run() { - if(!enableOnOriginalThread && originalThread == Thread.currentThread()) { - delegate.run(); - return; - } + this.originalSecurityContext = SecurityContextHolder.getContext(); + try { - SecurityContextHolder.setContext(securityContext); + SecurityContextHolder.setContext(delegateSecurityContext); delegate.run(); } finally { - SecurityContextHolder.clearContext(); + SecurityContext emptyContext = SecurityContextHolder.createEmptyContext(); + if(emptyContext.equals(originalSecurityContext)) { + SecurityContextHolder.clearContext(); + } else { + SecurityContextHolder.setContext(originalSecurityContext); + } + this.originalSecurityContext = null; } } diff --git a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java index d6644de83d..a49e863344 100644 --- a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java @@ -50,9 +50,12 @@ public class DelegatingSecurityContextCallableTests { private ExecutorService executor; + private SecurityContext originalSecurityContext; + @Before @SuppressWarnings("serial") public void setUp() throws Exception { + originalSecurityContext = SecurityContextHolder.createEmptyContext(); when(delegate.call()).thenAnswer(new Returns(callableResult) { @Override public Object answer(InvocationOnMock invocation) throws Throwable { @@ -111,17 +114,10 @@ public class DelegatingSecurityContextCallableTests { // SEC-3031 @Test public void callOnSameThread() throws Exception { + originalSecurityContext = securityContext; + SecurityContextHolder.setContext(originalSecurityContext); callable = new DelegatingSecurityContextCallable(delegate, securityContext); - securityContext = SecurityContextHolder.createEmptyContext(); - assertWrapped(callable.call()); - } - - @Test - public void callOnSameThreadExplicitlyEnabled() throws Exception { - DelegatingSecurityContextCallable callable = new DelegatingSecurityContextCallable(delegate, - securityContext); - callable.setEnableOnOriginalThread(true); assertWrapped(callable.call()); } @@ -170,6 +166,6 @@ public class DelegatingSecurityContextCallableTests { private void assertWrapped(Object callableResult) throws Exception { verify(delegate).call(); assertThat(SecurityContextHolder.getContext()).isEqualTo( - SecurityContextHolder.createEmptyContext()); + originalSecurityContext); } } \ No newline at end of file diff --git a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java index fc5c6e70c1..1f1ad3718e 100644 --- a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java @@ -51,8 +51,11 @@ public class DelegatingSecurityContextRunnableTests { private ExecutorService executor; + private SecurityContext originalSecurityContext; + @Before public void setUp() throws Exception { + originalSecurityContext = SecurityContextHolder.createEmptyContext(); doAnswer(new Answer() { public Object answer(InvocationOnMock invocation) throws Throwable { assertThat(SecurityContextHolder.getContext()).isEqualTo(securityContext); @@ -110,19 +113,11 @@ public class DelegatingSecurityContextRunnableTests { // SEC-3031 @Test public void callOnSameThread() throws Exception { + originalSecurityContext = securityContext; + SecurityContextHolder.setContext(originalSecurityContext); executor = synchronousExecutor(); runnable = new DelegatingSecurityContextRunnable(delegate, securityContext); - securityContext = SecurityContextHolder.createEmptyContext(); - assertWrapped(runnable); - } - - @Test - public void callOnSameThreadExplicitlyEnabled() throws Exception { - executor = synchronousExecutor(); - DelegatingSecurityContextRunnable runnable = new DelegatingSecurityContextRunnable(delegate, - securityContext); - runnable.setEnableOnOriginalThread(true); assertWrapped(runnable); } @@ -167,7 +162,7 @@ public class DelegatingSecurityContextRunnableTests { submit.get(); verify(delegate).run(); assertThat(SecurityContextHolder.getContext()).isEqualTo( - SecurityContextHolder.createEmptyContext()); + originalSecurityContext); } private static ExecutorService synchronousExecutor() {