mirror of
https://github.com/spring-projects/spring-security.git
synced 2025-02-13 09:54:57 +00:00
Introduce Customizable AuthorizationFailureHandler
Closes gh-13793
This commit is contained in:
parent
bd345fb2a8
commit
07ac0b616b
oauth2/oauth2-client/src
main/java/org/springframework/security/oauth2/client/web
test/java/org/springframework/security/oauth2/client/web
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2002-2022 the original author or authors.
|
||||
* Copyright 2002-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -25,6 +25,7 @@ import jakarta.servlet.http.HttpServletResponse;
|
||||
|
||||
import org.springframework.core.log.LogMessage;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.security.core.AuthenticationException;
|
||||
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
|
||||
@ -32,6 +33,7 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
||||
import org.springframework.security.web.DefaultRedirectStrategy;
|
||||
import org.springframework.security.web.RedirectStrategy;
|
||||
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
||||
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
|
||||
import org.springframework.security.web.savedrequest.RequestCache;
|
||||
import org.springframework.security.web.util.ThrowableAnalyzer;
|
||||
@ -97,6 +99,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
|
||||
|
||||
private RequestCache requestCache = new HttpSessionRequestCache();
|
||||
|
||||
private AuthenticationFailureHandler authenticationFailureHandler = this::unsuccessfulRedirectForAuthorization;
|
||||
|
||||
/**
|
||||
* Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided
|
||||
* parameters.
|
||||
@ -163,6 +167,18 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
|
||||
this.requestCache = requestCache;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@link AuthenticationFailureHandler} used to handle errors redirecting to
|
||||
* the Authorization Server's Authorization Endpoint.
|
||||
* @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used
|
||||
* to handle errors redirecting to the Authorization Server's Authorization Endpoint
|
||||
* @since 6.3
|
||||
*/
|
||||
public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
|
||||
Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
|
||||
this.authenticationFailureHandler = authenticationFailureHandler;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
|
||||
throws ServletException, IOException {
|
||||
@ -174,7 +190,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
|
||||
}
|
||||
}
|
||||
catch (Exception ex) {
|
||||
this.unsuccessfulRedirectForAuthorization(request, response, ex);
|
||||
AuthenticationException wrappedException = new OAuth2AuthorizationRequestException(ex);
|
||||
this.authenticationFailureHandler.onAuthenticationFailure(request, response, wrappedException);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
@ -199,7 +216,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
|
||||
this.sendRedirectForAuthorization(request, response, authorizationRequest);
|
||||
}
|
||||
catch (Exception failed) {
|
||||
this.unsuccessfulRedirectForAuthorization(request, response, failed);
|
||||
AuthenticationException wrappedException = new OAuth2AuthorizationRequestException(ex);
|
||||
this.authenticationFailureHandler.onAuthenticationFailure(request, response, wrappedException);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@ -223,9 +241,10 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
|
||||
}
|
||||
|
||||
private void unsuccessfulRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response,
|
||||
Exception ex) throws IOException {
|
||||
LogMessage message = LogMessage.format("Authorization Request failed: %s", ex);
|
||||
if (InvalidClientRegistrationIdException.class.isAssignableFrom(ex.getClass())) {
|
||||
AuthenticationException ex) throws IOException {
|
||||
Throwable cause = ex.getCause();
|
||||
LogMessage message = LogMessage.format("Authorization Request failed: %s", cause);
|
||||
if (InvalidClientRegistrationIdException.class.isAssignableFrom(cause.getClass())) {
|
||||
// Log an invalid registrationId at WARN level to allow these errors to be
|
||||
// tuned separately from other errors
|
||||
this.logger.warn(message, ex);
|
||||
@ -250,4 +269,12 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
|
||||
|
||||
}
|
||||
|
||||
private static final class OAuth2AuthorizationRequestException extends AuthenticationException {
|
||||
|
||||
OAuth2AuthorizationRequestException(Throwable cause) {
|
||||
super(cause.getMessage(), cause);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2002-2022 the original author or authors.
|
||||
* Copyright 2002-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -119,6 +119,11 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
|
||||
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestCache(null));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setAuthenticationFailureHandlerIsNullThenThrowIllegalArgumentException() {
|
||||
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationFailureHandler(null));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenNotAuthorizationRequestThenNextFilter() throws Exception {
|
||||
String requestUri = "/path";
|
||||
@ -144,6 +149,31 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
|
||||
assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestWithInvalidClientAndCustomFailureHandlerThenCustomError()
|
||||
throws Exception {
|
||||
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"
|
||||
+ this.registration1.getRegistrationId() + "-invalid";
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
this.filter.setAuthenticationFailureHandler((request1, response1, ex) -> {
|
||||
Throwable cause = ex.getCause();
|
||||
if (InvalidClientRegistrationIdException.class.isAssignableFrom(cause.getClass())) {
|
||||
response1.sendError(HttpStatus.BAD_REQUEST.value(), HttpStatus.BAD_REQUEST.getReasonPhrase());
|
||||
}
|
||||
else {
|
||||
response1.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(),
|
||||
HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
|
||||
}
|
||||
});
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
verifyNoMoreInteractions(filterChain);
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
|
||||
assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.BAD_REQUEST.getReasonPhrase());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestOAuth2LoginThenRedirectForAuthorization() throws Exception {
|
||||
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"
|
||||
|
Loading…
x
Reference in New Issue
Block a user