diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index afdb3fd1b9..c6ebdbe2dd 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -31,6 +31,8 @@ import java.util.UUID; import java.util.function.Function; import java.util.function.Supplier; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import reactor.core.publisher.Mono; import reactor.util.context.Context; @@ -1086,9 +1088,14 @@ public class ServerHttpSecurity { private ServerAuthenticationConverter getAuthenticationConverter(ReactiveClientRegistrationRepository clientRegistrationRepository) { if (this.authenticationConverter == null) { - ServerOAuth2AuthorizationCodeAuthenticationTokenConverter authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository); - authenticationConverter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); + ServerOAuth2AuthorizationCodeAuthenticationTokenConverter delegate = + new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository); + delegate.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); + ServerAuthenticationConverter authenticationConverter = exchange -> + delegate.convert(exchange).onErrorMap(OAuth2AuthorizationException.class, + e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString())); this.authenticationConverter = authenticationConverter; + return authenticationConverter; } return this.authenticationConverter; } diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java index 4723da1806..b1e5662a3e 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java @@ -103,7 +103,10 @@ import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.jwt.TestJwts.jwt; /** @@ -683,7 +686,6 @@ public class OAuth2LoginTests { } } - @Test public void logoutWhenUsingOidcLogoutHandlerThenRedirects() { this.spring.register(OAuth2LoginConfigWithOidcLogoutSuccessHandler.class).autowire(); @@ -739,6 +741,24 @@ public class OAuth2LoginTests { } } + // gh-8609 + @Test + public void oauth2LoginWhenAuthenticationConverterFailsThenDefaultRedirectToLogin() { + this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class).autowire(); + + WebTestClient webTestClient = WebTestClientBuilder + .bindToWebFilters(this.springSecurity) + .build(); + + webTestClient.get() + .uri("/login/oauth2/code/google") + .exchange() + .expectStatus() + .is3xxRedirection() + .expectHeader() + .valueEquals("Location", "/login?error"); + } + static class GitHubWebFilter implements WebFilter { @Override diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java index 9852ff1479..9fa9193b6d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java @@ -121,13 +121,14 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements .getAuthorizationExchange().getAuthorizationResponse(); if (authorizationResponse.statusError()) { - throw new OAuth2AuthenticationException( - authorizationResponse.getError(), authorizationResponse.getError().toString()); + return Mono.error(new OAuth2AuthenticationException( + authorizationResponse.getError(), authorizationResponse.getError().toString())); } if (!authorizationResponse.getState().equals(authorizationRequest.getState())) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + return Mono.error(new OAuth2AuthenticationException( + oauth2Error, oauth2Error.toString())); } OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest( @@ -139,7 +140,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements .onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString())) .onErrorMap(JwtException.class, e -> { OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, e.getMessage(), null); - throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), e); + return new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), e); }); }); } @@ -166,7 +167,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements INVALID_ID_TOKEN_ERROR_CODE, "Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(), null); - throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString()); + return Mono.error(new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString())); } return createOidcToken(clientRegistration, accessTokenResponse) diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java b/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java index 1693381294..c2d7476ba3 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -105,19 +105,16 @@ public class AuthenticationWebFilter implements WebFilter { .filter( matchResult -> matchResult.isMatch()) .flatMap( matchResult -> this.authenticationConverter.convert(exchange)) .switchIfEmpty(chain.filter(exchange).then(Mono.empty())) - .flatMap( token -> authenticate(exchange, chain, token)); + .flatMap( token -> authenticate(exchange, chain, token)) + .onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler + .onAuthenticationFailure(new WebFilterExchange(exchange, chain), e)); } - private Mono authenticate(ServerWebExchange exchange, - WebFilterChain chain, Authentication token) { - WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain); - + private Mono authenticate(ServerWebExchange exchange, WebFilterChain chain, Authentication token) { return this.authenticationManagerResolver.resolve(exchange) .flatMap(authenticationManager -> authenticationManager.authenticate(token)) .switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalStateException("No provider found for " + token.getClass())))) - .flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange)) - .onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler - .onAuthenticationFailure(webFilterExchange, e)); + .flatMap(authentication -> onAuthenticationSuccess(authentication, new WebFilterExchange(exchange, chain))); } protected Mono onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) {