diff --git a/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java b/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java index 72259a549b..0d7321a9f0 100644 --- a/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java +++ b/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java @@ -17,6 +17,7 @@ package org.springframework.security.test.context.support; import java.lang.annotation.Annotation; import java.lang.reflect.AnnotatedElement; +import java.util.function.Supplier; import org.springframework.beans.BeanUtils; import org.springframework.core.GenericTypeResolver; @@ -69,11 +70,12 @@ public class WithSecurityContextTestExecutionListener return; } - SecurityContext securityContext = testSecurityContext.securityContext; + Supplier supplier = testSecurityContext + .getSecurityContextSupplier(); if (testSecurityContext.getTestExecutionEvent() == TestExecutionEvent.TEST_METHOD) { - TestSecurityContextHolder.setContext(securityContext); + TestSecurityContextHolder.setContext(supplier.get()); } else { - testContext.setAttribute(SECURITY_CONTEXT_ATTR_NAME, securityContext); + testContext.setAttribute(SECURITY_CONTEXT_ATTR_NAME, supplier); } } @@ -83,9 +85,10 @@ public class WithSecurityContextTestExecutionListener */ @Override public void beforeTestExecution(TestContext testContext) { - SecurityContext securityContext = (SecurityContext) testContext.removeAttribute(SECURITY_CONTEXT_ATTR_NAME); - if (securityContext != null) { - TestSecurityContextHolder.setContext(securityContext); + Supplier supplier = (Supplier) testContext + .removeAttribute(SECURITY_CONTEXT_ATTR_NAME); + if (supplier != null) { + TestSecurityContextHolder.setContext(supplier.get()); } } @@ -118,14 +121,16 @@ public class WithSecurityContextTestExecutionListener .resolveTypeArgument(factory.getClass(), WithSecurityContextFactory.class); Annotation annotation = findAnnotation(annotated, type); + Supplier supplier = () -> { + try { + return factory.createSecurityContext(annotation); + } catch (RuntimeException e) { + throw new IllegalStateException( + "Unable to create SecurityContext using " + annotation, e); + } + }; TestExecutionEvent initialize = withSecurityContext.setupBefore(); - try { - return new TestSecurityContext(factory.createSecurityContext(annotation), initialize); - } - catch (RuntimeException e) { - throw new IllegalStateException( - "Unable to create SecurityContext using " + annotation, e); - } + return new TestSecurityContext(supplier, initialize); } private Annotation findAnnotation(AnnotatedElement annotated, @@ -179,16 +184,17 @@ public class WithSecurityContextTestExecutionListener } static class TestSecurityContext { - private final SecurityContext securityContext; + private final Supplier securityContextSupplier; private final TestExecutionEvent testExecutionEvent; - TestSecurityContext(SecurityContext securityContext, TestExecutionEvent testExecutionEvent) { - this.securityContext = securityContext; + TestSecurityContext(Supplier securityContextSupplier, + TestExecutionEvent testExecutionEvent) { + this.securityContextSupplier = securityContextSupplier; this.testExecutionEvent = testExecutionEvent; } - public SecurityContext getSecurityContext() { - return this.securityContext; + public Supplier getSecurityContextSupplier() { + return this.securityContextSupplier; } public TestExecutionEvent getTestExecutionEvent() { diff --git a/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListenerTests.java b/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListenerTests.java index fe7804f2c2..61dc8d7011 100644 --- a/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListenerTests.java +++ b/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListenerTests.java @@ -21,6 +21,8 @@ import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.springframework.beans.factory.annotation.Autowired; @@ -36,6 +38,7 @@ import org.springframework.test.context.junit4.rules.SpringClassRule; import org.springframework.test.context.junit4.rules.SpringMethodRule; import java.lang.reflect.Method; +import java.util.function.Supplier; import static org.assertj.core.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; @@ -102,7 +105,23 @@ public class WithSecurityContextTestExecutionListenerTests { this.listener.beforeTestMethod(this.testContext); assertThat(TestSecurityContextHolder.getContext().getAuthentication()).isNull(); - verify(this.testContext).setAttribute(eq(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME), any(SecurityContext.class)); + verify(this.testContext).setAttribute(eq(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME) + , ArgumentMatchers.>any()); + } + + @Test + @SuppressWarnings("unchecked") + public void beforeTestMethodWhenWithMockUserTestExecutionThenTestContextSupplierOk() throws Exception { + Method testMethod = TheTest.class.getMethod("withMockUserTestExecution"); + when(this.testContext.getApplicationContext()).thenReturn(this.applicationContext); + when(this.testContext.getTestMethod()).thenReturn(testMethod); + + this.listener.beforeTestMethod(this.testContext); + + ArgumentCaptor> supplierCaptor = ArgumentCaptor.forClass(Supplier.class); + verify(this.testContext).setAttribute(eq(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME), + supplierCaptor.capture()); + assertThat(supplierCaptor.getValue().get().getAuthentication()).isNotNull(); } @Test @@ -116,7 +135,8 @@ public class WithSecurityContextTestExecutionListenerTests { public void beforeTestExecutionWhenTestContextNotNullThenSecurityContextSet() { SecurityContextImpl securityContext = new SecurityContextImpl(); securityContext.setAuthentication(new TestingAuthenticationToken("user", "passsword", "ROLE_USER")); - when(this.testContext.removeAttribute(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME)).thenReturn(securityContext); + Supplier supplier = () -> securityContext; + when(this.testContext.removeAttribute(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME)).thenReturn(supplier); this.listener.beforeTestExecution(this.testContext);