AuthenticationSuccessHandler uses WebFilterExchange

Issue gh-4532
This commit is contained in:
Rob Winch 2017-09-07 14:11:04 -05:00
parent ef9cf1d54b
commit a6bed9a1aa
6 changed files with 21 additions and 17 deletions

View File

@ -19,8 +19,7 @@
package org.springframework.security.web.server.authentication; package org.springframework.security.web.server.authentication;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.web.server.ServerWebExchange; import org.springframework.security.web.server.WebFilterExchange;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
/** /**
@ -28,5 +27,5 @@ import reactor.core.publisher.Mono;
* @since 5.0 * @since 5.0
*/ */
public interface AuthenticationSuccessHandler { public interface AuthenticationSuccessHandler {
Mono<Void> success(Authentication authentication, ServerWebExchange exchange, WebFilterChain chain); Mono<Void> success(Authentication authentication, WebFilterExchange webFilterExchange);
} }

View File

@ -27,6 +27,7 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.web.server.AuthenticationEntryPoint; import org.springframework.security.web.server.AuthenticationEntryPoint;
import org.springframework.security.web.server.HttpBasicAuthenticationConverter; import org.springframework.security.web.server.HttpBasicAuthenticationConverter;
import org.springframework.security.web.server.WebFilterExchange;
import org.springframework.security.web.server.authentication.www.HttpBasicAuthenticationEntryPoint; import org.springframework.security.web.server.authentication.www.HttpBasicAuthenticationEntryPoint;
import org.springframework.security.web.server.context.SecurityContextRepository; import org.springframework.security.web.server.context.SecurityContextRepository;
import org.springframework.security.web.server.context.SecurityContextRepositoryServerWebExchange; import org.springframework.security.web.server.context.SecurityContextRepositoryServerWebExchange;
@ -87,7 +88,7 @@ public class AuthenticationWebFilter implements WebFilter {
SecurityContextImpl securityContext = new SecurityContextImpl(); SecurityContextImpl securityContext = new SecurityContextImpl();
securityContext.setAuthentication(authentication); securityContext.setAuthentication(authentication);
return this.securityContextRepository.save(exchange, securityContext) return this.securityContextRepository.save(exchange, securityContext)
.then(this.authenticationSuccessHandler.success(authentication, exchange, chain)); .then(this.authenticationSuccessHandler.success(authentication, new WebFilterExchange(exchange, chain)));
} }
public void setSecurityContextRepository( public void setSecurityContextRepository(

View File

@ -21,6 +21,7 @@ package org.springframework.security.web.server.authentication;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.web.server.DefaultRedirectStrategy; import org.springframework.security.web.server.DefaultRedirectStrategy;
import org.springframework.security.web.server.RedirectStrategy; import org.springframework.security.web.server.RedirectStrategy;
import org.springframework.security.web.server.WebFilterExchange;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebFilterChain;
@ -38,8 +39,8 @@ public class RedirectAuthenticationSuccessHandler implements AuthenticationSucce
private RedirectStrategy redirectStrategy = new DefaultRedirectStrategy(); private RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
@Override @Override
public Mono<Void> success(Authentication authentication, public Mono<Void> success(Authentication authentication, WebFilterExchange webFilterExchange) {
ServerWebExchange exchange, WebFilterChain chain) { ServerWebExchange exchange = webFilterExchange.getExchange();
return this.redirectStrategy.sendRedirect(exchange, this.location); return this.redirectStrategy.sendRedirect(exchange, this.location);
} }

View File

@ -19,6 +19,7 @@
package org.springframework.security.web.server.authentication; package org.springframework.security.web.server.authentication;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.web.server.WebFilterExchange;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
@ -29,7 +30,8 @@ import reactor.core.publisher.Mono;
*/ */
public class WebFilterChainAuthenticationSuccessHandler implements AuthenticationSuccessHandler { public class WebFilterChainAuthenticationSuccessHandler implements AuthenticationSuccessHandler {
@Override @Override
public Mono<Void> success(Authentication authentication, ServerWebExchange exchange, WebFilterChain chain) { public Mono<Void> success(Authentication authentication, WebFilterExchange webFilterExchange) {
return chain.filter(exchange); ServerWebExchange exchange = webFilterExchange.getExchange();
return webFilterExchange.getChain().filter(exchange);
} }
} }

View File

@ -150,7 +150,7 @@ public class AuthenticationWebFilterTests {
.bindToWebFilters(this.filter) .bindToWebFilters(this.filter)
.build(); .build();
EntityExchangeResult<String> result = client client
.get() .get()
.uri("/") .uri("/")
.exchange() .exchange()
@ -188,7 +188,7 @@ public class AuthenticationWebFilterTests {
Mono<Authentication> authentication = Mono.just(new TestingAuthenticationToken("test", "this", "ROLE_USER")); Mono<Authentication> authentication = Mono.just(new TestingAuthenticationToken("test", "this", "ROLE_USER"));
when(this.authenticationConverter.apply(any())).thenReturn(authentication); when(this.authenticationConverter.apply(any())).thenReturn(authentication);
when(this.authenticationManager.authenticate(any())).thenReturn(authentication); when(this.authenticationManager.authenticate(any())).thenReturn(authentication);
when(this.successHandler.success(any(),any(),any())).thenReturn(Mono.empty()); when(this.successHandler.success(any(),any())).thenReturn(Mono.empty());
when(this.securityContextRepository.save(any(),any())).thenAnswer( a -> Mono.just(a.getArguments()[0])); when(this.securityContextRepository.save(any(),any())).thenAnswer( a -> Mono.just(a.getArguments()[0]));
WebTestClient client = WebTestClientBuilder WebTestClient client = WebTestClientBuilder
@ -202,7 +202,7 @@ public class AuthenticationWebFilterTests {
.expectStatus().isOk() .expectStatus().isOk()
.expectBody().isEmpty(); .expectBody().isEmpty();
verify(this.successHandler).success(eq(authentication.block()), any(), any()); verify(this.successHandler).success(eq(authentication.block()), any());
verify(this.securityContextRepository).save(any(), any()); verify(this.securityContextRepository).save(any(), any());
verifyZeroInteractions(this.entryPoint); verifyZeroInteractions(this.entryPoint);
} }

View File

@ -28,6 +28,7 @@ import org.springframework.security.authentication.AuthenticationCredentialsNotF
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.server.RedirectStrategy; import org.springframework.security.web.server.RedirectStrategy;
import org.springframework.security.web.server.WebFilterExchange;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
@ -70,8 +71,8 @@ public class RedirectAuthenticationSuccessHandlerTests {
@Test @Test
public void successWhenNoSubscribersThenNoActions() { public void successWhenNoSubscribersThenNoActions() {
this.handler.success(this.authentication, this.exchange, this.handler.success(this.authentication, new WebFilterExchange(this.exchange,
this.chain); this.chain));
verifyZeroInteractions(this.exchange); verifyZeroInteractions(this.exchange);
} }
@ -80,8 +81,8 @@ public class RedirectAuthenticationSuccessHandlerTests {
public void successWhenSubscribeThenStatusAndLocationSet() { public void successWhenSubscribeThenStatusAndLocationSet() {
this.exchange = MockServerHttpRequest.get("/").toExchange(); this.exchange = MockServerHttpRequest.get("/").toExchange();
this.handler.success(this.authentication, this.exchange, this.handler.success(this.authentication, new WebFilterExchange(this.exchange,
this.chain).block(); this.chain)).block();
assertThat(this.exchange.getResponse().getStatusCode()).isEqualTo( assertThat(this.exchange.getResponse().getStatusCode()).isEqualTo(
HttpStatus.FOUND); HttpStatus.FOUND);
@ -95,8 +96,8 @@ public class RedirectAuthenticationSuccessHandlerTests {
this.handler.setRedirectStrategy(this.redirectStrategy); this.handler.setRedirectStrategy(this.redirectStrategy);
this.exchange = MockServerHttpRequest.get("/").toExchange(); this.exchange = MockServerHttpRequest.get("/").toExchange();
assertThat(this.handler.success(this.authentication, this.exchange, assertThat(this.handler.success(this.authentication, new WebFilterExchange(this.exchange,
this.chain)).isEqualTo(result); this.chain))).isEqualTo(result);
verify(this.redirectStrategy).sendRedirect(any(), eq(this.location)); verify(this.redirectStrategy).sendRedirect(any(), eq(this.location));
} }