OAuth2LoginAuthenticationWebFilter should handle OAuth2AuthorizationException

Issue gh-8609
This commit is contained in:
Joe Grandja 2020-06-09 14:40:56 -04:00
parent acf56f24a6
commit 38c1e3ffa8
4 changed files with 43 additions and 18 deletions

View File

@ -31,6 +31,8 @@ import java.util.UUID;
import java.util.function.Function;
import java.util.function.Supplier;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;
@ -1086,9 +1088,14 @@ public class ServerHttpSecurity {
private ServerAuthenticationConverter getAuthenticationConverter(ReactiveClientRegistrationRepository clientRegistrationRepository) {
if (this.authenticationConverter == null) {
ServerOAuth2AuthorizationCodeAuthenticationTokenConverter authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository);
authenticationConverter.setAuthorizationRequestRepository(getAuthorizationRequestRepository());
ServerOAuth2AuthorizationCodeAuthenticationTokenConverter delegate =
new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository);
delegate.setAuthorizationRequestRepository(getAuthorizationRequestRepository());
ServerAuthenticationConverter authenticationConverter = exchange ->
delegate.convert(exchange).onErrorMap(OAuth2AuthorizationException.class,
e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString()));
this.authenticationConverter = authenticationConverter;
return authenticationConverter;
}
return this.authenticationConverter;
}

View File

@ -103,7 +103,10 @@ import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.jwt.TestJwts.jwt;
/**
@ -683,7 +686,6 @@ public class OAuth2LoginTests {
}
}
@Test
public void logoutWhenUsingOidcLogoutHandlerThenRedirects() {
this.spring.register(OAuth2LoginConfigWithOidcLogoutSuccessHandler.class).autowire();
@ -739,6 +741,24 @@ public class OAuth2LoginTests {
}
}
// gh-8609
@Test
public void oauth2LoginWhenAuthenticationConverterFailsThenDefaultRedirectToLogin() {
this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class).autowire();
WebTestClient webTestClient = WebTestClientBuilder
.bindToWebFilters(this.springSecurity)
.build();
webTestClient.get()
.uri("/login/oauth2/code/google")
.exchange()
.expectStatus()
.is3xxRedirection()
.expectHeader()
.valueEquals("Location", "/login?error");
}
static class GitHubWebFilter implements WebFilter {
@Override

View File

@ -121,13 +121,14 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
.getAuthorizationExchange().getAuthorizationResponse();
if (authorizationResponse.statusError()) {
throw new OAuth2AuthenticationException(
authorizationResponse.getError(), authorizationResponse.getError().toString());
return Mono.error(new OAuth2AuthenticationException(
authorizationResponse.getError(), authorizationResponse.getError().toString()));
}
if (!authorizationResponse.getState().equals(authorizationRequest.getState())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
return Mono.error(new OAuth2AuthenticationException(
oauth2Error, oauth2Error.toString()));
}
OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest(
@ -139,7 +140,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
.onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString()))
.onErrorMap(JwtException.class, e -> {
OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, e.getMessage(), null);
throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), e);
return new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), e);
});
});
}
@ -166,7 +167,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
INVALID_ID_TOKEN_ERROR_CODE,
"Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(),
null);
throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString());
return Mono.error(new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString()));
}
return createOidcToken(clientRegistration, accessTokenResponse)

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -105,19 +105,16 @@ public class AuthenticationWebFilter implements WebFilter {
.filter( matchResult -> matchResult.isMatch())
.flatMap( matchResult -> this.authenticationConverter.convert(exchange))
.switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
.flatMap( token -> authenticate(exchange, chain, token));
.flatMap( token -> authenticate(exchange, chain, token))
.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler
.onAuthenticationFailure(new WebFilterExchange(exchange, chain), e));
}
private Mono<Void> authenticate(ServerWebExchange exchange,
WebFilterChain chain, Authentication token) {
WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain);
private Mono<Void> authenticate(ServerWebExchange exchange, WebFilterChain chain, Authentication token) {
return this.authenticationManagerResolver.resolve(exchange)
.flatMap(authenticationManager -> authenticationManager.authenticate(token))
.switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalStateException("No provider found for " + token.getClass()))))
.flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange))
.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler
.onAuthenticationFailure(webFilterExchange, e));
.flatMap(authentication -> onAuthenticationSuccess(authentication, new WebFilterExchange(exchange, chain)));
}
protected Mono<Void> onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) {