OAuth2LoginAuthenticationWebFilter should handle OAuth2AuthorizationException

Issue gh-8609
This commit is contained in:
Joe Grandja 2020-06-09 14:40:56 -04:00
parent a372ec9ef5
commit e146a7c16b
4 changed files with 44 additions and 20 deletions

View File

@ -34,6 +34,8 @@ import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;
@ -578,7 +580,12 @@ public class ServerHttpSecurity {
private ServerAuthenticationConverter getAuthenticationConverter(ReactiveClientRegistrationRepository clientRegistrationRepository) {
if (this.authenticationConverter == null) {
this.authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository);
ServerOAuth2AuthorizationCodeAuthenticationTokenConverter delegate =
new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository);
ServerAuthenticationConverter authenticationConverter = exchange ->
delegate.convert(exchange).onErrorMap(OAuth2AuthorizationException.class,
e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString()));
this.authenticationConverter = authenticationConverter;
}
return this.authenticationConverter;
}

View File

@ -16,12 +16,6 @@
package org.springframework.security.config.web.server;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import org.junit.Rule;
import org.junit.Test;
import org.openqa.selenium.WebDriver;
@ -67,13 +61,18 @@ import org.springframework.web.reactive.config.EnableWebFlux;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebHandler;
import reactor.core.publisher.Mono;
import java.time.Duration;
import java.time.Instant;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/**
* @author Rob Winch
* @since 5.1
@ -301,6 +300,24 @@ public class OAuth2LoginTests {
}
}
// gh-8609
@Test
public void oauth2LoginWhenAuthenticationConverterFailsThenDefaultRedirectToLogin() {
this.spring.register(OAuth2LoginWithMulitpleClientRegistrations.class).autowire();
WebTestClient webTestClient = WebTestClientBuilder
.bindToWebFilters(this.springSecurity)
.build();
webTestClient.get()
.uri("/login/oauth2/code/google")
.exchange()
.expectStatus()
.is3xxRedirection()
.expectHeader()
.valueEquals("Location", "/login?error");
}
static class GitHubWebFilter implements WebFilter {
@Override

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -117,13 +117,14 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
.getAuthorizationExchange().getAuthorizationResponse();
if (authorizationResponse.statusError()) {
throw new OAuth2AuthenticationException(
authorizationResponse.getError(), authorizationResponse.getError().toString());
return Mono.error(new OAuth2AuthenticationException(
authorizationResponse.getError(), authorizationResponse.getError().toString()));
}
if (!authorizationResponse.getState().equals(authorizationRequest.getState())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
return Mono.error(new OAuth2AuthenticationException(
oauth2Error, oauth2Error.toString()));
}
OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest(
@ -156,7 +157,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
INVALID_ID_TOKEN_ERROR_CODE,
"Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(),
null);
throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString());
return Mono.error(new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString()));
}
return createOidcToken(clientRegistration, accessTokenResponse)

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -89,17 +89,16 @@ public class AuthenticationWebFilter implements WebFilter {
.filter( matchResult -> matchResult.isMatch())
.flatMap( matchResult -> this.authenticationConverter.convert(exchange))
.switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
.flatMap( token -> authenticate(exchange, chain, token));
.flatMap( token -> authenticate(exchange, chain, token))
.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler
.onAuthenticationFailure(new WebFilterExchange(exchange, chain), e));
}
private Mono<Void> authenticate(ServerWebExchange exchange,
WebFilterChain chain, Authentication token) {
private Mono<Void> authenticate(ServerWebExchange exchange, WebFilterChain chain, Authentication token) {
WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain);
return this.authenticationManager.authenticate(token)
.switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalStateException("No provider found for " + token.getClass()))))
.flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange))
.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler
.onAuthenticationFailure(webFilterExchange, e));
.flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange));
}
protected Mono<Void> onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) {