diff --git a/webflux/src/main/java/org/springframework/security/web/server/context/SecurityContextRepositoryServerWebExchange.java b/webflux/src/main/java/org/springframework/security/web/server/context/SecurityContextRepositoryServerWebExchange.java index a3c657b6bc..dfa480a2d1 100644 --- a/webflux/src/main/java/org/springframework/security/web/server/context/SecurityContextRepositoryServerWebExchange.java +++ b/webflux/src/main/java/org/springframework/security/web/server/context/SecurityContextRepositoryServerWebExchange.java @@ -43,6 +43,7 @@ final class SecurityContextRepositoryServerWebExchange extends ServerWebExchange this.repository.load(this) .filter(c -> c.getAuthentication() != null) .flatMap(c -> Mono.just((T) c.getAuthentication())) + .switchIfEmpty( super.getPrincipal() ) ); } } diff --git a/webflux/src/test/java/org/springframework/security/web/server/context/SecurityContextRepositoryWebFilterTests.java b/webflux/src/test/java/org/springframework/security/web/server/context/SecurityContextRepositoryWebFilterTests.java index 3a33f25c5c..c1793ab583 100644 --- a/webflux/src/test/java/org/springframework/security/web/server/context/SecurityContextRepositoryWebFilterTests.java +++ b/webflux/src/test/java/org/springframework/security/web/server/context/SecurityContextRepositoryWebFilterTests.java @@ -24,11 +24,15 @@ import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.test.web.reactive.server.WebTestHandler; +import org.springframework.web.server.ServerWebExchange; import reactor.core.publisher.Mono; import java.security.Principal; +import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Matchers.any; import static org.mockito.Mockito.*; @@ -39,6 +43,9 @@ import static org.mockito.Mockito.*; */ @RunWith(MockitoJUnitRunner.class) public class SecurityContextRepositoryWebFilterTests { + @Mock + Authentication principal; + @Mock SecurityContextRepository repository; @@ -80,13 +87,29 @@ public class SecurityContextRepositoryWebFilterTests { verifyZeroInteractions(repository); } + // We must use the original principal if the result is empty for test support to work @Test - public void filterWhenGetPrincipalThenInteract() { + public void filterWhenEmptyAndGetPrincipalThenInteractAndUseOriginalPrincipal() { when(repository.load(any())).thenReturn(Mono.empty()); filters = WebTestHandler.bindToWebFilters(filter, (e,c) -> e.getPrincipal().flatMap( p-> c.filter(e))) ; - filters.exchange(exchange); + ServerWebExchange exchangeWithPrincipal = this.exchange.toExchange().mutate().principal(Mono.just(principal)).build(); + WebTestHandler.WebHandlerResult result = filters.exchange(exchangeWithPrincipal); verify(repository).load(any()); + assertThat(result.getExchange().getPrincipal().block()).isSameAs(principal); + } + + @Test + public void filterWhenPrincipalAndGetPrincipalThenInteractAndUseOriginalPrincipal() { + SecurityContextImpl context = new SecurityContextImpl(); + context.setAuthentication(principal); + when(repository.load(any())).thenReturn(Mono.just(context)); + filters = WebTestHandler.bindToWebFilters(filter, (e,c) -> e.getPrincipal().flatMap( p-> c.filter(e))) ; + + WebTestHandler.WebHandlerResult result = filters.exchange(exchange); + + verify(repository).load(any()); + assertThat(result.getExchange().getPrincipal().block()).isSameAs(principal); } }