From 45bac0fd2ce5a78621fc7dde5912fd3080513c25 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Mon, 11 Sep 2017 22:51:49 -0500 Subject: [PATCH] AuthenticationWebFilter uses AuthenticationFailureHandler Issue gh-4533 --- .../AuthenticationWebFilter.java | 18 +++++++++++------- .../AuthenticationWebFilterTests.java | 19 +++++++++---------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/webflux/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java b/webflux/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java index f5905d46a2..d993f42c1e 100644 --- a/webflux/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java +++ b/webflux/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java @@ -52,7 +52,7 @@ public class AuthenticationWebFilter implements WebFilter { private Function> authenticationConverter = new HttpBasicAuthenticationConverter(); - private AuthenticationEntryPoint entryPoint = new HttpBasicAuthenticationEntryPoint(); + private AuthenticationFailureHandler authenticationFailureHandler = new AuthenticationEntryPointFailureHandler(new HttpBasicAuthenticationEntryPoint()); private SecurityContextRepository securityContextRepository = new ServerWebExchangeAttributeSecurityContextRepository(); @@ -79,16 +79,18 @@ public class AuthenticationWebFilter implements WebFilter { private Mono authenticate(ServerWebExchange wrappedExchange, WebFilterChain chain, Authentication token) { + WebFilterExchange webFilterExchange = new WebFilterExchange(wrappedExchange, chain); return this.authenticationManager.authenticate(token) - .flatMap(authentication -> onAuthenticationSuccess(authentication, wrappedExchange, chain)) - .onErrorResume(AuthenticationException.class, e -> this.entryPoint.commence(wrappedExchange, e)); + .flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange)) + .onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler.onAuthenticationFailure(webFilterExchange, e)); } - private Mono onAuthenticationSuccess(Authentication authentication, ServerWebExchange exchange, WebFilterChain chain) { + private Mono onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) { + ServerWebExchange exchange = webFilterExchange.getExchange(); SecurityContextImpl securityContext = new SecurityContextImpl(); securityContext.setAuthentication(authentication); return this.securityContextRepository.save(exchange, securityContext) - .then(this.authenticationSuccessHandler.success(authentication, new WebFilterExchange(exchange, chain))); + .then(this.authenticationSuccessHandler.success(authentication, webFilterExchange)); } public void setSecurityContextRepository( @@ -105,8 +107,10 @@ public class AuthenticationWebFilter implements WebFilter { this.authenticationConverter = authenticationConverter; } - public void setEntryPoint(AuthenticationEntryPoint entryPoint) { - this.entryPoint = entryPoint; + public void setAuthenticationFailureHandler( + AuthenticationFailureHandler authenticationFailureHandler) { + Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null"); + this.authenticationFailureHandler = authenticationFailureHandler; } public void setRequiresAuthenticationMatcher( diff --git a/webflux/src/test/java/org/springframework/security/web/server/authentication/AuthenticationWebFilterTests.java b/webflux/src/test/java/org/springframework/security/web/server/authentication/AuthenticationWebFilterTests.java index fc675db61c..08eaeca09a 100644 --- a/webflux/src/test/java/org/springframework/security/web/server/authentication/AuthenticationWebFilterTests.java +++ b/webflux/src/test/java/org/springframework/security/web/server/authentication/AuthenticationWebFilterTests.java @@ -32,7 +32,6 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; -import org.springframework.security.web.server.AuthenticationEntryPoint; import org.springframework.security.web.server.context.SecurityContextRepository; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.test.web.reactive.server.EntityExchangeResult; @@ -62,7 +61,7 @@ public class AuthenticationWebFilterTests { @Mock private ReactiveAuthenticationManager authenticationManager; @Mock - private AuthenticationEntryPoint entryPoint; + private AuthenticationFailureHandler failureHandler; @Mock private SecurityContextRepository securityContextRepository; @@ -73,8 +72,8 @@ public class AuthenticationWebFilterTests { this.filter = new AuthenticationWebFilter(this.authenticationManager); this.filter.setAuthenticationSuccessHandler(this.successHandler); this.filter.setAuthenticationConverter(this.authenticationConverter); - this.filter.setEntryPoint(this.entryPoint); this.filter.setSecurityContextRepository(this.securityContextRepository); + this.filter.setAuthenticationFailureHandler(this.failureHandler); } @Test @@ -160,7 +159,7 @@ public class AuthenticationWebFilterTests { verify(this.securityContextRepository, never()).save(any(), any()); verifyZeroInteractions(this.authenticationManager, this.successHandler, - this.entryPoint); + this.failureHandler); } @Test @@ -180,7 +179,7 @@ public class AuthenticationWebFilterTests { verify(this.securityContextRepository, never()).save(any(), any()); verifyZeroInteractions(this.authenticationManager, this.successHandler, - this.entryPoint); + this.failureHandler); } @Test @@ -204,7 +203,7 @@ public class AuthenticationWebFilterTests { verify(this.successHandler).success(eq(authentication.block()), any()); verify(this.securityContextRepository).save(any(), any()); - verifyZeroInteractions(this.entryPoint); + verifyZeroInteractions(this.failureHandler); } @Test @@ -235,7 +234,7 @@ public class AuthenticationWebFilterTests { Mono authentication = Mono.just(new TestingAuthenticationToken("test", "this", "ROLE_USER")); when(this.authenticationConverter.apply(any())).thenReturn(authentication); when(this.authenticationManager.authenticate(any())).thenReturn(Mono.error(new BadCredentialsException("Failed"))); - when(this.entryPoint.commence(any(),any())).thenReturn(Mono.empty()); + when(this.failureHandler.onAuthenticationFailure(any(),any())).thenReturn(Mono.empty()); WebTestClient client = WebTestClientBuilder .bindToWebFilters(this.filter) @@ -248,7 +247,7 @@ public class AuthenticationWebFilterTests { .expectStatus().isOk() .expectBody().isEmpty(); - verify(this.entryPoint).commence(any(),any()); + verify(this.failureHandler).onAuthenticationFailure(any(),any()); verify(this.securityContextRepository, never()).save(any(), any()); verifyZeroInteractions(this.successHandler); } @@ -258,7 +257,7 @@ public class AuthenticationWebFilterTests { Mono authentication = Mono.just(new TestingAuthenticationToken("test", "this", "ROLE_USER")); when(this.authenticationConverter.apply(any())).thenReturn(authentication); when(this.authenticationManager.authenticate(any())).thenReturn(Mono.error(new RuntimeException("Failed"))); - when(this.entryPoint.commence(any(),any())).thenReturn(Mono.empty()); + when(this.failureHandler.onAuthenticationFailure(any(),any())).thenReturn(Mono.empty()); WebTestClient client = WebTestClientBuilder .bindToWebFilters(this.filter) @@ -272,7 +271,7 @@ public class AuthenticationWebFilterTests { .expectBody().isEmpty(); verify(this.securityContextRepository, never()).save(any(), any()); - verifyZeroInteractions(this.successHandler, this.entryPoint); + verifyZeroInteractions(this.successHandler, this.failureHandler); } @Test(expected = IllegalArgumentException.class)