diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 7d9b00a070..a96da865cc 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -986,7 +986,7 @@ public class ServerHttpSecurity { private ServerAuthenticationSuccessHandler authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler(); - private ServerAuthenticationFailureHandler authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception); + private ServerAuthenticationFailureHandler authenticationFailureHandler; /** * Configures the {@link ReactiveAuthenticationManager} to use. The default is @@ -1028,6 +1028,7 @@ public class ServerHttpSecurity { /** * The {@link ServerAuthenticationFailureHandler} used after authentication failure. + * Defaults to {@link RedirectServerAuthenticationFailureHandler} redirecting to "/login?error". * * @since 5.2 * @param authenticationFailureHandler the failure handler to use @@ -1175,7 +1176,7 @@ public class ServerHttpSecurity { authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository)); authenticationFilter.setAuthenticationSuccessHandler(this.authenticationSuccessHandler); - authenticationFilter.setAuthenticationFailureHandler(this.authenticationFailureHandler); + authenticationFilter.setAuthenticationFailureHandler(getAuthenticationFailureHandler()); authenticationFilter.setSecurityContextRepository(this.securityContextRepository); MediaTypeServerWebExchangeMatcher htmlMatcher = new MediaTypeServerWebExchangeMatcher( @@ -1192,6 +1193,13 @@ public class ServerHttpSecurity { http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION); } + private ServerAuthenticationFailureHandler getAuthenticationFailureHandler() { + if (this.authenticationFailureHandler == null) { + this.authenticationFailureHandler = new RedirectServerAuthenticationFailureHandler("/login?error"); + } + return this.authenticationFailureHandler; + } + private ServerWebExchangeMatcher createAttemptAuthenticationRequestMatcher() { return new PathPatternParserServerWebExchangeMatcher("/login/oauth2/code/{registrationId}"); } diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java index 046675934b..a783a8212a 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java @@ -70,6 +70,7 @@ import org.springframework.security.oauth2.core.oidc.user.TestOidcUsers; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.TestOAuth2Users; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtValidationException; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; @@ -518,6 +519,85 @@ public class OAuth2LoginTests { verify(securityContextRepository).save(any(), any()); } + // gh-5562 + @Test + public void oauth2LoginWhenAccessTokenRequestFailsThenDefaultRedirectToLogin() { + this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class, + OAuth2LoginWithCustomBeansConfig.class).autowire(); + + WebTestClient webTestClient = WebTestClientBuilder + .bindToWebFilters(this.springSecurity) + .build(); + + OAuth2AuthorizationRequest request = TestOAuth2AuthorizationRequests.request().scope("openid").build(); + OAuth2AuthorizationResponse response = TestOAuth2AuthorizationResponses.success().build(); + OAuth2AuthorizationExchange exchange = new OAuth2AuthorizationExchange(request, response); + OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("openid"); + OAuth2AuthorizationCodeAuthenticationToken authenticationToken = new OAuth2AuthorizationCodeAuthenticationToken(google, exchange, accessToken); + + OAuth2LoginWithCustomBeansConfig config = this.spring.getContext().getBean(OAuth2LoginWithCustomBeansConfig.class); + + ServerAuthenticationConverter converter = config.authenticationConverter; + when(converter.convert(any())).thenReturn(Mono.just(authenticationToken)); + + ReactiveOAuth2AccessTokenResponseClient tokenResponseClient = config.tokenResponseClient; + OAuth2Error oauth2Error = new OAuth2Error("invalid_request", "Invalid request", null); + when(tokenResponseClient.getTokenResponse(any())).thenThrow(new OAuth2AuthenticationException(oauth2Error)); + + webTestClient.get() + .uri("/login/oauth2/code/google") + .exchange() + .expectStatus() + .is3xxRedirection() + .expectHeader() + .valueEquals("Location", "/login?error"); + } + + // gh-6484 + @Test + public void oauth2LoginWhenIdTokenValidationFailsThenDefaultRedirectToLogin() { + this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class, + OAuth2LoginWithCustomBeansConfig.class).autowire(); + + WebTestClient webTestClient = WebTestClientBuilder + .bindToWebFilters(this.springSecurity) + .build(); + + OAuth2LoginWithCustomBeansConfig config = this.spring.getContext().getBean(OAuth2LoginWithCustomBeansConfig.class); + + OAuth2AuthorizationRequest request = TestOAuth2AuthorizationRequests.request().scope("openid").build(); + OAuth2AuthorizationResponse response = TestOAuth2AuthorizationResponses.success().build(); + OAuth2AuthorizationExchange exchange = new OAuth2AuthorizationExchange(request, response); + OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("openid"); + OAuth2AuthorizationCodeAuthenticationToken authenticationToken = new OAuth2AuthorizationCodeAuthenticationToken(google, exchange, accessToken); + + ServerAuthenticationConverter converter = config.authenticationConverter; + when(converter.convert(any())).thenReturn(Mono.just(authenticationToken)); + + Map additionalParameters = new HashMap<>(); + additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token"); + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue()) + .tokenType(accessToken.getTokenType()) + .scopes(accessToken.getScopes()) + .additionalParameters(additionalParameters) + .build(); + ReactiveOAuth2AccessTokenResponseClient tokenResponseClient = config.tokenResponseClient; + when(tokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); + + ReactiveJwtDecoderFactory jwtDecoderFactory = config.jwtDecoderFactory; + OAuth2Error oauth2Error = new OAuth2Error("invalid_id_token", "Invalid ID Token", null); + when(jwtDecoderFactory.createDecoder(any())).thenReturn(token -> + Mono.error(new JwtValidationException("ID Token validation failed", Collections.singleton(oauth2Error)))); + + webTestClient.get() + .uri("/login/oauth2/code/google") + .exchange() + .expectStatus() + .is3xxRedirection() + .expectHeader() + .valueEquals("Location", "/login?error"); + } + @Configuration static class OAuth2LoginWithCustomBeansConfig {