OAuth2LoginAuthenticationWebFilter should handle OAuth2AuthorizationException

Issue gh-8609
This commit is contained in:
Joe Grandja 2020-06-09 14:40:56 -04:00
parent a372ec9ef5
commit e146a7c16b
4 changed files with 44 additions and 20 deletions

View File

@ -34,6 +34,8 @@ import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.function.Function; import java.util.function.Function;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.util.context.Context; import reactor.util.context.Context;
@ -578,7 +580,12 @@ public class ServerHttpSecurity {
private ServerAuthenticationConverter getAuthenticationConverter(ReactiveClientRegistrationRepository clientRegistrationRepository) { private ServerAuthenticationConverter getAuthenticationConverter(ReactiveClientRegistrationRepository clientRegistrationRepository) {
if (this.authenticationConverter == null) { if (this.authenticationConverter == null) {
this.authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository); ServerOAuth2AuthorizationCodeAuthenticationTokenConverter delegate =
new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository);
ServerAuthenticationConverter authenticationConverter = exchange ->
delegate.convert(exchange).onErrorMap(OAuth2AuthorizationException.class,
e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString()));
this.authenticationConverter = authenticationConverter;
} }
return this.authenticationConverter; return this.authenticationConverter;
} }

View File

@ -16,12 +16,6 @@
package org.springframework.security.config.web.server; package org.springframework.security.config.web.server;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.openqa.selenium.WebDriver; import org.openqa.selenium.WebDriver;
@ -67,13 +61,18 @@ import org.springframework.web.reactive.config.EnableWebFlux;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebHandler; import org.springframework.web.server.WebHandler;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/** /**
* @author Rob Winch * @author Rob Winch
* @since 5.1 * @since 5.1
@ -301,6 +300,24 @@ public class OAuth2LoginTests {
} }
} }
// gh-8609
@Test
public void oauth2LoginWhenAuthenticationConverterFailsThenDefaultRedirectToLogin() {
this.spring.register(OAuth2LoginWithMulitpleClientRegistrations.class).autowire();
WebTestClient webTestClient = WebTestClientBuilder
.bindToWebFilters(this.springSecurity)
.build();
webTestClient.get()
.uri("/login/oauth2/code/google")
.exchange()
.expectStatus()
.is3xxRedirection()
.expectHeader()
.valueEquals("Location", "/login?error");
}
static class GitHubWebFilter implements WebFilter { static class GitHubWebFilter implements WebFilter {
@Override @Override

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -117,13 +117,14 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
.getAuthorizationExchange().getAuthorizationResponse(); .getAuthorizationExchange().getAuthorizationResponse();
if (authorizationResponse.statusError()) { if (authorizationResponse.statusError()) {
throw new OAuth2AuthenticationException( return Mono.error(new OAuth2AuthenticationException(
authorizationResponse.getError(), authorizationResponse.getError().toString()); authorizationResponse.getError(), authorizationResponse.getError().toString()));
} }
if (!authorizationResponse.getState().equals(authorizationRequest.getState())) { if (!authorizationResponse.getState().equals(authorizationRequest.getState())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE); OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); return Mono.error(new OAuth2AuthenticationException(
oauth2Error, oauth2Error.toString()));
} }
OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest( OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest(
@ -156,7 +157,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
INVALID_ID_TOKEN_ERROR_CODE, INVALID_ID_TOKEN_ERROR_CODE,
"Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(), "Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(),
null); null);
throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString()); return Mono.error(new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString()));
} }
return createOidcToken(clientRegistration, accessTokenResponse) return createOidcToken(clientRegistration, accessTokenResponse)

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2017 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -89,17 +89,16 @@ public class AuthenticationWebFilter implements WebFilter {
.filter( matchResult -> matchResult.isMatch()) .filter( matchResult -> matchResult.isMatch())
.flatMap( matchResult -> this.authenticationConverter.convert(exchange)) .flatMap( matchResult -> this.authenticationConverter.convert(exchange))
.switchIfEmpty(chain.filter(exchange).then(Mono.empty())) .switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
.flatMap( token -> authenticate(exchange, chain, token)); .flatMap( token -> authenticate(exchange, chain, token))
.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler
.onAuthenticationFailure(new WebFilterExchange(exchange, chain), e));
} }
private Mono<Void> authenticate(ServerWebExchange exchange, private Mono<Void> authenticate(ServerWebExchange exchange, WebFilterChain chain, Authentication token) {
WebFilterChain chain, Authentication token) {
WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain); WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain);
return this.authenticationManager.authenticate(token) return this.authenticationManager.authenticate(token)
.switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalStateException("No provider found for " + token.getClass())))) .switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalStateException("No provider found for " + token.getClass()))))
.flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange)) .flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange));
.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler
.onAuthenticationFailure(webFilterExchange, e));
} }
protected Mono<Void> onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) { protected Mono<Void> onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) {