Introduce Customizable AuthorizationFailureHandler

Closes gh-13793
This commit is contained in:
greg.lee 2023-11-19 02:28:27 +09:00 committed by Steve Riesenberg
parent bd345fb2a8
commit 07ac0b616b
No known key found for this signature in database
GPG Key ID: 3D0169B18AB8F0A9
2 changed files with 64 additions and 7 deletions
oauth2/oauth2-client/src
main/java/org/springframework/security/oauth2/client/web
test/java/org/springframework/security/oauth2/client/web

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -25,6 +25,7 @@ import jakarta.servlet.http.HttpServletResponse;
import org.springframework.core.log.LogMessage;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
@ -32,6 +33,7 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.util.ThrowableAnalyzer;
@ -97,6 +99,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
private RequestCache requestCache = new HttpSessionRequestCache();
private AuthenticationFailureHandler authenticationFailureHandler = this::unsuccessfulRedirectForAuthorization;
/**
* Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided
* parameters.
@ -163,6 +167,18 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
this.requestCache = requestCache;
}
/**
* Sets the {@link AuthenticationFailureHandler} used to handle errors redirecting to
* the Authorization Server's Authorization Endpoint.
* @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used
* to handle errors redirecting to the Authorization Server's Authorization Endpoint
* @since 6.3
*/
public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
this.authenticationFailureHandler = authenticationFailureHandler;
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
@ -174,7 +190,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
}
}
catch (Exception ex) {
this.unsuccessfulRedirectForAuthorization(request, response, ex);
AuthenticationException wrappedException = new OAuth2AuthorizationRequestException(ex);
this.authenticationFailureHandler.onAuthenticationFailure(request, response, wrappedException);
return;
}
try {
@ -199,7 +216,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
this.sendRedirectForAuthorization(request, response, authorizationRequest);
}
catch (Exception failed) {
this.unsuccessfulRedirectForAuthorization(request, response, failed);
AuthenticationException wrappedException = new OAuth2AuthorizationRequestException(ex);
this.authenticationFailureHandler.onAuthenticationFailure(request, response, wrappedException);
}
return;
}
@ -223,9 +241,10 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
}
private void unsuccessfulRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response,
Exception ex) throws IOException {
LogMessage message = LogMessage.format("Authorization Request failed: %s", ex);
if (InvalidClientRegistrationIdException.class.isAssignableFrom(ex.getClass())) {
AuthenticationException ex) throws IOException {
Throwable cause = ex.getCause();
LogMessage message = LogMessage.format("Authorization Request failed: %s", cause);
if (InvalidClientRegistrationIdException.class.isAssignableFrom(cause.getClass())) {
// Log an invalid registrationId at WARN level to allow these errors to be
// tuned separately from other errors
this.logger.warn(message, ex);
@ -250,4 +269,12 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
}
private static final class OAuth2AuthorizationRequestException extends AuthenticationException {
OAuth2AuthorizationRequestException(Throwable cause) {
super(cause.getMessage(), cause);
}
}
}

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -119,6 +119,11 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestCache(null));
}
@Test
public void setAuthenticationFailureHandlerIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationFailureHandler(null));
}
@Test
public void doFilterWhenNotAuthorizationRequestThenNextFilter() throws Exception {
String requestUri = "/path";
@ -144,6 +149,31 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
}
@Test
public void doFilterWhenAuthorizationRequestWithInvalidClientAndCustomFailureHandlerThenCustomError()
throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"
+ this.registration1.getRegistrationId() + "-invalid";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.setAuthenticationFailureHandler((request1, response1, ex) -> {
Throwable cause = ex.getCause();
if (InvalidClientRegistrationIdException.class.isAssignableFrom(cause.getClass())) {
response1.sendError(HttpStatus.BAD_REQUEST.value(), HttpStatus.BAD_REQUEST.getReasonPhrase());
}
else {
response1.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(),
HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
}
});
this.filter.doFilter(request, response, filterChain);
verifyNoMoreInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.BAD_REQUEST.getReasonPhrase());
}
@Test
public void doFilterWhenAuthorizationRequestOAuth2LoginThenRedirectForAuthorization() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"