diff --git a/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java b/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java index cd2454be7a..e4b4981812 100644 --- a/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java +++ b/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java @@ -21,6 +21,9 @@ import org.springframework.lang.Nullable; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.test.context.TestSecurityContextHolder; @@ -33,12 +36,9 @@ import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; import reactor.core.publisher.Mono; -import java.security.Principal; import java.util.Collection; import java.util.List; import java.util.function.Consumer; -import java.util.function.Function; -import java.util.function.Supplier; /** * Test utilities for working with Spring Security and @@ -58,22 +58,12 @@ public class SecurityMockServerConfigurers { public void beforeServerCreated(WebHttpHandlerBuilder builder) { builder.filters( filters -> { filters.add(0, new MutatorFilter()); - filters.add(0, new SetupMutatorFilter(createMutator( () -> TestSecurityContextHolder.getContext().getAuthentication()))); + filters.add(0, new SetupMutatorFilter(TestSecurityContextHolder.getContext())); }); } }; } - /** - * Updates the ServerWebExchange to use the provided Principal - * - * @param principal the principal to use. - * @return the {@link WebTestClientConfigurer} to use - */ - public static T mockPrincipal(Principal principal) { - return (T) new MutatorWebTestClientConfigurer(createMutator(() -> principal)); - } - /** * Updates the ServerWebExchange to use the provided Authentication as the Principal * @@ -81,7 +71,7 @@ public class SecurityMockServerConfigurers { * @return the {@link WebTestClientConfigurer}} to use */ public static T mockAuthentication(Authentication authentication) { - return mockPrincipal(authentication); + return (T) new MutatorWebTestClientConfigurer(authentication); } /** @@ -118,10 +108,6 @@ public class SecurityMockServerConfigurers { return new UserExchangeMutator(username); } - private static Function createMutator(Supplier principal) { - return m -> principal.get() == null ? m : m.mutate().principal(Mono.just(principal.get())).build(); - } - /** * Updates the WebServerExchange using {@code {@link SecurityMockServerConfigurers#mockUser(UserDetails)}. Defaults to use a * password of "password" and granted authorities of "ROLE_USER". @@ -230,12 +216,21 @@ public class SecurityMockServerConfigurers { } private static class MutatorWebTestClientConfigurer implements WebTestClientConfigurer, MockServerConfigurer { - private final Function mutator; + private final Mono context; - private MutatorWebTestClientConfigurer(Function mutator) { - this.mutator = mutator; + private MutatorWebTestClientConfigurer(Mono context) { + this.context = context; } + private MutatorWebTestClientConfigurer(SecurityContext context) { + this(Mono.just(context)); + } + + private MutatorWebTestClientConfigurer(Authentication authentication) { + this(new SecurityContextImpl(authentication)); + } + + @Override public void beforeServerCreated(WebHttpHandlerBuilder builder) { builder.filters(addSetupMutatorFilter()); @@ -247,34 +242,42 @@ public class SecurityMockServerConfigurers { } private Consumer> addSetupMutatorFilter() { - return filters -> filters.add(0, new SetupMutatorFilter(mutator)); + return filters -> filters.add(0, new SetupMutatorFilter(this.context)); } } private static class SetupMutatorFilter implements WebFilter { - private final Function mutator; + private final Mono context; - private SetupMutatorFilter(Function mutator) { - this.mutator = mutator; + private SetupMutatorFilter(Mono context) { + this.context = context; + } + + private SetupMutatorFilter(SecurityContext context) { + this(Mono.just(context)); + } + + private SetupMutatorFilter(Authentication authentication) { + this(new SecurityContextImpl(authentication)); } @Override public Mono filter(ServerWebExchange exchange, WebFilterChain webFilterChain) { - exchange.getAttributes().computeIfAbsent(MutatorFilter.ATTRIBUTE_NAME, key -> mutator); + exchange.getAttributes().computeIfAbsent(MutatorFilter.ATTRIBUTE_NAME, key -> this.context); return webFilterChain.filter(exchange); } } private static class MutatorFilter implements WebFilter { - - public static final String ATTRIBUTE_NAME = "mutator"; + public static final String ATTRIBUTE_NAME = "context"; @Override public Mono filter(ServerWebExchange exchange, WebFilterChain webFilterChain) { - Function mutator = exchange.getAttribute(ATTRIBUTE_NAME); - if(mutator != null) { + Mono context = exchange.getAttribute(ATTRIBUTE_NAME); + if(context != null) { exchange.getAttributes().remove(ATTRIBUTE_NAME); - exchange = mutator.apply(exchange); + return webFilterChain.filter(exchange) + .subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(context)); } return webFilterChain.filter(exchange); }