diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java index 80d9fd74ad..591e91aa07 100644 --- a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java @@ -16,9 +16,6 @@ package org.springframework.security.web.server.csrf; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; - import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; @@ -233,21 +230,16 @@ public class CsrfWebFilterTests { // gh-9113 @Test public void filterWhenSubscribingCsrfTokenMultipleTimesThenGenerateOnlyOnce() { + PublisherProbe chainResult = PublisherProbe.empty(); this.csrfFilter.setCsrfTokenRepository(this.repository); given(this.repository.loadToken(any())).willReturn(Mono.empty()); - AtomicInteger count = new AtomicInteger(); - given(this.repository.generateToken(any())).willReturn(Mono.fromCallable(() -> { - count.incrementAndGet(); - return this.token; - })); - given(this.repository.saveToken(any(), any())).willReturn(Mono.empty()); - AtomicReference> tokenFromExchange = new AtomicReference<>(); - given(this.chain.filter(any())).willReturn( - Mono.fromRunnable(() -> tokenFromExchange.set(this.get.getAttribute(CsrfToken.class.getName())))); + given(this.repository.generateToken(any())).willReturn(chainResult.mono()); + given(this.chain.filter(any())).willReturn(Mono.empty()); this.csrfFilter.filter(this.get, this.chain).block(); - tokenFromExchange.get().block(); - tokenFromExchange.get().block(); - assertThat(count).hasValue(1); + Mono result = this.get.getAttribute(CsrfToken.class.getName()); + result.block(); + result.block(); + assertThat(chainResult.subscribeCount()).isEqualTo(1); } @RestController