OAuth2LoginAuthenticationWebFilter should handle OAuth2AuthorizationException

Issue gh-8609
This commit is contained in:
Joe Grandja 2020-06-09 14:40:56 -04:00
parent 11c1236261
commit 674e2c0a8e
4 changed files with 43 additions and 19 deletions

View File

@ -33,6 +33,8 @@ import java.util.function.Supplier;
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.util.context.Context; import reactor.util.context.Context;
@ -1089,8 +1091,12 @@ public class ServerHttpSecurity {
private ServerAuthenticationConverter getAuthenticationConverter(ReactiveClientRegistrationRepository clientRegistrationRepository) { private ServerAuthenticationConverter getAuthenticationConverter(ReactiveClientRegistrationRepository clientRegistrationRepository) {
if (this.authenticationConverter == null) { if (this.authenticationConverter == null) {
ServerOAuth2AuthorizationCodeAuthenticationTokenConverter authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository); ServerOAuth2AuthorizationCodeAuthenticationTokenConverter delegate =
authenticationConverter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository);
delegate.setAuthorizationRequestRepository(getAuthorizationRequestRepository());
ServerAuthenticationConverter authenticationConverter = exchange ->
delegate.convert(exchange).onErrorMap(OAuth2AuthorizationException.class,
e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString()));
this.authenticationConverter = authenticationConverter; this.authenticationConverter = authenticationConverter;
} }
return this.authenticationConverter; return this.authenticationConverter;

View File

@ -103,7 +103,10 @@ import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.jwt.TestJwts.jwt; import static org.springframework.security.oauth2.jwt.TestJwts.jwt;
/** /**
@ -683,7 +686,6 @@ public class OAuth2LoginTests {
} }
} }
@Test @Test
public void logoutWhenUsingOidcLogoutHandlerThenRedirects() { public void logoutWhenUsingOidcLogoutHandlerThenRedirects() {
this.spring.register(OAuth2LoginConfigWithOidcLogoutSuccessHandler.class).autowire(); this.spring.register(OAuth2LoginConfigWithOidcLogoutSuccessHandler.class).autowire();
@ -739,6 +741,24 @@ public class OAuth2LoginTests {
} }
} }
// gh-8609
@Test
public void oauth2LoginWhenAuthenticationConverterFailsThenDefaultRedirectToLogin() {
this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class).autowire();
WebTestClient webTestClient = WebTestClientBuilder
.bindToWebFilters(this.springSecurity)
.build();
webTestClient.get()
.uri("/login/oauth2/code/google")
.exchange()
.expectStatus()
.is3xxRedirection()
.expectHeader()
.valueEquals("Location", "/login?error");
}
static class GitHubWebFilter implements WebFilter { static class GitHubWebFilter implements WebFilter {
@Override @Override

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -121,13 +121,14 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
.getAuthorizationExchange().getAuthorizationResponse(); .getAuthorizationExchange().getAuthorizationResponse();
if (authorizationResponse.statusError()) { if (authorizationResponse.statusError()) {
throw new OAuth2AuthenticationException( return Mono.error(new OAuth2AuthenticationException(
authorizationResponse.getError(), authorizationResponse.getError().toString()); authorizationResponse.getError(), authorizationResponse.getError().toString()));
} }
if (!authorizationResponse.getState().equals(authorizationRequest.getState())) { if (!authorizationResponse.getState().equals(authorizationRequest.getState())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE); OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); return Mono.error(new OAuth2AuthenticationException(
oauth2Error, oauth2Error.toString()));
} }
OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest( OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest(
@ -139,7 +140,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
.onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString())) .onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString()))
.onErrorMap(JwtException.class, e -> { .onErrorMap(JwtException.class, e -> {
OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, e.getMessage(), null); OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, e.getMessage(), null);
throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), e); return new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), e);
}); });
}); });
} }
@ -166,7 +167,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
INVALID_ID_TOKEN_ERROR_CODE, INVALID_ID_TOKEN_ERROR_CODE,
"Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(), "Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(),
null); null);
throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString()); return Mono.error(new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString()));
} }
return createOidcToken(clientRegistration, accessTokenResponse) return createOidcToken(clientRegistration, accessTokenResponse)

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -106,19 +106,16 @@ public class AuthenticationWebFilter implements WebFilter {
.filter( matchResult -> matchResult.isMatch()) .filter( matchResult -> matchResult.isMatch())
.flatMap( matchResult -> this.authenticationConverter.convert(exchange)) .flatMap( matchResult -> this.authenticationConverter.convert(exchange))
.switchIfEmpty(chain.filter(exchange).then(Mono.empty())) .switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
.flatMap( token -> authenticate(exchange, chain, token)); .flatMap( token -> authenticate(exchange, chain, token))
.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler
.onAuthenticationFailure(new WebFilterExchange(exchange, chain), e));
} }
private Mono<Void> authenticate(ServerWebExchange exchange, private Mono<Void> authenticate(ServerWebExchange exchange, WebFilterChain chain, Authentication token) {
WebFilterChain chain, Authentication token) {
WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain);
return this.authenticationManagerResolver.resolve(exchange.getRequest()) return this.authenticationManagerResolver.resolve(exchange.getRequest())
.flatMap(authenticationManager -> authenticationManager.authenticate(token)) .flatMap(authenticationManager -> authenticationManager.authenticate(token))
.switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalStateException("No provider found for " + token.getClass())))) .switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalStateException("No provider found for " + token.getClass()))))
.flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange)) .flatMap(authentication -> onAuthenticationSuccess(authentication, new WebFilterExchange(exchange, chain)));
.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler
.onAuthenticationFailure(webFilterExchange, e));
} }
protected Mono<Void> onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) { protected Mono<Void> onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) {