diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 95db5745df..1ab1c0a69a 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -34,6 +34,8 @@ import java.util.Map; import java.util.Optional; import java.util.function.Function; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import reactor.core.publisher.Mono; import reactor.util.context.Context; @@ -578,7 +580,12 @@ public class ServerHttpSecurity { private ServerAuthenticationConverter getAuthenticationConverter(ReactiveClientRegistrationRepository clientRegistrationRepository) { if (this.authenticationConverter == null) { - this.authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository); + ServerOAuth2AuthorizationCodeAuthenticationTokenConverter delegate = + new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository); + ServerAuthenticationConverter authenticationConverter = exchange -> + delegate.convert(exchange).onErrorMap(OAuth2AuthorizationException.class, + e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString())); + this.authenticationConverter = authenticationConverter; } return this.authenticationConverter; } diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java index 090fa7ac2e..db241653a6 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java @@ -16,12 +16,6 @@ package org.springframework.security.config.web.server; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - import org.junit.Rule; import org.junit.Test; import org.openqa.selenium.WebDriver; @@ -67,13 +61,18 @@ import org.springframework.web.reactive.config.EnableWebFlux; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; - import org.springframework.web.server.WebHandler; import reactor.core.publisher.Mono; import java.time.Duration; import java.time.Instant; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + /** * @author Rob Winch * @since 5.1 @@ -301,6 +300,24 @@ public class OAuth2LoginTests { } } + // gh-8609 + @Test + public void oauth2LoginWhenAuthenticationConverterFailsThenDefaultRedirectToLogin() { + this.spring.register(OAuth2LoginWithMulitpleClientRegistrations.class).autowire(); + + WebTestClient webTestClient = WebTestClientBuilder + .bindToWebFilters(this.springSecurity) + .build(); + + webTestClient.get() + .uri("/login/oauth2/code/google") + .exchange() + .expectStatus() + .is3xxRedirection() + .expectHeader() + .valueEquals("Location", "/login?error"); + } + static class GitHubWebFilter implements WebFilter { @Override diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java index 8d0c15bb95..1a3cdbabb1 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -117,13 +117,14 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements .getAuthorizationExchange().getAuthorizationResponse(); if (authorizationResponse.statusError()) { - throw new OAuth2AuthenticationException( - authorizationResponse.getError(), authorizationResponse.getError().toString()); + return Mono.error(new OAuth2AuthenticationException( + authorizationResponse.getError(), authorizationResponse.getError().toString())); } if (!authorizationResponse.getState().equals(authorizationRequest.getState())) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + return Mono.error(new OAuth2AuthenticationException( + oauth2Error, oauth2Error.toString())); } OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest( @@ -156,7 +157,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements INVALID_ID_TOKEN_ERROR_CODE, "Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(), null); - throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString()); + return Mono.error(new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString())); } return createOidcToken(clientRegistration, accessTokenResponse) diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java b/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java index 97fbc631cb..07a0c1a85a 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -89,17 +89,16 @@ public class AuthenticationWebFilter implements WebFilter { .filter( matchResult -> matchResult.isMatch()) .flatMap( matchResult -> this.authenticationConverter.convert(exchange)) .switchIfEmpty(chain.filter(exchange).then(Mono.empty())) - .flatMap( token -> authenticate(exchange, chain, token)); + .flatMap( token -> authenticate(exchange, chain, token)) + .onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler + .onAuthenticationFailure(new WebFilterExchange(exchange, chain), e)); } - private Mono authenticate(ServerWebExchange exchange, - WebFilterChain chain, Authentication token) { + private Mono authenticate(ServerWebExchange exchange, WebFilterChain chain, Authentication token) { WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain); return this.authenticationManager.authenticate(token) .switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalStateException("No provider found for " + token.getClass())))) - .flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange)) - .onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler - .onAuthenticationFailure(webFilterExchange, e)); + .flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange)); } protected Mono onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) {