diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java index afb06ed9f8..6eb2778815 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java @@ -33,6 +33,7 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationProvider; import org.springframework.security.web.AuthenticationEntryPoint; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; @@ -61,6 +62,9 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter private AuthenticationEntryPoint authenticationEntryPoint = new BearerTokenAuthenticationEntryPoint(); + private AuthenticationFailureHandler authenticationFailureHandler = (request, response, exception) -> + authenticationEntryPoint.commence(request, response, exception); + /** * Construct a {@code BearerTokenAuthenticationFilter} using the provided parameter(s) * @param authenticationManagerResolver @@ -131,7 +135,7 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter this.logger.debug("Authentication request for failed: " + failed); } - this.authenticationEntryPoint.commence(request, response, failed); + this.authenticationFailureHandler.onAuthenticationFailure(request, response, failed); } } @@ -152,4 +156,14 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter Assert.notNull(authenticationEntryPoint, "authenticationEntryPoint cannot be null"); this.authenticationEntryPoint = authenticationEntryPoint; } + + /** + * Set the {@link AuthenticationFailureHandler} to use. Default implementation invokes {@link AuthenticationEntryPoint}. + * @param authenticationFailureHandler the {@code AuthenticationFailureHandler} to use + * @since 5.2 + */ + public final void setAuthenticationFailureHandler(final AuthenticationFailureHandler authenticationFailureHandler) { + Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null"); + this.authenticationFailureHandler = authenticationFailureHandler; + } } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java index 72cac4d8e7..25339f031b 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java @@ -37,6 +37,7 @@ import org.springframework.security.oauth2.server.resource.BearerTokenAuthentica import org.springframework.security.oauth2.server.resource.BearerTokenError; import org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes; import org.springframework.security.web.AuthenticationEntryPoint; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; @@ -55,6 +56,9 @@ public class BearerTokenAuthenticationFilterTests { @Mock AuthenticationEntryPoint authenticationEntryPoint; + @Mock + AuthenticationFailureHandler authenticationFailureHandler; + @Mock AuthenticationManager authenticationManager; @@ -138,7 +142,7 @@ public class BearerTokenAuthenticationFilterTests { } @Test - public void doFilterWhenAuthenticationFailsThenPropagatesError() throws ServletException, IOException { + public void doFilterWhenAuthenticationFailsWithDefaultHandlerThenPropagatesError() throws ServletException, IOException { BearerTokenError error = new BearerTokenError( BearerTokenErrorCodes.INVALID_TOKEN, HttpStatus.UNAUTHORIZED, @@ -159,6 +163,29 @@ public class BearerTokenAuthenticationFilterTests { verify(this.authenticationEntryPoint).commence(this.request, this.response, exception); } + @Test + public void doFilterWhenAuthenticationFailsWithCustomHandlerThenPropagatesError() throws ServletException, IOException { + BearerTokenError error = new BearerTokenError( + BearerTokenErrorCodes.INVALID_TOKEN, + HttpStatus.UNAUTHORIZED, + "description", + "uri" + ); + + OAuth2AuthenticationException exception = new OAuth2AuthenticationException(error); + + when(this.bearerTokenResolver.resolve(this.request)).thenReturn("token"); + when(this.authenticationManager.authenticate(any(BearerTokenAuthenticationToken.class))) + .thenThrow(exception); + + BearerTokenAuthenticationFilter filter = + addMocks(new BearerTokenAuthenticationFilter(this.authenticationManager)); + filter.setAuthenticationFailureHandler(this.authenticationFailureHandler); + filter.doFilter(this.request, this.response, this.filterChain); + + verify(this.authenticationFailureHandler).onAuthenticationFailure(this.request, this.response, exception); + } + @Test public void setAuthenticationEntryPointWhenNullThenThrowsException() { BearerTokenAuthenticationFilter filter = new BearerTokenAuthenticationFilter(this.authenticationManager);