Fix SecurityContext creation for TEST_EXECUTION

Currently, there is support for setting up a SecurityContext after @Before by
using TestExecutionEvent.TEST_EXECUTION. The current implementation, however,
already creates the SecurityContext in @Before and just does not set it yet.
This leads to issues like #6591. For the case of @WithUserDetails, the
creation of the SecurityContext already looks up a user from the repository.
If the user was inserted in @Before, the user is not found despite using
TestExecutionEvent.TEST_EXECUTION. This commit changes the creation of the
SecurityContext to happen after @Before if using
TestExecutionEvent.TEST_EXECUTION.

Closes gh-6591
This commit is contained in:
Markus Gabriel 2020-06-21 13:20:59 +02:00 committed by Rob Winch
parent c71352c548
commit 97ee6d66f1
2 changed files with 46 additions and 20 deletions

View File

@ -17,6 +17,7 @@ package org.springframework.security.test.context.support;
import java.lang.annotation.Annotation; import java.lang.annotation.Annotation;
import java.lang.reflect.AnnotatedElement; import java.lang.reflect.AnnotatedElement;
import java.util.function.Supplier;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.core.GenericTypeResolver; import org.springframework.core.GenericTypeResolver;
@ -69,11 +70,12 @@ public class WithSecurityContextTestExecutionListener
return; return;
} }
SecurityContext securityContext = testSecurityContext.securityContext; Supplier<SecurityContext> supplier = testSecurityContext
.getSecurityContextSupplier();
if (testSecurityContext.getTestExecutionEvent() == TestExecutionEvent.TEST_METHOD) { if (testSecurityContext.getTestExecutionEvent() == TestExecutionEvent.TEST_METHOD) {
TestSecurityContextHolder.setContext(securityContext); TestSecurityContextHolder.setContext(supplier.get());
} else { } else {
testContext.setAttribute(SECURITY_CONTEXT_ATTR_NAME, securityContext); testContext.setAttribute(SECURITY_CONTEXT_ATTR_NAME, supplier);
} }
} }
@ -83,9 +85,10 @@ public class WithSecurityContextTestExecutionListener
*/ */
@Override @Override
public void beforeTestExecution(TestContext testContext) { public void beforeTestExecution(TestContext testContext) {
SecurityContext securityContext = (SecurityContext) testContext.removeAttribute(SECURITY_CONTEXT_ATTR_NAME); Supplier<SecurityContext> supplier = (Supplier<SecurityContext>) testContext
if (securityContext != null) { .removeAttribute(SECURITY_CONTEXT_ATTR_NAME);
TestSecurityContextHolder.setContext(securityContext); if (supplier != null) {
TestSecurityContextHolder.setContext(supplier.get());
} }
} }
@ -118,14 +121,16 @@ public class WithSecurityContextTestExecutionListener
.resolveTypeArgument(factory.getClass(), .resolveTypeArgument(factory.getClass(),
WithSecurityContextFactory.class); WithSecurityContextFactory.class);
Annotation annotation = findAnnotation(annotated, type); Annotation annotation = findAnnotation(annotated, type);
Supplier<SecurityContext> supplier = () -> {
try {
return factory.createSecurityContext(annotation);
} catch (RuntimeException e) {
throw new IllegalStateException(
"Unable to create SecurityContext using " + annotation, e);
}
};
TestExecutionEvent initialize = withSecurityContext.setupBefore(); TestExecutionEvent initialize = withSecurityContext.setupBefore();
try { return new TestSecurityContext(supplier, initialize);
return new TestSecurityContext(factory.createSecurityContext(annotation), initialize);
}
catch (RuntimeException e) {
throw new IllegalStateException(
"Unable to create SecurityContext using " + annotation, e);
}
} }
private Annotation findAnnotation(AnnotatedElement annotated, private Annotation findAnnotation(AnnotatedElement annotated,
@ -179,16 +184,17 @@ public class WithSecurityContextTestExecutionListener
} }
static class TestSecurityContext { static class TestSecurityContext {
private final SecurityContext securityContext; private final Supplier<SecurityContext> securityContextSupplier;
private final TestExecutionEvent testExecutionEvent; private final TestExecutionEvent testExecutionEvent;
TestSecurityContext(SecurityContext securityContext, TestExecutionEvent testExecutionEvent) { TestSecurityContext(Supplier<SecurityContext> securityContextSupplier,
this.securityContext = securityContext; TestExecutionEvent testExecutionEvent) {
this.securityContextSupplier = securityContextSupplier;
this.testExecutionEvent = testExecutionEvent; this.testExecutionEvent = testExecutionEvent;
} }
public SecurityContext getSecurityContext() { public Supplier<SecurityContext> getSecurityContextSupplier() {
return this.securityContext; return this.securityContextSupplier;
} }
public TestExecutionEvent getTestExecutionEvent() { public TestExecutionEvent getTestExecutionEvent() {

View File

@ -21,6 +21,8 @@ import org.junit.ClassRule;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
@ -36,6 +38,7 @@ import org.springframework.test.context.junit4.rules.SpringClassRule;
import org.springframework.test.context.junit4.rules.SpringMethodRule; import org.springframework.test.context.junit4.rules.SpringMethodRule;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.function.Supplier;
import static org.assertj.core.api.Assertions.*; import static org.assertj.core.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
@ -102,7 +105,23 @@ public class WithSecurityContextTestExecutionListenerTests {
this.listener.beforeTestMethod(this.testContext); this.listener.beforeTestMethod(this.testContext);
assertThat(TestSecurityContextHolder.getContext().getAuthentication()).isNull(); assertThat(TestSecurityContextHolder.getContext().getAuthentication()).isNull();
verify(this.testContext).setAttribute(eq(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME), any(SecurityContext.class)); verify(this.testContext).setAttribute(eq(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME)
, ArgumentMatchers.<Supplier<SecurityContext>>any());
}
@Test
@SuppressWarnings("unchecked")
public void beforeTestMethodWhenWithMockUserTestExecutionThenTestContextSupplierOk() throws Exception {
Method testMethod = TheTest.class.getMethod("withMockUserTestExecution");
when(this.testContext.getApplicationContext()).thenReturn(this.applicationContext);
when(this.testContext.getTestMethod()).thenReturn(testMethod);
this.listener.beforeTestMethod(this.testContext);
ArgumentCaptor<Supplier<SecurityContext>> supplierCaptor = ArgumentCaptor.forClass(Supplier.class);
verify(this.testContext).setAttribute(eq(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME),
supplierCaptor.capture());
assertThat(supplierCaptor.getValue().get().getAuthentication()).isNotNull();
} }
@Test @Test
@ -116,7 +135,8 @@ public class WithSecurityContextTestExecutionListenerTests {
public void beforeTestExecutionWhenTestContextNotNullThenSecurityContextSet() { public void beforeTestExecutionWhenTestContextNotNullThenSecurityContextSet() {
SecurityContextImpl securityContext = new SecurityContextImpl(); SecurityContextImpl securityContext = new SecurityContextImpl();
securityContext.setAuthentication(new TestingAuthenticationToken("user", "passsword", "ROLE_USER")); securityContext.setAuthentication(new TestingAuthenticationToken("user", "passsword", "ROLE_USER"));
when(this.testContext.removeAttribute(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME)).thenReturn(securityContext); Supplier<SecurityContext> supplier = () -> securityContext;
when(this.testContext.removeAttribute(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME)).thenReturn(supplier);
this.listener.beforeTestExecution(this.testContext); this.listener.beforeTestExecution(this.testContext);