diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilter.java index f49eb4d6df..d6fcad1fea 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilter.java @@ -136,6 +136,7 @@ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter { .map(ServerWebExchangeMatcher.MatchResult::getVariables) .map(variables -> variables.get(REGISTRATION_ID_URI_VARIABLE_NAME)) .cast(String.class) + .onErrorResume(ClientAuthorizationRequiredException.class, e -> Mono.just(e.getClientRegistrationId())) .flatMap(clientRegistrationId -> this.findByRegistrationId(exchange, clientRegistrationId)) .flatMap(clientRegistration -> sendRedirectForAuthorization(exchange, clientRegistration)); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilterTests.java index 82fc51de2a..ce839c6810 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilterTests.java @@ -21,6 +21,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.core.AuthorizationGrantType; @@ -133,4 +134,15 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests { }); verify(this.authzRequestRepository).saveAuthorizationRequest(any(), any()); } + + @Test + public void filterWhenExceptionThenRedirected() { + FilteringWebHandler webHandler = new FilteringWebHandler(e -> Mono.error(new ClientAuthorizationRequiredException(this.github.getRegistrationId())), Arrays.asList(this.filter)); + this.client = WebTestClient.bindToWebHandler(webHandler).build(); + FluxExchangeResult result = this.client.get() + .uri("https://example.com/foo").exchange() + .expectStatus() + .is3xxRedirection() + .returnResult(String.class); + } }