Use SecurityContextHolderStrategy for Context Propagation

Issue gh-11060
This commit is contained in:
Josh Cummings 2022-06-21 16:38:56 -06:00
parent d18ff25b95
commit 459003e1b3
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
8 changed files with 227 additions and 36 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2016 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -19,6 +19,9 @@ package org.springframework.security.concurrent;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert;
/** /**
* An internal support class that wraps {@link Callable} with * An internal support class that wraps {@link Callable} with
@ -30,6 +33,9 @@ import org.springframework.security.core.context.SecurityContext;
*/ */
abstract class AbstractDelegatingSecurityContextSupport { abstract class AbstractDelegatingSecurityContextSupport {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private final SecurityContext securityContext; private final SecurityContext securityContext;
/** /**
@ -44,12 +50,19 @@ abstract class AbstractDelegatingSecurityContextSupport {
this.securityContext = securityContext; this.securityContext = securityContext;
} }
void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
protected final Runnable wrap(Runnable delegate) { protected final Runnable wrap(Runnable delegate) {
return DelegatingSecurityContextRunnable.create(delegate, this.securityContext); return DelegatingSecurityContextRunnable.create(delegate, this.securityContext,
this.securityContextHolderStrategy);
} }
protected final <T> Callable<T> wrap(Callable<T> delegate) { protected final <T> Callable<T> wrap(Callable<T> delegate) {
return DelegatingSecurityContextCallable.create(delegate, this.securityContext); return DelegatingSecurityContextCallable.create(delegate, this.securityContext,
this.securityContextHolderStrategy);
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@ import java.util.concurrent.Callable;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
@ -40,10 +41,15 @@ public final class DelegatingSecurityContextCallable<V> implements Callable<V> {
private final Callable<V> delegate; private final Callable<V> delegate;
private final boolean explicitSecurityContextProvided;
/** /**
* The {@link SecurityContext} that the delegate {@link Callable} will be ran as. * The {@link SecurityContext} that the delegate {@link Callable} will be ran as.
*/ */
private final SecurityContext delegateSecurityContext; private SecurityContext delegateSecurityContext;
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
/** /**
* The {@link SecurityContext} that was on the {@link SecurityContextHolder} prior to * The {@link SecurityContext} that was on the {@link SecurityContextHolder} prior to
@ -60,10 +66,7 @@ public final class DelegatingSecurityContextCallable<V> implements Callable<V> {
* {@link Callable}. Cannot be null. * {@link Callable}. Cannot be null.
*/ */
public DelegatingSecurityContextCallable(Callable<V> delegate, SecurityContext securityContext) { public DelegatingSecurityContextCallable(Callable<V> delegate, SecurityContext securityContext) {
Assert.notNull(delegate, "delegate cannot be null"); this(delegate, securityContext, true);
Assert.notNull(securityContext, "securityContext cannot be null");
this.delegate = delegate;
this.delegateSecurityContext = securityContext;
} }
/** /**
@ -73,28 +76,51 @@ public final class DelegatingSecurityContextCallable<V> implements Callable<V> {
* {@link SecurityContext}. Cannot be null. * {@link SecurityContext}. Cannot be null.
*/ */
public DelegatingSecurityContextCallable(Callable<V> delegate) { public DelegatingSecurityContextCallable(Callable<V> delegate) {
this(delegate, SecurityContextHolder.getContext()); this(delegate, SecurityContextHolder.getContext(), false);
}
private DelegatingSecurityContextCallable(Callable<V> delegate, SecurityContext securityContext,
boolean explicitSecurityContextProvided) {
Assert.notNull(delegate, "delegate cannot be null");
Assert.notNull(securityContext, "securityContext cannot be null");
this.delegate = delegate;
this.delegateSecurityContext = securityContext;
this.explicitSecurityContextProvided = explicitSecurityContextProvided;
} }
@Override @Override
public V call() throws Exception { public V call() throws Exception {
this.originalSecurityContext = SecurityContextHolder.getContext(); this.originalSecurityContext = this.securityContextHolderStrategy.getContext();
try { try {
SecurityContextHolder.setContext(this.delegateSecurityContext); this.securityContextHolderStrategy.setContext(this.delegateSecurityContext);
return this.delegate.call(); return this.delegate.call();
} }
finally { finally {
SecurityContext emptyContext = SecurityContextHolder.createEmptyContext(); SecurityContext emptyContext = this.securityContextHolderStrategy.createEmptyContext();
if (emptyContext.equals(this.originalSecurityContext)) { if (emptyContext.equals(this.originalSecurityContext)) {
SecurityContextHolder.clearContext(); this.securityContextHolderStrategy.clearContext();
} }
else { else {
SecurityContextHolder.setContext(this.originalSecurityContext); this.securityContextHolderStrategy.setContext(this.originalSecurityContext);
} }
this.originalSecurityContext = null; this.originalSecurityContext = null;
} }
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
if (!this.explicitSecurityContextProvided) {
this.delegateSecurityContext = securityContextHolderStrategy.getContext();
}
}
@Override @Override
public String toString() { public String toString() {
return this.delegate.toString(); return this.delegate.toString();
@ -116,4 +142,15 @@ public final class DelegatingSecurityContextCallable<V> implements Callable<V> {
: new DelegatingSecurityContextCallable<>(delegate); : new DelegatingSecurityContextCallable<>(delegate);
} }
static <V> Callable<V> create(Callable<V> delegate, SecurityContext securityContext,
SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(delegate, "delegate cannot be null");
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
DelegatingSecurityContextCallable<V> callable = (securityContext != null)
? new DelegatingSecurityContextCallable<>(delegate, securityContext)
: new DelegatingSecurityContextCallable<>(delegate);
callable.setSecurityContextHolderStrategy(securityContextHolderStrategy);
return callable;
}
} }

View File

@ -20,6 +20,7 @@ import java.util.concurrent.Executor;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
@ -66,4 +67,14 @@ public class DelegatingSecurityContextExecutor extends AbstractDelegatingSecurit
return this.delegate; return this.delegate;
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
super.setSecurityContextHolderStrategy(securityContextHolderStrategy);
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,6 +18,7 @@ package org.springframework.security.concurrent;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
@ -38,10 +39,15 @@ public final class DelegatingSecurityContextRunnable implements Runnable {
private final Runnable delegate; private final Runnable delegate;
private final boolean explicitSecurityContextProvided;
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
/** /**
* The {@link SecurityContext} that the delegate {@link Runnable} will be ran as. * The {@link SecurityContext} that the delegate {@link Runnable} will be ran as.
*/ */
private final SecurityContext delegateSecurityContext; private SecurityContext delegateSecurityContext;
/** /**
* The {@link SecurityContext} that was on the {@link SecurityContextHolder} prior to * The {@link SecurityContext} that was on the {@link SecurityContextHolder} prior to
@ -58,10 +64,7 @@ public final class DelegatingSecurityContextRunnable implements Runnable {
* {@link Runnable}. Cannot be null. * {@link Runnable}. Cannot be null.
*/ */
public DelegatingSecurityContextRunnable(Runnable delegate, SecurityContext securityContext) { public DelegatingSecurityContextRunnable(Runnable delegate, SecurityContext securityContext) {
Assert.notNull(delegate, "delegate cannot be null"); this(delegate, securityContext, true);
Assert.notNull(securityContext, "securityContext cannot be null");
this.delegate = delegate;
this.delegateSecurityContext = securityContext;
} }
/** /**
@ -71,28 +74,51 @@ public final class DelegatingSecurityContextRunnable implements Runnable {
* {@link SecurityContext}. Cannot be null. * {@link SecurityContext}. Cannot be null.
*/ */
public DelegatingSecurityContextRunnable(Runnable delegate) { public DelegatingSecurityContextRunnable(Runnable delegate) {
this(delegate, SecurityContextHolder.getContext()); this(delegate, SecurityContextHolder.getContext(), false);
}
private DelegatingSecurityContextRunnable(Runnable delegate, SecurityContext securityContext,
boolean explicitSecurityContextProvided) {
Assert.notNull(delegate, "delegate cannot be null");
Assert.notNull(securityContext, "securityContext cannot be null");
this.delegate = delegate;
this.delegateSecurityContext = securityContext;
this.explicitSecurityContextProvided = explicitSecurityContextProvided;
} }
@Override @Override
public void run() { public void run() {
this.originalSecurityContext = SecurityContextHolder.getContext(); this.originalSecurityContext = this.securityContextHolderStrategy.getContext();
try { try {
SecurityContextHolder.setContext(this.delegateSecurityContext); this.securityContextHolderStrategy.setContext(this.delegateSecurityContext);
this.delegate.run(); this.delegate.run();
} }
finally { finally {
SecurityContext emptyContext = SecurityContextHolder.createEmptyContext(); SecurityContext emptyContext = this.securityContextHolderStrategy.createEmptyContext();
if (emptyContext.equals(this.originalSecurityContext)) { if (emptyContext.equals(this.originalSecurityContext)) {
SecurityContextHolder.clearContext(); this.securityContextHolderStrategy.clearContext();
} }
else { else {
SecurityContextHolder.setContext(this.originalSecurityContext); this.securityContextHolderStrategy.setContext(this.originalSecurityContext);
} }
this.originalSecurityContext = null; this.originalSecurityContext = null;
} }
} }
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
if (!this.explicitSecurityContextProvided) {
this.delegateSecurityContext = this.securityContextHolderStrategy.getContext();
}
}
@Override @Override
public String toString() { public String toString() {
return this.delegate.toString(); return this.delegate.toString();
@ -114,4 +140,15 @@ public final class DelegatingSecurityContextRunnable implements Runnable {
: new DelegatingSecurityContextRunnable(delegate); : new DelegatingSecurityContextRunnable(delegate);
} }
static Runnable create(Runnable delegate, SecurityContext securityContext,
SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(delegate, "delegate cannot be null");
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
DelegatingSecurityContextRunnable runnable = (securityContext != null)
? new DelegatingSecurityContextRunnable(delegate, securityContext)
: new DelegatingSecurityContextRunnable(delegate);
runnable.setSecurityContextHolderStrategy(securityContextHolderStrategy);
return runnable;
}
} }

View File

@ -30,7 +30,9 @@ import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
/** /**
* Abstract base class for testing classes that extend * Abstract base class for testing classes that extend
@ -71,18 +73,18 @@ public abstract class AbstractDelegatingSecurityContextTestSupport {
protected MockedStatic<DelegatingSecurityContextRunnable> delegatingSecurityContextRunnable; protected MockedStatic<DelegatingSecurityContextRunnable> delegatingSecurityContextRunnable;
public final void explicitSecurityContextSetup() throws Exception { public final void explicitSecurityContextSetup() throws Exception {
this.delegatingSecurityContextCallable.when( this.delegatingSecurityContextCallable.when(() -> DelegatingSecurityContextCallable.create(eq(this.callable),
() -> DelegatingSecurityContextCallable.create(eq(this.callable), this.securityContextCaptor.capture())) this.securityContextCaptor.capture(), any())).thenReturn(this.wrappedCallable);
.thenReturn(this.wrappedCallable); this.delegatingSecurityContextRunnable.when(() -> DelegatingSecurityContextRunnable.create(eq(this.runnable),
this.delegatingSecurityContextRunnable.when( this.securityContextCaptor.capture(), any())).thenReturn(this.wrappedRunnable);
() -> DelegatingSecurityContextRunnable.create(eq(this.runnable), this.securityContextCaptor.capture()))
.thenReturn(this.wrappedRunnable);
} }
public final void currentSecurityContextSetup() throws Exception { public final void currentSecurityContextSetup() throws Exception {
this.delegatingSecurityContextCallable.when(() -> DelegatingSecurityContextCallable.create(this.callable, null)) this.delegatingSecurityContextCallable
.when(() -> DelegatingSecurityContextCallable.create(eq(this.callable), isNull(), any()))
.thenReturn(this.wrappedCallable); .thenReturn(this.wrappedCallable);
this.delegatingSecurityContextRunnable.when(() -> DelegatingSecurityContextRunnable.create(this.runnable, null)) this.delegatingSecurityContextRunnable
.when(() -> DelegatingSecurityContextRunnable.create(eq(this.runnable), isNull(), any()))
.thenReturn(this.wrappedRunnable); .thenReturn(this.wrappedRunnable);
} }

View File

@ -30,12 +30,16 @@ import org.mockito.internal.stubbing.answers.Returns;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.core.context.MockSecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
/** /**
@ -68,10 +72,15 @@ public class DelegatingSecurityContextCallableTests {
} }
private void givenDelegateCallWillAnswerWithCurrentSecurityContext() throws Exception { private void givenDelegateCallWillAnswerWithCurrentSecurityContext() throws Exception {
givenDelegateCallWillAnswerWithCurrentSecurityContext(SecurityContextHolder.getContextHolderStrategy());
}
private void givenDelegateCallWillAnswerWithCurrentSecurityContext(SecurityContextHolderStrategy strategy)
throws Exception {
given(this.delegate.call()).willAnswer(new Returns(this.callableResult) { given(this.delegate.call()).willAnswer(new Returns(this.callableResult) {
@Override @Override
public Object answer(InvocationOnMock invocation) throws Throwable { public Object answer(InvocationOnMock invocation) throws Throwable {
assertThat(SecurityContextHolder.getContext()) assertThat(strategy.getContext())
.isEqualTo(DelegatingSecurityContextCallableTests.this.securityContext); .isEqualTo(DelegatingSecurityContextCallableTests.this.securityContext);
return super.answer(invocation); return super.answer(invocation);
} }
@ -122,6 +131,20 @@ public class DelegatingSecurityContextCallableTests {
assertWrapped(this.callable); assertWrapped(this.callable);
} }
@Test
public void callDefaultSecurityContextWithCustomSecurityContextHolderStrategy() throws Exception {
SecurityContextHolderStrategy securityContextHolderStrategy = spy(new MockSecurityContextHolderStrategy());
givenDelegateCallWillAnswerWithCurrentSecurityContext(securityContextHolderStrategy);
securityContextHolderStrategy.setContext(this.securityContext);
DelegatingSecurityContextCallable<Object> callable = new DelegatingSecurityContextCallable<>(this.delegate);
callable.setSecurityContextHolderStrategy(securityContextHolderStrategy);
this.callable = callable;
// ensure callable is what sets up the SecurityContextHolder
securityContextHolderStrategy.clearContext();
assertWrapped(this.callable);
verify(securityContextHolderStrategy, atLeastOnce()).getContext();
}
// SEC-3031 // SEC-3031
@Test @Test
public void callOnSameThread() throws Exception { public void callOnSameThread() throws Exception {

View File

@ -30,12 +30,16 @@ import org.mockito.stubbing.Answer;
import org.springframework.core.task.SyncTaskExecutor; import org.springframework.core.task.SyncTaskExecutor;
import org.springframework.core.task.support.ExecutorServiceAdapter; import org.springframework.core.task.support.ExecutorServiceAdapter;
import org.springframework.security.core.context.MockSecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.BDDMockito.willAnswer; import static org.mockito.BDDMockito.willAnswer;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
/** /**
@ -73,6 +77,13 @@ public class DelegatingSecurityContextRunnableTests {
}).given(this.delegate).run(); }).given(this.delegate).run();
} }
private void givenDelegateRunWillAnswerWithCurrentSecurityContext(SecurityContextHolderStrategy strategy) {
willAnswer((Answer<Object>) (invocation) -> {
assertThat(strategy.getContext()).isEqualTo(this.securityContext);
return null;
}).given(this.delegate).run();
}
@AfterEach @AfterEach
public void tearDown() { public void tearDown() {
SecurityContextHolder.clearContext(); SecurityContextHolder.clearContext();
@ -117,6 +128,20 @@ public class DelegatingSecurityContextRunnableTests {
assertWrapped(this.runnable); assertWrapped(this.runnable);
} }
@Test
public void callDefaultSecurityContextWithCustomSecurityContextHolderStrategy() throws Exception {
SecurityContextHolderStrategy securityContextHolderStrategy = spy(new MockSecurityContextHolderStrategy());
givenDelegateRunWillAnswerWithCurrentSecurityContext(securityContextHolderStrategy);
securityContextHolderStrategy.setContext(this.securityContext);
DelegatingSecurityContextRunnable runnable = new DelegatingSecurityContextRunnable(this.delegate);
runnable.setSecurityContextHolderStrategy(securityContextHolderStrategy);
this.runnable = runnable;
// ensure callable is what sets up the SecurityContextHolder
securityContextHolderStrategy.clearContext();
assertWrapped(this.runnable);
verify(securityContextHolderStrategy, atLeastOnce()).getContext();
}
// SEC-3031 // SEC-3031
@Test @Test
public void callOnSameThread() throws Exception { public void callOnSameThread() throws Exception {

View File

@ -0,0 +1,43 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.core.context;
public class MockSecurityContextHolderStrategy implements SecurityContextHolderStrategy {
private SecurityContext context;
@Override
public void clearContext() {
this.context = null;
}
@Override
public SecurityContext getContext() {
return this.context;
}
@Override
public void setContext(SecurityContext context) {
this.context = context;
}
@Override
public SecurityContext createEmptyContext() {
return new SecurityContextImpl();
}
}