Use ReactorSecurityContextHolder

Issue gh-4713
This commit is contained in:
Rob Winch 2017-10-25 16:18:27 -05:00
parent 9ea4df5b5d
commit 747473257f
8 changed files with 55 additions and 43 deletions

View File

@ -25,7 +25,7 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringRunner;
import reactor.core.publisher.Flux;
@ -34,8 +34,6 @@ import reactor.test.StepVerifier;
import reactor.test.publisher.TestPublisher;
import reactor.util.context.Context;
import java.util.function.Function;
import static org.mockito.Mockito.*;
/**
@ -49,10 +47,8 @@ public class EnableReactiveMethodSecurityTests {
ReactiveMessageService delegate;
TestPublisher<String> result = TestPublisher.create();
Function<Context, Context> withAdmin = context -> context.put(Authentication.class, Mono
.just(new TestingAuthenticationToken("admin","password","ROLE_USER", "ROLE_ADMIN")));
Function<Context, Context> withUser = context -> context.put(Authentication.class, Mono
.just(new TestingAuthenticationToken("user","password","ROLE_USER")));
Context withAdmin = ReactiveSecurityContextHolder.withAuthentication(new TestingAuthenticationToken("admin","password","ROLE_USER", "ROLE_ADMIN"));
Context withUser = ReactiveSecurityContextHolder.withAuthentication(new TestingAuthenticationToken("user","password","ROLE_USER"));
@After
public void cleanup() {

View File

@ -31,6 +31,8 @@ import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.config.users.ReactiveAuthenticationTestConfiguration;
import org.springframework.security.config.web.server.ServerHttpSecurity;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.userdetails.MapReactiveUserDetailsService;
import org.springframework.security.core.userdetails.ReactiveUserDetailsService;
import org.springframework.security.core.userdetails.User;
@ -106,8 +108,8 @@ public class EnableWebFluxSecurityTests {
chain.filter(exchange.mutate().principal(Mono.just(currentPrincipal)).build()),
this.springSecurityFilterChain,
(exchange,chain) ->
Mono.subscriberContext()
.flatMap( c -> c.<Mono<Principal>>get(Authentication.class))
ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.flatMap( principal -> exchange.getResponse()
.writeWith(Mono.just(toDataBuffer(principal.getName()))))
).build();
@ -126,8 +128,8 @@ public class EnableWebFluxSecurityTests {
WebTestClient client = WebTestClientBuilder.bindToWebFilters(
this.springSecurityFilterChain,
(exchange,chain) ->
Mono.subscriberContext()
.flatMap( c -> c.<Mono<Principal>>get(Authentication.class))
ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.flatMap( principal -> exchange.getResponse()
.writeWith(Mono.just(toDataBuffer(principal.getName()))))
)
@ -154,8 +156,8 @@ public class EnableWebFluxSecurityTests {
WebTestClient client = WebTestClientBuilder.bindToWebFilters(
this.springSecurityFilterChain,
(exchange,chain) ->
Mono.subscriberContext()
.flatMap( c -> c.<Mono<Principal>>get(Authentication.class))
ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.flatMap( principal -> exchange.getResponse()
.writeWith(Mono.just(toDataBuffer(principal.getName()))))
)

View File

@ -25,11 +25,12 @@ import org.springframework.security.access.method.MethodSecurityMetadataSource;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.util.Assert;
import reactor.core.Exceptions;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;
import java.lang.reflect.Method;
import java.util.Collection;
@ -68,9 +69,9 @@ public class PrePostAdviceReactiveMethodInterceptor implements MethodInterceptor
.getAttributes(method, targetClass);
PreInvocationAttribute preAttr = findPreInvocationAttribute(attributes);
Mono<Authentication> toInvoke = Mono.subscriberContext()
.defaultIfEmpty(Context.empty())
.flatMap( cxt -> cxt.getOrDefault(Authentication.class, Mono.just(anonymous)))
Mono<Authentication> toInvoke = ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.defaultIfEmpty(this.anonymous)
.filter( auth -> this.preInvocationAdvice.before(auth, invocation, preAttr))
.switchIfEmpty(Mono.error(new AccessDeniedException("Denied")));

View File

@ -18,6 +18,7 @@ package org.springframework.security.test.context.support;
import org.reactivestreams.Subscription;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.test.context.TestSecurityContextHolder;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.TestExecutionListener;
@ -25,7 +26,6 @@ import org.springframework.test.context.support.AbstractTestExecutionListener;
import org.springframework.util.ClassUtils;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Hooks;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Operators;
import reactor.util.context.Context;
@ -76,7 +76,8 @@ public class ReactorContextTestExecutionListener
if (authentication == null) {
return context;
}
return context.put(Authentication.class, Mono.just(authentication));
Context toMerge = ReactiveSecurityContextHolder.withAuthentication(authentication);
return context.putAll(toMerge);
}
@Override

View File

@ -20,6 +20,8 @@ import static org.assertj.core.api.Assertions.assertThat;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.test.context.support.WithMockUser;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
@ -42,8 +44,8 @@ public class SecurityTestExecutionListenerTests {
@WithMockUser
@Test
public void reactorContextTestSecurityContextHolderExecutionListenerTestIsRegistered() {
Mono<String> name = Mono.subscriberContext()
.flatMap( context -> context.<Mono<Authentication>>get(Authentication.class))
Mono<String> name = ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.map(Principal::getName);
StepVerifier.create(name)

View File

@ -26,6 +26,8 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import reactor.core.publisher.Hooks;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
@ -108,8 +110,8 @@ public class ReactorContextTestExecutionListenerTests {
}
public void assertAuthentication(Authentication expected) {
Mono<Authentication> authentication = Mono.subscriberContext()
.flatMap( context -> context.<Mono<Authentication>>get(Authentication.class));
Mono<Authentication> authentication = ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication);
StepVerifier.create(authentication)
.expectNext(expected)

View File

@ -17,6 +17,8 @@
package org.springframework.security.web.server.context;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
@ -38,6 +40,13 @@ public class AuthenticationReactorContextWebFilter implements WebFilter {
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
return chain.filter(exchange)
.subscriberContext((Context context) -> context.put(Authentication.class, exchange.getPrincipal()));
.subscriberContext(createContext(exchange));
}
private Context createContext(ServerWebExchange exchange) {
return exchange.getPrincipal()
.cast(Authentication.class)
.map(SecurityContextImpl::new)
.as(ReactiveSecurityContextHolder::withSecurityContext);
}
}

View File

@ -21,11 +21,12 @@ import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.handler.DefaultWebFilterChain;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import reactor.util.context.Context;
import java.security.Principal;
@ -47,12 +48,12 @@ public class AuthenticationReactorContextWebFilterTests {
exchange = exchange.mutate().principal(Mono.just(principal)).build();
StepVerifier.create(filter.filter(exchange,
new DefaultWebFilterChain( e ->
Mono.subscriberContext().doOnSuccess( context -> {
Principal contextPrincipal = context.<Mono<Principal>>get(Authentication.class).block();
assertThat(contextPrincipal).isEqualTo(principal);
assertThat(context.<String>get("foo")).isEqualTo("bar");
})
.then()
ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.doOnSuccess(contextPrincipal -> assertThat(contextPrincipal).isEqualTo(principal))
.flatMap( contextPrincipal -> Mono.subscriberContext())
.doOnSuccess( context -> assertThat(context.<String>get("foo")).isEqualTo("bar"))
.then()
)
)
.subscriberContext( context -> context.put("foo", "bar")))
@ -64,11 +65,10 @@ public class AuthenticationReactorContextWebFilterTests {
exchange = exchange.mutate().principal(Mono.just(principal)).build();
StepVerifier.create(filter.filter(exchange,
new DefaultWebFilterChain( e ->
Mono.subscriberContext().doOnSuccess( context -> {
Principal contextPrincipal = context.<Mono<Principal>>get(Authentication.class).block();
assertThat(contextPrincipal).isEqualTo(principal);
})
.then()
ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.doOnSuccess(contextPrincipal -> assertThat(contextPrincipal).isEqualTo(principal))
.then()
)
))
.verifyComplete();
@ -76,15 +76,14 @@ public class AuthenticationReactorContextWebFilterTests {
@Test
public void filterWhenPrincipalNullThenContextEmpty() {
Context defaultContext = Context.empty();
Authentication defaultAuthentication = new TestingAuthenticationToken("anonymouse","anonymous", "TEST");
StepVerifier.create(filter.filter(exchange,
new DefaultWebFilterChain( e ->
Mono.subscriberContext()
.defaultIfEmpty(defaultContext)
.doOnSuccess( context -> {
Principal contextPrincipal = context.<Mono<Principal>>get(Authentication.class).block();
assertThat(contextPrincipal).isNull();
})
ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.defaultIfEmpty(defaultAuthentication)
.doOnSuccess( contextPrincipal -> assertThat(contextPrincipal).isEqualTo(defaultAuthentication)
)
.then()
)
))