diff --git a/core/src/main/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategy.java b/core/src/main/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategy.java index 1ab3ac9e63..06158272e7 100644 --- a/core/src/main/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategy.java +++ b/core/src/main/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategy.java @@ -18,6 +18,8 @@ package org.springframework.security.core.context; import java.util.Arrays; import java.util.Collection; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; import org.springframework.util.Assert; @@ -127,9 +129,9 @@ public final class ListeningSecurityContextHolderStrategy implements SecurityCon */ @Override public void clearContext() { - SecurityContext from = getContext(); + Supplier deferred = this.delegate.getDeferredContext(); this.delegate.clearContext(); - publish(from, null); + publish(new SecurityContextChangedEvent(deferred, SecurityContextChangedEvent.NO_CONTEXT)); } /** @@ -140,14 +142,28 @@ public final class ListeningSecurityContextHolderStrategy implements SecurityCon return this.delegate.getContext(); } + /** + * {@inheritDoc} + */ + @Override + public Supplier getDeferredContext() { + return this.delegate.getDeferredContext(); + } + /** * {@inheritDoc} */ @Override public void setContext(SecurityContext context) { - SecurityContext from = getContext(); - this.delegate.setContext(context); - publish(from, context); + setDeferredContext(() -> context); + } + + /** + * {@inheritDoc} + */ + @Override + public void setDeferredContext(Supplier deferredContext) { + this.delegate.setDeferredContext(new PublishOnceSupplier(getDeferredContext(), deferredContext)); } /** @@ -158,14 +174,42 @@ public final class ListeningSecurityContextHolderStrategy implements SecurityCon return this.delegate.createEmptyContext(); } - private void publish(SecurityContext previous, SecurityContext current) { - if (previous == current) { - return; - } - SecurityContextChangedEvent event = new SecurityContextChangedEvent(previous, current); + private void publish(SecurityContextChangedEvent event) { for (SecurityContextChangedListener listener : this.listeners) { listener.securityContextChanged(event); } } + class PublishOnceSupplier implements Supplier { + + private final AtomicBoolean isPublished = new AtomicBoolean(false); + + private final Supplier old; + + private final Supplier updated; + + PublishOnceSupplier(Supplier old, Supplier updated) { + if (old instanceof PublishOnceSupplier) { + this.old = ((PublishOnceSupplier) old).updated; + } + else { + this.old = old; + } + this.updated = updated; + } + + @Override + public SecurityContext get() { + SecurityContext updated = this.updated.get(); + if (this.isPublished.compareAndSet(false, true)) { + SecurityContext old = this.old.get(); + if (old != updated) { + publish(new SecurityContextChangedEvent(old, updated)); + } + } + return updated; + } + + } + } diff --git a/core/src/main/java/org/springframework/security/core/context/SecurityContextChangedEvent.java b/core/src/main/java/org/springframework/security/core/context/SecurityContextChangedEvent.java index 17cd787ad5..c14125c475 100644 --- a/core/src/main/java/org/springframework/security/core/context/SecurityContextChangedEvent.java +++ b/core/src/main/java/org/springframework/security/core/context/SecurityContextChangedEvent.java @@ -16,6 +16,8 @@ package org.springframework.security.core.context; +import java.util.function.Supplier; + import org.springframework.context.ApplicationEvent; /** @@ -26,9 +28,24 @@ import org.springframework.context.ApplicationEvent; */ public class SecurityContextChangedEvent extends ApplicationEvent { - private final SecurityContext oldContext; + public static final Supplier NO_CONTEXT = () -> null; - private final SecurityContext newContext; + private final Supplier oldContext; + + private final Supplier newContext; + + /** + * Construct an event + * @param oldContext the old security context + * @param newContext the new security context, use + * {@link SecurityContextChangedEvent#NO_CONTEXT} for if the context is cleared + * @since 5.8 + */ + public SecurityContextChangedEvent(Supplier oldContext, Supplier newContext) { + super(SecurityContextHolder.class); + this.oldContext = oldContext; + this.newContext = newContext; + } /** * Construct an event @@ -36,9 +53,7 @@ public class SecurityContextChangedEvent extends ApplicationEvent { * @param newContext the new security context */ public SecurityContextChangedEvent(SecurityContext oldContext, SecurityContext newContext) { - super(SecurityContextHolder.class); - this.oldContext = oldContext; - this.newContext = newContext; + this(() -> oldContext, (newContext != null) ? () -> newContext : NO_CONTEXT); } /** @@ -47,7 +62,7 @@ public class SecurityContextChangedEvent extends ApplicationEvent { * @return the previous {@link SecurityContext} */ public SecurityContext getOldContext() { - return this.oldContext; + return this.oldContext.get(); } /** @@ -56,7 +71,21 @@ public class SecurityContextChangedEvent extends ApplicationEvent { * @return the current {@link SecurityContext} */ public SecurityContext getNewContext() { - return this.newContext; + return this.newContext.get(); + } + + /** + * Say whether the event is a context-clearing event. + * + *

+ * This method is handy for avoiding looking up the new context to confirm it is a + * cleared event. + * @return {@code true} if the new context is + * {@link SecurityContextChangedEvent#NO_CONTEXT} + * @since 5.8 + */ + public boolean isCleared() { + return this.newContext == NO_CONTEXT; } } diff --git a/core/src/test/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategyTests.java b/core/src/test/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategyTests.java index 46f6a60b4e..a8bb5a06b2 100644 --- a/core/src/test/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategyTests.java +++ b/core/src/test/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategyTests.java @@ -16,27 +16,36 @@ package org.springframework.security.core.context; -import org.junit.jupiter.api.Test; +import java.util.function.Supplier; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.security.authentication.TestingAuthenticationToken; + +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; public class ListeningSecurityContextHolderStrategyTests { @Test public void setContextWhenInvokedThenListenersAreNotified() { - SecurityContextHolderStrategy delegate = mock(SecurityContextHolderStrategy.class); + SecurityContextHolderStrategy delegate = spy(new MockSecurityContextHolderStrategy()); SecurityContextChangedListener one = mock(SecurityContextChangedListener.class); SecurityContextChangedListener two = mock(SecurityContextChangedListener.class); SecurityContextHolderStrategy strategy = new ListeningSecurityContextHolderStrategy(delegate, one, two); given(delegate.createEmptyContext()).willReturn(new SecurityContextImpl()); SecurityContext context = strategy.createEmptyContext(); strategy.setContext(context); - verify(delegate).setContext(context); + strategy.getContext(); verify(one).securityContextChanged(any()); verify(two).securityContextChanged(any()); } @@ -49,10 +58,68 @@ public class ListeningSecurityContextHolderStrategyTests { SecurityContext context = new SecurityContextImpl(); given(delegate.getContext()).willReturn(context); strategy.setContext(strategy.getContext()); - verify(delegate).setContext(context); + strategy.getContext(); verifyNoInteractions(listener); } + @Test + public void clearContextWhenNoGetContextThenContextIsNotRead() { + SecurityContextHolderStrategy delegate = mock(SecurityContextHolderStrategy.class); + SecurityContextChangedListener listener = mock(SecurityContextChangedListener.class); + SecurityContextHolderStrategy strategy = new ListeningSecurityContextHolderStrategy(delegate, listener); + Supplier context = mock(Supplier.class); + ArgumentCaptor event = ArgumentCaptor.forClass(SecurityContextChangedEvent.class); + given(delegate.getDeferredContext()).willReturn(context); + given(delegate.getContext()).willAnswer((invocation) -> context.get()); + strategy.clearContext(); + verifyNoInteractions(context); + verify(listener).securityContextChanged(event.capture()); + assertThat(event.getValue().isCleared()).isTrue(); + strategy.getContext(); + verify(context).get(); + strategy.clearContext(); + verifyNoMoreInteractions(context); + } + + @Test + public void getContextWhenCalledMultipleTimesThenEventPublishedOnce() { + SecurityContextHolderStrategy delegate = new MockSecurityContextHolderStrategy(); + SecurityContextChangedListener listener = mock(SecurityContextChangedListener.class); + SecurityContextHolderStrategy strategy = new ListeningSecurityContextHolderStrategy(delegate, listener); + strategy.setContext(new SecurityContextImpl()); + verifyNoInteractions(listener); + strategy.getContext(); + verify(listener).securityContextChanged(any()); + strategy.getContext(); + verifyNoMoreInteractions(listener); + } + + @Test + public void setContextWhenCalledMultipleTimesThenPublishedEventsAlign() { + SecurityContextHolderStrategy delegate = new MockSecurityContextHolderStrategy(); + SecurityContextChangedListener listener = mock(SecurityContextChangedListener.class); + SecurityContextHolderStrategy strategy = new ListeningSecurityContextHolderStrategy(delegate, listener); + SecurityContext one = new SecurityContextImpl(new TestingAuthenticationToken("user", "pass")); + SecurityContext two = new SecurityContextImpl(new TestingAuthenticationToken("admin", "pass")); + ArgumentCaptor event = ArgumentCaptor.forClass(SecurityContextChangedEvent.class); + strategy.setContext(one); + strategy.setContext(two); + verifyNoInteractions(listener); + strategy.getContext(); + verify(listener).securityContextChanged(event.capture()); + assertThat(event.getValue().getOldContext()).isEqualTo(one); + assertThat(event.getValue().getNewContext()).isEqualTo(two); + strategy.getContext(); + verifyNoMoreInteractions(listener); + strategy.setContext(one); + verifyNoMoreInteractions(listener); + reset(listener); + strategy.getContext(); + verify(listener).securityContextChanged(event.capture()); + assertThat(event.getValue().getOldContext()).isEqualTo(two); + assertThat(event.getValue().getNewContext()).isEqualTo(one); + } + @Test public void constructorWhenNullDelegateThenIllegalArgument() { assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy( diff --git a/core/src/test/java/org/springframework/security/core/context/MockSecurityContextHolderStrategy.java b/core/src/test/java/org/springframework/security/core/context/MockSecurityContextHolderStrategy.java index 3f4ddfba11..3977925cae 100644 --- a/core/src/test/java/org/springframework/security/core/context/MockSecurityContextHolderStrategy.java +++ b/core/src/test/java/org/springframework/security/core/context/MockSecurityContextHolderStrategy.java @@ -16,23 +16,35 @@ package org.springframework.security.core.context; +import java.util.function.Supplier; + public class MockSecurityContextHolderStrategy implements SecurityContextHolderStrategy { - private SecurityContext context; + private Supplier context = () -> null; @Override public void clearContext() { - this.context = null; + this.context = () -> null; } @Override public SecurityContext getContext() { + return this.context.get(); + } + + @Override + public Supplier getDeferredContext() { return this.context; } @Override public void setContext(SecurityContext context) { - this.context = context; + this.context = () -> context; + } + + @Override + public void setDeferredContext(Supplier deferredContext) { + this.context = deferredContext; } @Override