Use SecurityContextHolderStrategy for Context Propagation

Issue gh-11060
This commit is contained in:
Josh Cummings 2022-06-21 16:38:56 -06:00
parent 5357cb8c95
commit 38cb6c3172
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
8 changed files with 227 additions and 36 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,6 +19,9 @@ package org.springframework.security.concurrent;
import java.util.concurrent.Callable;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert;
/**
* An internal support class that wraps {@link Callable} with
@ -30,6 +33,9 @@ import org.springframework.security.core.context.SecurityContext;
*/
abstract class AbstractDelegatingSecurityContextSupport {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private final SecurityContext securityContext;
/**
@ -44,12 +50,19 @@ abstract class AbstractDelegatingSecurityContextSupport {
this.securityContext = securityContext;
}
void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
protected final Runnable wrap(Runnable delegate) {
return DelegatingSecurityContextRunnable.create(delegate, this.securityContext);
return DelegatingSecurityContextRunnable.create(delegate, this.securityContext,
this.securityContextHolderStrategy);
}
protected final <T> Callable<T> wrap(Callable<T> delegate) {
return DelegatingSecurityContextCallable.create(delegate, this.securityContext);
return DelegatingSecurityContextCallable.create(delegate, this.securityContext,
this.securityContextHolderStrategy);
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@ import java.util.concurrent.Callable;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert;
/**
@ -40,10 +41,15 @@ public final class DelegatingSecurityContextCallable<V> implements Callable<V> {
private final Callable<V> delegate;
private final boolean explicitSecurityContextProvided;
/**
* The {@link SecurityContext} that the delegate {@link Callable} will be ran as.
*/
private final SecurityContext delegateSecurityContext;
private SecurityContext delegateSecurityContext;
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
/**
* The {@link SecurityContext} that was on the {@link SecurityContextHolder} prior to
@ -60,10 +66,7 @@ public final class DelegatingSecurityContextCallable<V> implements Callable<V> {
* {@link Callable}. Cannot be null.
*/
public DelegatingSecurityContextCallable(Callable<V> delegate, SecurityContext securityContext) {
Assert.notNull(delegate, "delegate cannot be null");
Assert.notNull(securityContext, "securityContext cannot be null");
this.delegate = delegate;
this.delegateSecurityContext = securityContext;
this(delegate, securityContext, true);
}
/**
@ -73,28 +76,51 @@ public final class DelegatingSecurityContextCallable<V> implements Callable<V> {
* {@link SecurityContext}. Cannot be null.
*/
public DelegatingSecurityContextCallable(Callable<V> delegate) {
this(delegate, SecurityContextHolder.getContext());
this(delegate, SecurityContextHolder.getContext(), false);
}
private DelegatingSecurityContextCallable(Callable<V> delegate, SecurityContext securityContext,
boolean explicitSecurityContextProvided) {
Assert.notNull(delegate, "delegate cannot be null");
Assert.notNull(securityContext, "securityContext cannot be null");
this.delegate = delegate;
this.delegateSecurityContext = securityContext;
this.explicitSecurityContextProvided = explicitSecurityContextProvided;
}
@Override
public V call() throws Exception {
this.originalSecurityContext = SecurityContextHolder.getContext();
this.originalSecurityContext = this.securityContextHolderStrategy.getContext();
try {
SecurityContextHolder.setContext(this.delegateSecurityContext);
this.securityContextHolderStrategy.setContext(this.delegateSecurityContext);
return this.delegate.call();
}
finally {
SecurityContext emptyContext = SecurityContextHolder.createEmptyContext();
SecurityContext emptyContext = this.securityContextHolderStrategy.createEmptyContext();
if (emptyContext.equals(this.originalSecurityContext)) {
SecurityContextHolder.clearContext();
this.securityContextHolderStrategy.clearContext();
}
else {
SecurityContextHolder.setContext(this.originalSecurityContext);
this.securityContextHolderStrategy.setContext(this.originalSecurityContext);
}
this.originalSecurityContext = null;
}
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
if (!this.explicitSecurityContextProvided) {
this.delegateSecurityContext = securityContextHolderStrategy.getContext();
}
}
@Override
public String toString() {
return this.delegate.toString();
@ -116,4 +142,15 @@ public final class DelegatingSecurityContextCallable<V> implements Callable<V> {
: new DelegatingSecurityContextCallable<>(delegate);
}
static <V> Callable<V> create(Callable<V> delegate, SecurityContext securityContext,
SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(delegate, "delegate cannot be null");
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
DelegatingSecurityContextCallable<V> callable = (securityContext != null)
? new DelegatingSecurityContextCallable<>(delegate, securityContext)
: new DelegatingSecurityContextCallable<>(delegate);
callable.setSecurityContextHolderStrategy(securityContextHolderStrategy);
return callable;
}
}

View File

@ -20,6 +20,7 @@ import java.util.concurrent.Executor;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert;
/**
@ -66,4 +67,14 @@ public class DelegatingSecurityContextExecutor extends AbstractDelegatingSecurit
return this.delegate;
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
super.setSecurityContextHolderStrategy(securityContextHolderStrategy);
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,6 +18,7 @@ package org.springframework.security.concurrent;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert;
/**
@ -38,10 +39,15 @@ public final class DelegatingSecurityContextRunnable implements Runnable {
private final Runnable delegate;
private final boolean explicitSecurityContextProvided;
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
/**
* The {@link SecurityContext} that the delegate {@link Runnable} will be ran as.
*/
private final SecurityContext delegateSecurityContext;
private SecurityContext delegateSecurityContext;
/**
* The {@link SecurityContext} that was on the {@link SecurityContextHolder} prior to
@ -58,10 +64,7 @@ public final class DelegatingSecurityContextRunnable implements Runnable {
* {@link Runnable}. Cannot be null.
*/
public DelegatingSecurityContextRunnable(Runnable delegate, SecurityContext securityContext) {
Assert.notNull(delegate, "delegate cannot be null");
Assert.notNull(securityContext, "securityContext cannot be null");
this.delegate = delegate;
this.delegateSecurityContext = securityContext;
this(delegate, securityContext, true);
}
/**
@ -71,28 +74,51 @@ public final class DelegatingSecurityContextRunnable implements Runnable {
* {@link SecurityContext}. Cannot be null.
*/
public DelegatingSecurityContextRunnable(Runnable delegate) {
this(delegate, SecurityContextHolder.getContext());
this(delegate, SecurityContextHolder.getContext(), false);
}
private DelegatingSecurityContextRunnable(Runnable delegate, SecurityContext securityContext,
boolean explicitSecurityContextProvided) {
Assert.notNull(delegate, "delegate cannot be null");
Assert.notNull(securityContext, "securityContext cannot be null");
this.delegate = delegate;
this.delegateSecurityContext = securityContext;
this.explicitSecurityContextProvided = explicitSecurityContextProvided;
}
@Override
public void run() {
this.originalSecurityContext = SecurityContextHolder.getContext();
this.originalSecurityContext = this.securityContextHolderStrategy.getContext();
try {
SecurityContextHolder.setContext(this.delegateSecurityContext);
this.securityContextHolderStrategy.setContext(this.delegateSecurityContext);
this.delegate.run();
}
finally {
SecurityContext emptyContext = SecurityContextHolder.createEmptyContext();
SecurityContext emptyContext = this.securityContextHolderStrategy.createEmptyContext();
if (emptyContext.equals(this.originalSecurityContext)) {
SecurityContextHolder.clearContext();
this.securityContextHolderStrategy.clearContext();
}
else {
SecurityContextHolder.setContext(this.originalSecurityContext);
this.securityContextHolderStrategy.setContext(this.originalSecurityContext);
}
this.originalSecurityContext = null;
}
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
if (!this.explicitSecurityContextProvided) {
this.delegateSecurityContext = this.securityContextHolderStrategy.getContext();
}
}
@Override
public String toString() {
return this.delegate.toString();
@ -114,4 +140,15 @@ public final class DelegatingSecurityContextRunnable implements Runnable {
: new DelegatingSecurityContextRunnable(delegate);
}
static Runnable create(Runnable delegate, SecurityContext securityContext,
SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(delegate, "delegate cannot be null");
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
DelegatingSecurityContextRunnable runnable = (securityContext != null)
? new DelegatingSecurityContextRunnable(delegate, securityContext)
: new DelegatingSecurityContextRunnable(delegate);
runnable.setSecurityContextHolderStrategy(securityContextHolderStrategy);
return runnable;
}
}

View File

@ -30,7 +30,9 @@ import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
/**
* Abstract base class for testing classes that extend
@ -71,18 +73,18 @@ public abstract class AbstractDelegatingSecurityContextTestSupport {
protected MockedStatic<DelegatingSecurityContextRunnable> delegatingSecurityContextRunnable;
public final void explicitSecurityContextSetup() throws Exception {
this.delegatingSecurityContextCallable.when(
() -> DelegatingSecurityContextCallable.create(eq(this.callable), this.securityContextCaptor.capture()))
.thenReturn(this.wrappedCallable);
this.delegatingSecurityContextRunnable.when(
() -> DelegatingSecurityContextRunnable.create(eq(this.runnable), this.securityContextCaptor.capture()))
.thenReturn(this.wrappedRunnable);
this.delegatingSecurityContextCallable.when(() -> DelegatingSecurityContextCallable.create(eq(this.callable),
this.securityContextCaptor.capture(), any())).thenReturn(this.wrappedCallable);
this.delegatingSecurityContextRunnable.when(() -> DelegatingSecurityContextRunnable.create(eq(this.runnable),
this.securityContextCaptor.capture(), any())).thenReturn(this.wrappedRunnable);
}
public final void currentSecurityContextSetup() throws Exception {
this.delegatingSecurityContextCallable.when(() -> DelegatingSecurityContextCallable.create(this.callable, null))
this.delegatingSecurityContextCallable
.when(() -> DelegatingSecurityContextCallable.create(eq(this.callable), isNull(), any()))
.thenReturn(this.wrappedCallable);
this.delegatingSecurityContextRunnable.when(() -> DelegatingSecurityContextRunnable.create(this.runnable, null))
this.delegatingSecurityContextRunnable
.when(() -> DelegatingSecurityContextRunnable.create(eq(this.runnable), isNull(), any()))
.thenReturn(this.wrappedRunnable);
}

View File

@ -30,12 +30,16 @@ import org.mockito.internal.stubbing.answers.Returns;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.core.context.MockSecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
/**
@ -68,10 +72,15 @@ public class DelegatingSecurityContextCallableTests {
}
private void givenDelegateCallWillAnswerWithCurrentSecurityContext() throws Exception {
givenDelegateCallWillAnswerWithCurrentSecurityContext(SecurityContextHolder.getContextHolderStrategy());
}
private void givenDelegateCallWillAnswerWithCurrentSecurityContext(SecurityContextHolderStrategy strategy)
throws Exception {
given(this.delegate.call()).willAnswer(new Returns(this.callableResult) {
@Override
public Object answer(InvocationOnMock invocation) throws Throwable {
assertThat(SecurityContextHolder.getContext())
assertThat(strategy.getContext())
.isEqualTo(DelegatingSecurityContextCallableTests.this.securityContext);
return super.answer(invocation);
}
@ -122,6 +131,20 @@ public class DelegatingSecurityContextCallableTests {
assertWrapped(this.callable);
}
@Test
public void callDefaultSecurityContextWithCustomSecurityContextHolderStrategy() throws Exception {
SecurityContextHolderStrategy securityContextHolderStrategy = spy(new MockSecurityContextHolderStrategy());
givenDelegateCallWillAnswerWithCurrentSecurityContext(securityContextHolderStrategy);
securityContextHolderStrategy.setContext(this.securityContext);
DelegatingSecurityContextCallable<Object> callable = new DelegatingSecurityContextCallable<>(this.delegate);
callable.setSecurityContextHolderStrategy(securityContextHolderStrategy);
this.callable = callable;
// ensure callable is what sets up the SecurityContextHolder
securityContextHolderStrategy.clearContext();
assertWrapped(this.callable);
verify(securityContextHolderStrategy, atLeastOnce()).getContext();
}
// SEC-3031
@Test
public void callOnSameThread() throws Exception {

View File

@ -30,12 +30,16 @@ import org.mockito.stubbing.Answer;
import org.springframework.core.task.SyncTaskExecutor;
import org.springframework.core.task.support.ExecutorServiceAdapter;
import org.springframework.security.core.context.MockSecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.BDDMockito.willAnswer;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
/**
@ -73,6 +77,13 @@ public class DelegatingSecurityContextRunnableTests {
}).given(this.delegate).run();
}
private void givenDelegateRunWillAnswerWithCurrentSecurityContext(SecurityContextHolderStrategy strategy) {
willAnswer((Answer<Object>) (invocation) -> {
assertThat(strategy.getContext()).isEqualTo(this.securityContext);
return null;
}).given(this.delegate).run();
}
@AfterEach
public void tearDown() {
SecurityContextHolder.clearContext();
@ -117,6 +128,20 @@ public class DelegatingSecurityContextRunnableTests {
assertWrapped(this.runnable);
}
@Test
public void callDefaultSecurityContextWithCustomSecurityContextHolderStrategy() throws Exception {
SecurityContextHolderStrategy securityContextHolderStrategy = spy(new MockSecurityContextHolderStrategy());
givenDelegateRunWillAnswerWithCurrentSecurityContext(securityContextHolderStrategy);
securityContextHolderStrategy.setContext(this.securityContext);
DelegatingSecurityContextRunnable runnable = new DelegatingSecurityContextRunnable(this.delegate);
runnable.setSecurityContextHolderStrategy(securityContextHolderStrategy);
this.runnable = runnable;
// ensure callable is what sets up the SecurityContextHolder
securityContextHolderStrategy.clearContext();
assertWrapped(this.runnable);
verify(securityContextHolderStrategy, atLeastOnce()).getContext();
}
// SEC-3031
@Test
public void callOnSameThread() throws Exception {

View File

@ -0,0 +1,43 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.core.context;
public class MockSecurityContextHolderStrategy implements SecurityContextHolderStrategy {
private SecurityContext context;
@Override
public void clearContext() {
this.context = null;
}
@Override
public SecurityContext getContext() {
return this.context;
}
@Override
public void setContext(SecurityContext context) {
this.context = context;
}
@Override
public SecurityContext createEmptyContext() {
return new SecurityContextImpl();
}
}