diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java index 386315fa0f..5341ccae37 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java @@ -201,8 +201,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt if (authorizationRequest == null) { throw authzEx; } - this.sendRedirectForAuthorization(request, response, authorizationRequest); this.requestCache.saveRequest(request, response); + this.sendRedirectForAuthorization(request, response, authorizationRequest); } catch (Exception failed) { this.unsuccessfulRedirectForAuthorization(request, response, failed); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java index 751344f339..a8b92dd67b 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java @@ -52,6 +52,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willAnswer; import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -369,4 +370,22 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { + "redirect_uri=http://localhost/login/oauth2/code/registration-id"); } + // gh-11602 + + @Test + public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownThenSaveRequestBeforeCommitted() + throws Exception { + String requestUri = "/path"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + willAnswer((invocation) -> assertThat((invocation.getArgument(1)).isCommitted()).isFalse()) + .given(this.requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class)); + willThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId())).given(filterChain) + .doFilter(any(ServletRequest.class), any(ServletResponse.class)); + this.filter.doFilter(request, response, filterChain); + assertThat(response.isCommitted()).isTrue(); + } + }