diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java index ab8e7c4413..f6844284f9 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java @@ -192,8 +192,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt if (authorizationRequest == null) { throw authzEx; } - this.sendRedirectForAuthorization(request, response, authorizationRequest); this.requestCache.saveRequest(request, response); + this.sendRedirectForAuthorization(request, response, authorizationRequest); } catch (Exception failed) { this.unsuccessfulRedirectForAuthorization(request, response, failed); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java index 0df8ea1a03..aaf6f86201 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java @@ -48,6 +48,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willAnswer; import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -333,4 +334,22 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { + "login_hint=user@provider\\.com"); } + // gh-11602 + + @Test + public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownThenSaveRequestBeforeCommitted() + throws Exception { + String requestUri = "/path"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + willAnswer((invocation) -> assertThat((invocation.getArgument(1)).isCommitted()).isFalse()) + .given(this.requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class)); + willThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId())).given(filterChain) + .doFilter(any(ServletRequest.class), any(ServletResponse.class)); + this.filter.doFilter(request, response, filterChain); + assertThat(response.isCommitted()).isTrue(); + } + }