diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index fc5fb598b3..4ad62b09ce 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -20,7 +20,11 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.util.Assert; /** @@ -40,6 +44,7 @@ import org.springframework.util.Assert; * @see Section 4.1.4 Access Token Response */ public class OAuth2AuthorizationCodeAuthenticationProvider implements AuthenticationProvider { + private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter"; private final OAuth2AccessTokenResponseClient accessTokenResponseClient; /** @@ -59,8 +64,18 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = (OAuth2AuthorizationCodeAuthenticationToken) authentication; - OAuth2AuthorizationExchangeValidator.validate( - authorizationCodeAuthentication.getAuthorizationExchange()); + OAuth2AuthorizationResponse authorizationResponse = authorizationCodeAuthentication + .getAuthorizationExchange().getAuthorizationResponse(); + if (authorizationResponse.statusError()) { + throw new OAuth2AuthorizationException(authorizationResponse.getError()); + } + + OAuth2AuthorizationRequest authorizationRequest = authorizationCodeAuthentication + .getAuthorizationExchange().getAuthorizationRequest(); + if (!authorizationResponse.getState().equals(authorizationRequest.getState())) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE); + throw new OAuth2AuthorizationException(oauth2Error); + } OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenResponseClient.getTokenResponse( diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManager.java index 4ecdd3b85c..28e4d7bdfe 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,9 +22,13 @@ import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessT import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.util.Assert; import reactor.core.publisher.Mono; @@ -55,8 +59,8 @@ import java.util.function.Function; * @see Section 4.1.3 Access Token Request * @see Section 4.1.4 Access Token Response */ -public class OAuth2AuthorizationCodeReactiveAuthenticationManager implements - ReactiveAuthenticationManager { +public class OAuth2AuthorizationCodeReactiveAuthenticationManager implements ReactiveAuthenticationManager { + private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter"; private final ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; public OAuth2AuthorizationCodeReactiveAuthenticationManager( @@ -70,7 +74,16 @@ public class OAuth2AuthorizationCodeReactiveAuthenticationManager implements return Mono.defer(() -> { OAuth2AuthorizationCodeAuthenticationToken token = (OAuth2AuthorizationCodeAuthenticationToken) authentication; - OAuth2AuthorizationExchangeValidator.validate(token.getAuthorizationExchange()); + OAuth2AuthorizationResponse authorizationResponse = token.getAuthorizationExchange().getAuthorizationResponse(); + if (authorizationResponse.statusError()) { + return Mono.error(new OAuth2AuthorizationException(authorizationResponse.getError())); + } + + OAuth2AuthorizationRequest authorizationRequest = token.getAuthorizationExchange().getAuthorizationRequest(); + if (!authorizationResponse.getState().equals(authorizationRequest.getState())) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE); + return Mono.error(new OAuth2AuthorizationException(oauth2Error)); + } OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest( token.getClientRegistration(), diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationExchangeValidator.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationExchangeValidator.java deleted file mode 100644 index a240e0521a..0000000000 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationExchangeValidator.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright 2002-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.client.authentication; - -import org.springframework.security.oauth2.core.OAuth2AuthorizationException; -import org.springframework.security.oauth2.core.OAuth2Error; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; - -/** - * A validator for an "exchange" of an OAuth 2.0 Authorization Request and Response. - * - * @author Joe Grandja - * @since 5.1 - * @see OAuth2AuthorizationExchange - */ -final class OAuth2AuthorizationExchangeValidator { - private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter"; - - static void validate(OAuth2AuthorizationExchange authorizationExchange) { - OAuth2AuthorizationRequest authorizationRequest = authorizationExchange.getAuthorizationRequest(); - OAuth2AuthorizationResponse authorizationResponse = authorizationExchange.getAuthorizationResponse(); - - if (authorizationResponse.statusError()) { - throw new OAuth2AuthorizationException(authorizationResponse.getError()); - } - - if (!authorizationResponse.getState().equals(authorizationRequest.getState())) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE); - throw new OAuth2AuthorizationException(oauth2Error); - } - } -} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java index ecd9084744..bcae130597 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java @@ -27,6 +27,8 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -146,15 +148,21 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { return this.requiresAuthenticationMatcher.matches(exchange) .filter(ServerWebExchangeMatcher.MatchResult::isMatch) - .flatMap(matchResult -> this.authenticationConverter.convert(exchange)) + .flatMap(matchResult -> + this.authenticationConverter.convert(exchange) + .onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException( + e.getError(), e.getError().toString()))) .switchIfEmpty(chain.filter(exchange).then(Mono.empty())) - .flatMap(token -> authenticate(exchange, chain, token)); + .flatMap(token -> authenticate(exchange, chain, token)) + .onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler + .onAuthenticationFailure(new WebFilterExchange(exchange, chain), e)); } - private Mono authenticate(ServerWebExchange exchange, - WebFilterChain chain, Authentication token) { + private Mono authenticate(ServerWebExchange exchange, WebFilterChain chain, Authentication token) { WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain); return this.authenticationManager.authenticate(token) + .onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException( + e.getError(), e.getError().toString())) .switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalStateException("No provider found for " + token.getClass())))) .flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange)) .onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.java index 3d76bc390e..31c9b57cf5 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.java @@ -18,7 +18,6 @@ package org.springframework.security.oauth2.client.web.server; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken; -import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -33,7 +32,7 @@ import org.springframework.web.util.UriComponentsBuilder; import reactor.core.publisher.Mono; /** - * Converts from a {@link ServerWebExchange} to an {@link OAuth2LoginAuthenticationToken} that can be authenticated. The + * Converts from a {@link ServerWebExchange} to an {@link OAuth2AuthorizationCodeAuthenticationToken} that can be authenticated. The * converter does not validate any errors it only performs a conversion. * @author Rob Winch * @since 5.1 diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java index cbc08accd6..d635688cf4 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java @@ -29,6 +29,9 @@ import org.springframework.security.oauth2.client.authentication.TestOAuth2Autho import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.CollectionUtils; @@ -41,6 +44,7 @@ import java.util.LinkedHashMap; import java.util.Map; import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -226,6 +230,56 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { verifyZeroInteractions(this.authenticationManager); } + // gh-8609 + @Test + public void filterWhenAuthenticationConverterThrowsOAuth2AuthorizationExceptionThenMappedToOAuth2AuthenticationException() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.empty()); + + MockServerHttpRequest authorizationRequest = + createAuthorizationRequest("/authorization/callback"); + OAuth2AuthorizationRequest oauth2AuthorizationRequest = + createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); + MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); + MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); + DefaultWebFilterChain chain = new DefaultWebFilterChain( + e -> e.getResponse().setComplete(), Collections.emptyList()); + + this.authorizationRequestRepository.saveAuthorizationRequest(oauth2AuthorizationRequest, exchange).block(); + + assertThatThrownBy(() -> this.filter.filter(exchange, chain).block()) + .isInstanceOf(OAuth2AuthenticationException.class) + .hasMessageContaining("client_registration_not_found"); + verifyZeroInteractions(this.authenticationManager); + } + + // gh-8609 + @Test + public void filterWhenAuthenticationManagerThrowsOAuth2AuthorizationExceptionThenMappedToOAuth2AuthenticationException() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + when(this.clientRegistrationRepository.findByRegistrationId(any())) + .thenReturn(Mono.just(clientRegistration)); + + MockServerHttpRequest authorizationRequest = + createAuthorizationRequest("/authorization/callback"); + OAuth2AuthorizationRequest oauth2AuthorizationRequest = + createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); + + when(this.authenticationManager.authenticate(any())) + .thenReturn(Mono.error(new OAuth2AuthorizationException(new OAuth2Error("authorization_error")))); + + MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); + MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); + DefaultWebFilterChain chain = new DefaultWebFilterChain( + e -> e.getResponse().setComplete(), Collections.emptyList()); + + this.authorizationRequestRepository.saveAuthorizationRequest(oauth2AuthorizationRequest, exchange).block(); + + assertThatThrownBy(() -> this.filter.filter(exchange, chain).block()) + .isInstanceOf(OAuth2AuthenticationException.class) + .hasMessageContaining("authorization_error"); + } + private static OAuth2AuthorizationRequest createOAuth2AuthorizationRequest( MockServerHttpRequest authorizationRequest, ClientRegistration registration) { Map additionalParameters = new HashMap<>();