BearerTokenAuthenticationFilter exposes AuthenticationFailureHandler

Make BearerTokenAuthenticationFilter expose an AuthenticationFailureHandler which, by default, invokes the AuthenticationEntryPoint set in the filter.

Fixes gh-7009
This commit is contained in:
Thomas Vitale 2019-06-28 22:03:05 +02:00 committed by Josh Cummings
parent ce79ef2634
commit f9747e6591
2 changed files with 43 additions and 2 deletions

View File

@ -33,6 +33,7 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken; import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken;
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationProvider; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationProvider;
import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
@ -61,6 +62,9 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
private AuthenticationEntryPoint authenticationEntryPoint = new BearerTokenAuthenticationEntryPoint(); private AuthenticationEntryPoint authenticationEntryPoint = new BearerTokenAuthenticationEntryPoint();
private AuthenticationFailureHandler authenticationFailureHandler = (request, response, exception) ->
authenticationEntryPoint.commence(request, response, exception);
/** /**
* Construct a {@code BearerTokenAuthenticationFilter} using the provided parameter(s) * Construct a {@code BearerTokenAuthenticationFilter} using the provided parameter(s)
* @param authenticationManagerResolver * @param authenticationManagerResolver
@ -131,7 +135,7 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
this.logger.debug("Authentication request for failed: " + failed); this.logger.debug("Authentication request for failed: " + failed);
} }
this.authenticationEntryPoint.commence(request, response, failed); this.authenticationFailureHandler.onAuthenticationFailure(request, response, failed);
} }
} }
@ -152,4 +156,14 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
Assert.notNull(authenticationEntryPoint, "authenticationEntryPoint cannot be null"); Assert.notNull(authenticationEntryPoint, "authenticationEntryPoint cannot be null");
this.authenticationEntryPoint = authenticationEntryPoint; this.authenticationEntryPoint = authenticationEntryPoint;
} }
/**
* Set the {@link AuthenticationFailureHandler} to use. Default implementation invokes {@link AuthenticationEntryPoint}.
* @param authenticationFailureHandler the {@code AuthenticationFailureHandler} to use
* @since 5.2
*/
public final void setAuthenticationFailureHandler(final AuthenticationFailureHandler authenticationFailureHandler) {
Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
this.authenticationFailureHandler = authenticationFailureHandler;
}
} }

View File

@ -37,6 +37,7 @@ import org.springframework.security.oauth2.server.resource.BearerTokenAuthentica
import org.springframework.security.oauth2.server.resource.BearerTokenError; import org.springframework.security.oauth2.server.resource.BearerTokenError;
import org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes; import org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes;
import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatCode;
@ -55,6 +56,9 @@ public class BearerTokenAuthenticationFilterTests {
@Mock @Mock
AuthenticationEntryPoint authenticationEntryPoint; AuthenticationEntryPoint authenticationEntryPoint;
@Mock
AuthenticationFailureHandler authenticationFailureHandler;
@Mock @Mock
AuthenticationManager authenticationManager; AuthenticationManager authenticationManager;
@ -138,7 +142,7 @@ public class BearerTokenAuthenticationFilterTests {
} }
@Test @Test
public void doFilterWhenAuthenticationFailsThenPropagatesError() throws ServletException, IOException { public void doFilterWhenAuthenticationFailsWithDefaultHandlerThenPropagatesError() throws ServletException, IOException {
BearerTokenError error = new BearerTokenError( BearerTokenError error = new BearerTokenError(
BearerTokenErrorCodes.INVALID_TOKEN, BearerTokenErrorCodes.INVALID_TOKEN,
HttpStatus.UNAUTHORIZED, HttpStatus.UNAUTHORIZED,
@ -159,6 +163,29 @@ public class BearerTokenAuthenticationFilterTests {
verify(this.authenticationEntryPoint).commence(this.request, this.response, exception); verify(this.authenticationEntryPoint).commence(this.request, this.response, exception);
} }
@Test
public void doFilterWhenAuthenticationFailsWithCustomHandlerThenPropagatesError() throws ServletException, IOException {
BearerTokenError error = new BearerTokenError(
BearerTokenErrorCodes.INVALID_TOKEN,
HttpStatus.UNAUTHORIZED,
"description",
"uri"
);
OAuth2AuthenticationException exception = new OAuth2AuthenticationException(error);
when(this.bearerTokenResolver.resolve(this.request)).thenReturn("token");
when(this.authenticationManager.authenticate(any(BearerTokenAuthenticationToken.class)))
.thenThrow(exception);
BearerTokenAuthenticationFilter filter =
addMocks(new BearerTokenAuthenticationFilter(this.authenticationManager));
filter.setAuthenticationFailureHandler(this.authenticationFailureHandler);
filter.doFilter(this.request, this.response, this.filterChain);
verify(this.authenticationFailureHandler).onAuthenticationFailure(this.request, this.response, exception);
}
@Test @Test
public void setAuthenticationEntryPointWhenNullThenThrowsException() { public void setAuthenticationEntryPointWhenNullThenThrowsException() {
BearerTokenAuthenticationFilter filter = new BearerTokenAuthenticationFilter(this.authenticationManager); BearerTokenAuthenticationFilter filter = new BearerTokenAuthenticationFilter(this.authenticationManager);