Add OAuth2Authorization success/failure handlers

Fixes gh-7840
This commit is contained in:
Joe Grandja 2020-02-13 05:36:17 -05:00
parent 1b68cdb650
commit 69156b741d
15 changed files with 1349 additions and 101 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,6 +19,11 @@ import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.RemoveAuthorizedClientOAuth2AuthorizationFailureHandler;
import org.springframework.security.oauth2.client.web.SaveAuthorizedClientOAuth2AuthorizationSuccessHandler;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
@ -31,20 +36,50 @@ import java.util.function.Function;
/**
* An implementation of an {@link OAuth2AuthorizedClientManager}
* that is capable of operating outside of a {@code HttpServletRequest} context,
* that is capable of operating outside of the context of a {@code HttpServletRequest},
* e.g. in a scheduled/background thread and/or in the service-tier.
*
* <p>
* (When operating <em>within</em> the context of a {@code HttpServletRequest},
* use {@link DefaultOAuth2AuthorizedClientManager} instead.)
*
* <h2>Authorized Client Persistence</h2>
*
* <p>
* This manager utilizes an {@link OAuth2AuthorizedClientService}
* to persist {@link OAuth2AuthorizedClient}s.
*
* <p>
* By default, when an authorization attempt succeeds, the {@link OAuth2AuthorizedClient}
* will be saved in the {@link OAuth2AuthorizedClientService}.
* This functionality can be changed by configuring a custom {@link OAuth2AuthorizationSuccessHandler}
* via {@link #setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler)}.
*
* <p>
* By default, when an authorization attempt fails due to an
* {@value OAuth2ErrorCodes#INVALID_GRANT} error,
* the previously saved {@link OAuth2AuthorizedClient}
* will be removed from the {@link OAuth2AuthorizedClientService}.
* (The {@value OAuth2ErrorCodes#INVALID_GRANT} error can occur
* when a refresh token that is no longer valid is used to retrieve a new access token.)
* This functionality can be changed by configuring a custom {@link OAuth2AuthorizationFailureHandler}
* via {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)}.
*
* @author Joe Grandja
* @since 5.2
* @see OAuth2AuthorizedClientManager
* @see OAuth2AuthorizedClientProvider
* @see OAuth2AuthorizedClientService
* @see OAuth2AuthorizationSuccessHandler
* @see OAuth2AuthorizationFailureHandler
*/
public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implements OAuth2AuthorizedClientManager {
private final ClientRegistrationRepository clientRegistrationRepository;
private final OAuth2AuthorizedClientService authorizedClientService;
private OAuth2AuthorizedClientProvider authorizedClientProvider = context -> null;
private Function<OAuth2AuthorizeRequest, Map<String, Object>> contextAttributesMapper = new DefaultContextAttributesMapper();
private Function<OAuth2AuthorizeRequest, Map<String, Object>> contextAttributesMapper;
private OAuth2AuthorizationSuccessHandler authorizationSuccessHandler;
private OAuth2AuthorizationFailureHandler authorizationFailureHandler;
/**
* Constructs an {@code AuthorizedClientServiceOAuth2AuthorizedClientManager} using the provided parameters.
@ -58,6 +93,9 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen
Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
this.clientRegistrationRepository = clientRegistrationRepository;
this.authorizedClientService = authorizedClientService;
this.contextAttributesMapper = new DefaultContextAttributesMapper();
this.authorizationSuccessHandler = new SaveAuthorizedClientOAuth2AuthorizationSuccessHandler(authorizedClientService);
this.authorizationFailureHandler = new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(authorizedClientService);
}
@Nullable
@ -92,9 +130,16 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen
})
.build();
authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
try {
authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
} catch (OAuth2AuthorizationException ex) {
this.authorizationFailureHandler.onAuthorizationFailure(ex, principal, Collections.emptyMap());
throw ex;
}
if (authorizedClient != null) {
this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal);
this.authorizationSuccessHandler.onAuthorizationSuccess(
authorizedClient, principal, Collections.emptyMap());
} else {
// In the case of re-authorization, the returned `authorizedClient` may be null if re-authorization is not supported.
// For these cases, return the provided `authorizationContext.authorizedClient`.
@ -128,6 +173,36 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen
this.contextAttributesMapper = contextAttributesMapper;
}
/**
* Sets the {@link OAuth2AuthorizationSuccessHandler} that handles successful authorizations.
*
* <p>
* A {@link SaveAuthorizedClientOAuth2AuthorizationSuccessHandler} is used by default.
*
* @param authorizationSuccessHandler the {@link OAuth2AuthorizationSuccessHandler} that handles successful authorizations
* @see SaveAuthorizedClientOAuth2AuthorizationSuccessHandler
* @since 5.3
*/
public void setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler authorizationSuccessHandler) {
Assert.notNull(authorizationSuccessHandler, "authorizationSuccessHandler cannot be null");
this.authorizationSuccessHandler = authorizationSuccessHandler;
}
/**
* Sets the {@link OAuth2AuthorizationFailureHandler} that handles authorization failures.
*
* <p>
* A {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} is used by default.
*
* @param authorizationFailureHandler the {@link OAuth2AuthorizationFailureHandler} that handles authorization failures
* @see RemoveAuthorizedClientOAuth2AuthorizationFailureHandler
* @since 5.3
*/
public void setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler authorizationFailureHandler) {
Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null");
this.authorizationFailureHandler = authorizationFailureHandler;
}
/**
* The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}.
*/

View File

@ -0,0 +1,48 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import java.util.Map;
/**
* Handles when an OAuth 2.0 Client fails to authorize (or re-authorize)
* via the Authorization Server or Resource Server.
*
* @author Joe Grandja
* @since 5.3
* @see OAuth2AuthorizedClient
* @see OAuth2AuthorizedClientManager
*/
@FunctionalInterface
public interface OAuth2AuthorizationFailureHandler {
/**
* Called when an OAuth 2.0 Client fails to authorize (or re-authorize)
* via the Authorization Server or Resource Server.
*
* @param authorizationException the exception that contains details about what failed
* @param principal the {@code Principal} associated with the attempted authorization
* @param attributes an immutable {@code Map} of (optional) attributes present under certain conditions.
* For example, this might contain a {@code javax.servlet.http.HttpServletRequest}
* and {@code javax.servlet.http.HttpServletResponse} if the authorization was performed
* within the context of a {@code javax.servlet.ServletContext}.
*/
void onAuthorizationFailure(OAuth2AuthorizationException authorizationException,
Authentication principal, Map<String, Object> attributes);
}

View File

@ -0,0 +1,47 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.springframework.security.core.Authentication;
import java.util.Map;
/**
* Handles when an OAuth 2.0 Client has been successfully
* authorized (or re-authorized) via the Authorization Server.
*
* @author Joe Grandja
* @since 5.3
* @see OAuth2AuthorizedClient
* @see OAuth2AuthorizedClientManager
*/
@FunctionalInterface
public interface OAuth2AuthorizationSuccessHandler {
/**
* Called when an OAuth 2.0 Client has been successfully
* authorized (or re-authorized) via the Authorization Server.
*
* @param authorizedClient the client that was successfully authorized (or re-authorized)
* @param principal the {@code Principal} associated with the authorized client
* @param attributes an immutable {@code Map} of (optional) attributes present under certain conditions.
* For example, this might contain a {@code javax.servlet.http.HttpServletRequest}
* and {@code javax.servlet.http.HttpServletResponse} if the authorization was performed
* within the context of a {@code javax.servlet.ServletContext}.
*/
void onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient,
Authentication principal, Map<String, Object> attributes);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,9 +20,9 @@ import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.http.converter.FormHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
@ -30,6 +30,7 @@ import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestClientResponseException;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
@ -74,9 +75,22 @@ public final class DefaultAuthorizationCodeTokenResponseClient implements OAuth2
try {
response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class);
} catch (RestClientException ex) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null);
throw new OAuth2AuthorizationException(oauth2Error, ex);
int statusCode = 500;
if (ex instanceof RestClientResponseException) {
statusCode = ((RestClientResponseException) ex).getRawStatusCode();
}
OAuth2Error oauth2Error = new OAuth2Error(
INVALID_TOKEN_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(),
null);
String message = String.format("Error retrieving OAuth 2.0 Access Token (HTTP Status Code: %s) %s",
statusCode,
oauth2Error);
throw new ClientAuthorizationException(
oauth2Error,
authorizationCodeGrantRequest.getClientRegistration().getRegistrationId(),
message,
ex);
}
OAuth2AccessTokenResponse tokenResponse = response.getBody();

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,9 +20,9 @@ import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.http.converter.FormHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
@ -30,6 +30,7 @@ import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestClientResponseException;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
@ -74,9 +75,22 @@ public final class DefaultClientCredentialsTokenResponseClient implements OAuth2
try {
response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class);
} catch (RestClientException ex) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null);
throw new OAuth2AuthorizationException(oauth2Error, ex);
int statusCode = 500;
if (ex instanceof RestClientResponseException) {
statusCode = ((RestClientResponseException) ex).getRawStatusCode();
}
OAuth2Error oauth2Error = new OAuth2Error(
INVALID_TOKEN_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(),
null);
String message = String.format("Error retrieving OAuth 2.0 Access Token (HTTP Status Code: %s) %s",
statusCode,
oauth2Error);
throw new ClientAuthorizationException(
oauth2Error,
clientCredentialsGrantRequest.getClientRegistration().getRegistrationId(),
message,
ex);
}
OAuth2AccessTokenResponse tokenResponse = response.getBody();

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,9 +20,9 @@ import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.http.converter.FormHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
@ -30,6 +30,7 @@ import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestClientResponseException;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
@ -74,9 +75,22 @@ public final class DefaultPasswordTokenResponseClient implements OAuth2AccessTok
try {
response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class);
} catch (RestClientException ex) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null);
throw new OAuth2AuthorizationException(oauth2Error, ex);
int statusCode = 500;
if (ex instanceof RestClientResponseException) {
statusCode = ((RestClientResponseException) ex).getRawStatusCode();
}
OAuth2Error oauth2Error = new OAuth2Error(
INVALID_TOKEN_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(),
null);
String message = String.format("Error retrieving OAuth 2.0 Access Token (HTTP Status Code: %s) %s",
statusCode,
oauth2Error);
throw new ClientAuthorizationException(
oauth2Error,
passwordGrantRequest.getClientRegistration().getRegistrationId(),
message,
ex);
}
OAuth2AccessTokenResponse tokenResponse = response.getBody();

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,9 +20,9 @@ import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.http.converter.FormHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
@ -30,6 +30,7 @@ import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestClientResponseException;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
@ -73,9 +74,22 @@ public final class DefaultRefreshTokenTokenResponseClient implements OAuth2Acces
try {
response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class);
} catch (RestClientException ex) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null);
throw new OAuth2AuthorizationException(oauth2Error, ex);
int statusCode = 500;
if (ex instanceof RestClientResponseException) {
statusCode = ((RestClientResponseException) ex).getRawStatusCode();
}
OAuth2Error oauth2Error = new OAuth2Error(
INVALID_TOKEN_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(),
null);
String message = String.format("Error retrieving OAuth 2.0 Access Token (HTTP Status Code: %s) %s",
statusCode,
oauth2Error);
throw new ClientAuthorizationException(
oauth2Error,
refreshTokenGrantRequest.getClientRegistration().getRegistrationId(),
message,
ex);
}
OAuth2AccessTokenResponse tokenResponse = response.getBody();

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -31,10 +31,10 @@ import com.nimbusds.oauth2.sdk.auth.Secret;
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
import com.nimbusds.oauth2.sdk.id.ClientID;
import org.springframework.http.MediaType;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@ -100,9 +100,19 @@ public class NimbusAuthorizationCodeTokenResponseClient implements OAuth2AccessT
httpRequest.setReadTimeout(30000);
tokenResponse = com.nimbusds.oauth2.sdk.TokenResponse.parse(httpRequest.send());
} catch (ParseException | IOException ex) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null);
throw new OAuth2AuthorizationException(oauth2Error, ex);
int statusCode = 500;
OAuth2Error oauth2Error = new OAuth2Error(
INVALID_TOKEN_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(),
null);
String message = String.format("Error retrieving OAuth 2.0 Access Token (HTTP Status Code: %s) %s",
statusCode,
oauth2Error);
throw new ClientAuthorizationException(
oauth2Error,
clientRegistration.getRegistrationId(),
message,
ex);
}
if (!tokenResponse.indicatesSuccess()) {
@ -117,7 +127,7 @@ public class NimbusAuthorizationCodeTokenResponseClient implements OAuth2AccessT
errorObject.getDescription(),
errorObject.getURI() != null ? errorObject.getURI().toString() : null);
}
throw new OAuth2AuthorizationException(oauth2Error);
throw new ClientAuthorizationException(oauth2Error, clientRegistration.getRegistrationId());
}
AccessTokenResponse accessTokenResponse = (AccessTokenResponse) tokenResponse;

View File

@ -15,22 +15,20 @@
*/
package org.springframework.security.oauth2.client.web;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.AuthorizedClientServiceOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
import org.springframework.security.oauth2.client.OAuth2AuthorizationSuccessHandler;
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
@ -39,19 +37,57 @@ import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
/**
* The default implementation of an {@link OAuth2AuthorizedClientManager}.
* The default implementation of an {@link OAuth2AuthorizedClientManager}
* for use within the context of a {@code HttpServletRequest}.
*
* <p>
* (When operating <em>outside</em> of the context of a {@code HttpServletRequest},
* use {@link AuthorizedClientServiceOAuth2AuthorizedClientManager} instead.)
*
* <h2>Authorized Client Persistence</h2>
*
* <p>
* This manager utilizes an {@link OAuth2AuthorizedClientRepository}
* to persist {@link OAuth2AuthorizedClient}s.
*
* <p>
* By default, when an authorization attempt succeeds, the {@link OAuth2AuthorizedClient}
* will be saved in the {@link OAuth2AuthorizedClientRepository}.
* This functionality can be changed by configuring a custom {@link OAuth2AuthorizationSuccessHandler}
* via {@link #setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler)}.
*
* <p>
* By default, when an authorization attempt fails due to an
* {@value OAuth2ErrorCodes#INVALID_GRANT} error,
* the previously saved {@link OAuth2AuthorizedClient}
* will be removed from the {@link OAuth2AuthorizedClientRepository}.
* (The {@value OAuth2ErrorCodes#INVALID_GRANT} error can occur
* when a refresh token that is no longer valid is used to retrieve a new access token.)
* This functionality can be changed by configuring a custom {@link OAuth2AuthorizationFailureHandler}
* via {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)}.
*
* @author Joe Grandja
* @since 5.2
* @see OAuth2AuthorizedClientManager
* @see OAuth2AuthorizedClientProvider
* @see OAuth2AuthorizationSuccessHandler
* @see OAuth2AuthorizationFailureHandler
*/
public final class DefaultOAuth2AuthorizedClientManager implements OAuth2AuthorizedClientManager {
private final ClientRegistrationRepository clientRegistrationRepository;
private final OAuth2AuthorizedClientRepository authorizedClientRepository;
private OAuth2AuthorizedClientProvider authorizedClientProvider = context -> null;
private Function<OAuth2AuthorizeRequest, Map<String, Object>> contextAttributesMapper = new DefaultContextAttributesMapper();
private Function<OAuth2AuthorizeRequest, Map<String, Object>> contextAttributesMapper;
private OAuth2AuthorizationSuccessHandler authorizationSuccessHandler;
private OAuth2AuthorizationFailureHandler authorizationFailureHandler;
/**
* Constructs a {@code DefaultOAuth2AuthorizedClientManager} using the provided parameters.
@ -65,6 +101,9 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
this.clientRegistrationRepository = clientRegistrationRepository;
this.authorizedClientRepository = authorizedClientRepository;
this.contextAttributesMapper = new DefaultContextAttributesMapper();
this.authorizationSuccessHandler = new SaveAuthorizedClientOAuth2AuthorizationSuccessHandler(authorizedClientRepository);
this.authorizationFailureHandler = new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(authorizedClientRepository);
}
@Nullable
@ -105,9 +144,17 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
})
.build();
authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
try {
authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
} catch (OAuth2AuthorizationException ex) {
this.authorizationFailureHandler.onAuthorizationFailure(
ex, principal, createAttributes(servletRequest, servletResponse));
throw ex;
}
if (authorizedClient != null) {
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, servletRequest, servletResponse);
this.authorizationSuccessHandler.onAuthorizationSuccess(
authorizedClient, principal, createAttributes(servletRequest, servletResponse));
} else {
// In the case of re-authorization, the returned `authorizedClient` may be null if re-authorization is not supported.
// For these cases, return the provided `authorizationContext.authorizedClient`.
@ -119,12 +166,19 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
return authorizedClient;
}
private static Map<String, Object> createAttributes(HttpServletRequest servletRequest, HttpServletResponse servletResponse) {
Map<String, Object> attributes = new HashMap<>();
attributes.put(HttpServletRequest.class.getName(), servletRequest);
attributes.put(HttpServletResponse.class.getName(), servletResponse);
return attributes;
}
private static HttpServletRequest getHttpServletRequestOrDefault(Map<String, Object> attributes) {
HttpServletRequest servletRequest = (HttpServletRequest) attributes.get(HttpServletRequest.class.getName());
if (servletRequest == null) {
RequestAttributes context = RequestContextHolder.getRequestAttributes();
if (context instanceof ServletRequestAttributes) {
servletRequest = ((ServletRequestAttributes) context).getRequest();
RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
if (requestAttributes instanceof ServletRequestAttributes) {
servletRequest = ((ServletRequestAttributes) requestAttributes).getRequest();
}
}
return servletRequest;
@ -133,9 +187,9 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
private static HttpServletResponse getHttpServletResponseOrDefault(Map<String, Object> attributes) {
HttpServletResponse servletResponse = (HttpServletResponse) attributes.get(HttpServletResponse.class.getName());
if (servletResponse == null) {
RequestAttributes context = RequestContextHolder.getRequestAttributes();
if (context instanceof ServletRequestAttributes) {
servletResponse = ((ServletRequestAttributes) context).getResponse();
RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
if (requestAttributes instanceof ServletRequestAttributes) {
servletResponse = ((ServletRequestAttributes) requestAttributes).getResponse();
}
}
return servletResponse;
@ -163,6 +217,36 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
this.contextAttributesMapper = contextAttributesMapper;
}
/**
* Sets the {@link OAuth2AuthorizationSuccessHandler} that handles successful authorizations.
*
* <p>
* A {@link SaveAuthorizedClientOAuth2AuthorizationSuccessHandler} is used by default.
*
* @param authorizationSuccessHandler the {@link OAuth2AuthorizationSuccessHandler} that handles successful authorizations
* @see SaveAuthorizedClientOAuth2AuthorizationSuccessHandler
* @since 5.3
*/
public void setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler authorizationSuccessHandler) {
Assert.notNull(authorizationSuccessHandler, "authorizationSuccessHandler cannot be null");
this.authorizationSuccessHandler = authorizationSuccessHandler;
}
/**
* Sets the {@link OAuth2AuthorizationFailureHandler} that handles authorization failures.
*
* <p>
* A {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} is used by default.
*
* @param authorizationFailureHandler the {@link OAuth2AuthorizationFailureHandler} that handles authorization failures
* @see RemoveAuthorizedClientOAuth2AuthorizationFailureHandler
* @since 5.3
*/
public void setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler authorizationFailureHandler) {
Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null");
this.authorizationFailureHandler = authorizationFailureHandler;
}
/**
* The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}.
*/

View File

@ -0,0 +1,169 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.util.Assert;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
/**
* An {@link OAuth2AuthorizationFailureHandler} that removes an {@link OAuth2AuthorizedClient}
* from an {@link OAuth2AuthorizedClientRepository} or {@link OAuth2AuthorizedClientService}
* for a specific set of OAuth 2.0 error codes.
*
* @author Joe Grandja
* @since 5.3
* @see OAuth2AuthorizedClient
* @see OAuth2AuthorizedClientRepository
* @see OAuth2AuthorizedClientService
*/
public class RemoveAuthorizedClientOAuth2AuthorizationFailureHandler implements OAuth2AuthorizationFailureHandler {
/**
* The default OAuth 2.0 error codes that will trigger removal of an {@link OAuth2AuthorizedClient}.
* @see OAuth2ErrorCodes
*/
public static final Set<String> DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList(
/*
* Returned from Resource Servers when an access token provided is expired, revoked,
* malformed, or invalid for other reasons.
*
* Note that this is needed because ServletOAuth2AuthorizedClientExchangeFilterFunction
* delegates this type of failure received from a Resource Server
* to this failure handler.
*/
OAuth2ErrorCodes.INVALID_TOKEN,
/*
* Returned from Authorization Servers when the authorization grant or refresh token is invalid, expired, revoked,
* does not match the redirection URI used in the authorization request, or was issued to another client.
*/
OAuth2ErrorCodes.INVALID_GRANT
)));
/**
* The OAuth 2.0 error codes which will trigger removal of an {@link OAuth2AuthorizedClient}.
* @see OAuth2ErrorCodes
*/
private final Set<String> removeAuthorizedClientErrorCodes;
/**
* A delegate that removes an {@link OAuth2AuthorizedClient} from a
* {@link OAuth2AuthorizedClientRepository} or {@link OAuth2AuthorizedClientService}
* if the error code is one of the {@link #removeAuthorizedClientErrorCodes}.
*/
private final OAuth2AuthorizedClientRemover delegate;
@FunctionalInterface
private interface OAuth2AuthorizedClientRemover {
void removeAuthorizedClient(String clientRegistrationId, Authentication principal, Map<String, Object> attributes);
}
/**
* Constructs a {@code RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} using the provided parameters.
*
* @param authorizedClientRepository the repository from which authorized clients will be removed
* if the error code is one of the {@link #DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES}.
*/
public RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(OAuth2AuthorizedClientRepository authorizedClientRepository) {
this(authorizedClientRepository, DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES);
}
/**
* Constructs a {@code RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} using the provided parameters.
*
* @param authorizedClientRepository the repository from which authorized clients will be removed
* if the error code is one of the {@link #removeAuthorizedClientErrorCodes}.
* @param removeAuthorizedClientErrorCodes the OAuth 2.0 error codes which will trigger removal of an authorized client.
* @see OAuth2ErrorCodes
*/
public RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(
OAuth2AuthorizedClientRepository authorizedClientRepository,
Set<String> removeAuthorizedClientErrorCodes) {
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
Assert.notNull(removeAuthorizedClientErrorCodes, "removeAuthorizedClientErrorCodes cannot be null");
this.removeAuthorizedClientErrorCodes = Collections.unmodifiableSet(new HashSet<>(removeAuthorizedClientErrorCodes));
this.delegate = (clientRegistrationId, principal, attributes) ->
authorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal,
(HttpServletRequest) attributes.get(HttpServletRequest.class.getName()),
(HttpServletResponse) attributes.get(HttpServletResponse.class.getName()));
}
/**
* Constructs a {@code RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} using the provided parameters.
*
* @param authorizedClientService the service from which authorized clients will be removed
* if the error code is one of the {@link #DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES}.
*/
public RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(OAuth2AuthorizedClientService authorizedClientService) {
this(authorizedClientService, DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES);
}
/**
* Constructs a {@code RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} using the provided parameters.
*
* @param authorizedClientService the service from which authorized clients will be removed
* if the error code is one of the {@link #removeAuthorizedClientErrorCodes}.
* @param removeAuthorizedClientErrorCodes the OAuth 2.0 error codes which will trigger removal of an authorized client.
* @see OAuth2ErrorCodes
*/
public RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(
OAuth2AuthorizedClientService authorizedClientService,
Set<String> removeAuthorizedClientErrorCodes) {
Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
Assert.notNull(removeAuthorizedClientErrorCodes, "removeAuthorizedClientErrorCodes cannot be null");
this.removeAuthorizedClientErrorCodes = Collections.unmodifiableSet(new HashSet<>(removeAuthorizedClientErrorCodes));
this.delegate = (clientRegistrationId, principal, attributes) ->
authorizedClientService.removeAuthorizedClient(clientRegistrationId, principal.getName());
}
@Override
public void onAuthorizationFailure(OAuth2AuthorizationException authorizationException,
Authentication principal, Map<String, Object> attributes) {
if (authorizationException instanceof ClientAuthorizationException &&
hasRemovalErrorCode(authorizationException)) {
ClientAuthorizationException clientAuthorizationException = (ClientAuthorizationException) authorizationException;
this.delegate.removeAuthorizedClient(
clientAuthorizationException.getClientRegistrationId(), principal, attributes);
}
}
/**
* Returns true if the given exception has an error code that
* indicates that the authorized client should be removed.
*
* @param authorizationException the exception that caused the authorization failure
* @return true if the given exception has an error code that
* indicates that the authorized client should be removed.
*/
private boolean hasRemovalErrorCode(OAuth2AuthorizationException authorizationException) {
return this.removeAuthorizedClientErrorCodes.contains(authorizationException.getError().getErrorCode());
}
}

View File

@ -0,0 +1,77 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizationSuccessHandler;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.util.Assert;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Map;
/**
* An {@link OAuth2AuthorizationSuccessHandler} that saves an {@link OAuth2AuthorizedClient}
* in an {@link OAuth2AuthorizedClientRepository} or {@link OAuth2AuthorizedClientService}.
*
* @author Joe Grandja
* @since 5.3
* @see OAuth2AuthorizedClient
* @see OAuth2AuthorizedClientRepository
* @see OAuth2AuthorizedClientService
*/
public class SaveAuthorizedClientOAuth2AuthorizationSuccessHandler implements OAuth2AuthorizationSuccessHandler {
/**
* A delegate that saves an {@link OAuth2AuthorizedClient} in an
* {@link OAuth2AuthorizedClientRepository} or {@link OAuth2AuthorizedClientService}.
*/
private final OAuth2AuthorizationSuccessHandler delegate;
/**
* Constructs a {@code SaveAuthorizedClientOAuth2AuthorizationSuccessHandler} using the provided parameters.
*
* @param authorizedClientRepository The repository in which authorized clients will be saved.
*/
public SaveAuthorizedClientOAuth2AuthorizationSuccessHandler(
OAuth2AuthorizedClientRepository authorizedClientRepository) {
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
this.delegate = (authorizedClient, principal, attributes) ->
authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal,
(HttpServletRequest) attributes.get(HttpServletRequest.class.getName()),
(HttpServletResponse) attributes.get(HttpServletResponse.class.getName()));
}
/**
* Constructs a {@code SaveAuthorizedClientOAuth2AuthorizationSuccessHandler} using the provided parameters.
*
* @param authorizedClientService The service in which authorized clients will be saved.
*/
public SaveAuthorizedClientOAuth2AuthorizationSuccessHandler(
OAuth2AuthorizedClientService authorizedClientService) {
Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
this.delegate = (authorizedClient, principal, attributes) ->
authorizedClientService.saveAuthorizedClient(authorizedClient, principal);
}
@Override
public void onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient,
Authentication principal, Map<String, Object> attributes) {
this.delegate.onAuthorizationSuccess(authorizedClient, principal, attributes);
}
}

View File

@ -13,15 +13,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.reactive.function.client;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
@ -35,7 +38,13 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.RemoveAuthorizedClientOAuth2AuthorizationFailureHandler;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
@ -44,6 +53,7 @@ import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import org.springframework.web.reactive.function.client.ExchangeFunction;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.client.WebClientResponseException;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import reactor.util.context.Context;
@ -52,18 +62,25 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.time.Duration;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the
* token as a Bearer Token. It also provides mechanisms for looking up the {@link OAuth2AuthorizedClient}. This class is
* intended to be used in a servlet environment.
* Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth 2.0 requests
* by including the {@link OAuth2AuthorizedClient#getAccessToken() access token} as a bearer token.
*
* <p>
* <b>NOTE:</b>This class is intended to be used in a {@code Servlet} environment.
*
* <p>
* Example usage:
*
* <pre>
* ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2 = new ServletOAuth2AuthorizedClientExchangeFilterFunction(clientRegistrationRepository, authorizedClientRepository);
* ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2 = new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager);
* WebClient webClient = WebClient.builder()
* .apply(oauth2.oauth2Configuration())
* .build();
@ -76,23 +93,35 @@ import java.util.function.Consumer;
* .bodyToMono(String.class);
* </pre>
*
* An attempt to automatically refresh the token will be made if all of the following
* are true:
* <h3>Authentication and Authorization Failures</h3>
*
* <ul>
* <li>The {@link OAuth2AuthorizedClientManager} is not null</li>
* <li>A refresh token is present on the {@link OAuth2AuthorizedClient}</li>
* <li>The access token is expired</li>
* <li>The {@link SecurityContextHolder} will be used to attempt to save
* the token. If it is empty, then the principal name on the {@link OAuth2AuthorizedClient}
* will be used to create an Authentication for saving.</li>
* </ul>
* <p>
* Since 5.3, this filter function has the ability to forward authentication (HTTP 401 Unauthorized)
* and authorization (HTTP 403 Forbidden) failures from an OAuth 2.0 Resource Server
* to a {@link OAuth2AuthorizationFailureHandler}.
* A {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} can be used
* to remove the cached {@link OAuth2AuthorizedClient}, so that future requests will result
* in a new token being retrieved from an Authorization Server, and sent to the Resource Server.
*
* <p>
* If the {@link #ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)}
* constructor is used, a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler}
* will be configured automatically.
*
* <p>
* If the {@link #ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)}
* constructor is used, a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler}
* will <em>NOT</em> be configured automatically.
* It is recommended that you configure one via {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)}.
*
* @author Rob Winch
* @author Joe Grandja
* @author Roman Matiushchenko
* @since 5.1
* @see OAuth2AuthorizedClientManager
* @see DefaultOAuth2AuthorizedClientManager
* @see OAuth2AuthorizedClientProvider
* @see OAuth2AuthorizedClientProviderBuilder
*/
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
@ -103,6 +132,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
* The request attribute name used to locate the {@link OAuth2AuthorizedClient}.
*/
private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName();
private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient.class.getName().concat(".CLIENT_REGISTRATION_ID");
private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName();
private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
@ -125,35 +155,75 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
private String defaultClientRegistrationId;
private ClientResponseHandler clientResponseHandler;
@FunctionalInterface
private interface ClientResponseHandler {
Mono<ClientResponse> handleResponse(ClientRequest request, Mono<ClientResponse> response);
}
public ServletOAuth2AuthorizedClientExchangeFilterFunction() {
}
/**
* Constructs a {@code ServletOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters.
*
* <p>
* When this constructor is used, authentication (HTTP 401) and authorization (HTTP 403)
* failures returned from an OAuth 2.0 Resource Server will <em>NOT</em> be forwarded to an
* {@link OAuth2AuthorizationFailureHandler}.
* Therefore, future requests to the Resource Server will most likely use the same (likely invalid) token,
* resulting in the same errors returned from the Resource Server.
* It is recommended to configure a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler}
* via {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)}
* so that authentication and authorization failures returned from a Resource Server
* will result in removing the authorized client, so that a new token is retrieved for future requests.
*
* @since 5.2
* @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which manages the authorized client(s)
*/
public ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager authorizedClientManager) {
Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null");
this.authorizedClientManager = authorizedClientManager;
this.clientResponseHandler = (request, responseMono) -> responseMono;
}
/**
* Constructs a {@code ServletOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters.
*
* <p>
* Since 5.3, when this constructor is used, authentication (HTTP 401)
* and authorization (HTTP 403) failures returned from an OAuth 2.0 Resource Server
* will be forwarded to a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler},
* which will potentially remove the {@link OAuth2AuthorizedClient} from the given
* {@link OAuth2AuthorizedClientRepository}, depending on the OAuth 2.0 error code returned.
* Authentication failures returned from an OAuth 2.0 Resource Server typically indicate
* that the token is invalid, and should not be used in future requests.
* Removing the authorized client from the repository will ensure that the existing
* token will not be sent for future requests to the Resource Server,
* and a new token is retrieved from the Authorization Server and used for
* future requests to the Resource Server.
*
* @param clientRegistrationRepository the repository of client registrations
* @param authorizedClientRepository the repository of authorized clients
*/
public ServletOAuth2AuthorizedClientExchangeFilterFunction(
ClientRegistrationRepository clientRegistrationRepository,
OAuth2AuthorizedClientRepository authorizedClientRepository) {
this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository);
OAuth2AuthorizationFailureHandler authorizationFailureHandler =
new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(authorizedClientRepository);
this.authorizedClientManager = createDefaultAuthorizedClientManager(
clientRegistrationRepository, authorizedClientRepository, authorizationFailureHandler);
this.defaultAuthorizedClientManager = true;
this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler);
}
private static OAuth2AuthorizedClientManager createDefaultAuthorizedClientManager(
ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) {
ClientRegistrationRepository clientRegistrationRepository,
OAuth2AuthorizedClientRepository authorizedClientRepository,
OAuth2AuthorizationFailureHandler authorizationFailureHandler) {
OAuth2AuthorizedClientProvider authorizedClientProvider =
OAuth2AuthorizedClientProviderBuilder.builder()
@ -165,6 +235,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager(
clientRegistrationRepository, authorizedClientRepository);
authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
authorizedClientManager.setAuthorizationFailureHandler(authorizationFailureHandler);
return authorizedClientManager;
}
@ -333,19 +404,47 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
updateDefaultAuthorizedClientManager();
}
/**
* Sets the {@link OAuth2AuthorizationFailureHandler} that handles
* authentication and authorization failures when communicating
* to the OAuth 2.0 Resource Server.
*
* <p>
* For example, a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler}
* is typically used to remove the cached {@link OAuth2AuthorizedClient},
* so that the same token is no longer used in future requests to the Resource Server.
*
* <p>
* The failure handler used by default depends on which constructor was used
* to construct this {@link ServletOAuth2AuthorizedClientExchangeFilterFunction}.
* See the constructors for more details.
*
* @param authorizationFailureHandler the {@link OAuth2AuthorizationFailureHandler} that handles authentication and authorization failures
* @since 5.3
*/
public void setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler authorizationFailureHandler) {
Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null");
this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler);
}
@Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
return mergeRequestAttributesIfNecessary(request)
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
.flatMap(req -> authorizedClient(getOAuth2AuthorizedClient(req.attributes()), req))
.flatMap(req -> reauthorizeClient(getOAuth2AuthorizedClient(req.attributes()), req))
.switchIfEmpty(Mono.defer(() ->
mergeRequestAttributesIfNecessary(request)
.filter(req -> resolveClientRegistrationId(req) != null)
.flatMap(req -> authorizeClient(resolveClientRegistrationId(req), req))
))
.map(authorizedClient -> bearer(request, authorizedClient))
.flatMap(next::exchange)
.switchIfEmpty(Mono.defer(() -> next.exchange(request)));
.flatMap(requestWithBearer -> exchangeAndHandleResponse(requestWithBearer, next))
.switchIfEmpty(Mono.defer(() -> exchangeAndHandleResponse(request, next)));
}
private Mono<ClientResponse> exchangeAndHandleResponse(ClientRequest request, ExchangeFunction next) {
return next.exchange(request)
.transform(responseMono -> this.clientResponseHandler.handleResponse(request, responseMono));
}
private Mono<ClientRequest> mergeRequestAttributesIfNecessary(ClientRequest request) {
@ -443,13 +542,14 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
});
OAuth2AuthorizeRequest authorizeRequest = builder.build();
// NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.boundedElastic())
// NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.boundedElastic())
// NOTE:
// 'authorizedClientManager.authorize()' needs to be executed
// on a dedicated thread via subscribeOn(Schedulers.boundedElastic())
// since it performs a blocking I/O operation using RestTemplate internally
return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(authorizeRequest)).subscribeOn(Schedulers.boundedElastic());
}
private Mono<OAuth2AuthorizedClient> authorizedClient(OAuth2AuthorizedClient authorizedClient, ClientRequest request) {
private Mono<OAuth2AuthorizedClient> reauthorizeClient(OAuth2AuthorizedClient authorizedClient, ClientRequest request) {
if (this.authorizedClientManager == null) {
return Mono.just(authorizedClient);
}
@ -472,7 +572,9 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
});
OAuth2AuthorizeRequest reauthorizeRequest = builder.build();
// NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.boundedElastic())
// NOTE:
// 'authorizedClientManager.authorize()' needs to be executed
// on a dedicated thread via subscribeOn(Schedulers.boundedElastic())
// since it performs a blocking I/O operation using RestTemplate internally
return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(reauthorizeRequest)).subscribeOn(Schedulers.boundedElastic());
}
@ -480,6 +582,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {
return ClientRequest.from(request)
.headers(headers -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue()))
.attributes(oauth2AuthorizedClient(authorizedClient))
.build();
}
@ -550,4 +653,183 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
return new UnsupportedOperationException("Not Supported");
}
}
/**
* Forwards authentication and authorization failures to an
* {@link OAuth2AuthorizationFailureHandler}.
*
* @since 5.3
*/
private static class AuthorizationFailureForwarder implements ClientResponseHandler {
/**
* A map of HTTP status code to OAuth 2.0 error code for
* HTTP status codes that should be interpreted as
* authentication or authorization failures.
*/
private final Map<Integer, String> httpStatusToOAuth2ErrorCodeMap;
/**
* The {@link OAuth2AuthorizationFailureHandler} to notify
* when an authentication/authorization failure occurs.
*/
private final OAuth2AuthorizationFailureHandler authorizationFailureHandler;
private AuthorizationFailureForwarder(OAuth2AuthorizationFailureHandler authorizationFailureHandler) {
Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null");
this.authorizationFailureHandler = authorizationFailureHandler;
Map<Integer, String> httpStatusToOAuth2Error = new HashMap<>();
httpStatusToOAuth2Error.put(HttpStatus.UNAUTHORIZED.value(), OAuth2ErrorCodes.INVALID_TOKEN);
httpStatusToOAuth2Error.put(HttpStatus.FORBIDDEN.value(), OAuth2ErrorCodes.INSUFFICIENT_SCOPE);
this.httpStatusToOAuth2ErrorCodeMap = Collections.unmodifiableMap(httpStatusToOAuth2Error);
}
@Override
public Mono<ClientResponse> handleResponse(ClientRequest request, Mono<ClientResponse> responseMono) {
return responseMono
.flatMap(response -> handleResponse(request, response)
.thenReturn(response))
.onErrorResume(WebClientResponseException.class, e -> handleWebClientResponseException(request, e)
.then(Mono.error(e)))
.onErrorResume(OAuth2AuthorizationException.class, e -> handleAuthorizationException(request, e)
.then(Mono.error(e)));
}
private Mono<Void> handleResponse(ClientRequest request, ClientResponse response) {
return Mono.justOrEmpty(resolveErrorIfPossible(response))
.flatMap(oauth2Error -> {
Map<String, Object> attrs = request.attributes();
OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs);
if (authorizedClient == null) {
return Mono.empty();
}
ClientAuthorizationException authorizationException = new ClientAuthorizationException(
oauth2Error, authorizedClient.getClientRegistration().getRegistrationId());
Authentication principal = new PrincipalNameAuthentication(authorizedClient.getPrincipalName());
HttpServletRequest servletRequest = getRequest(attrs);
HttpServletResponse servletResponse = getResponse(attrs);
return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse);
});
}
private OAuth2Error resolveErrorIfPossible(ClientResponse response) {
// Try to resolve from 'WWW-Authenticate' header
if (!response.headers().header(HttpHeaders.WWW_AUTHENTICATE).isEmpty()) {
String wwwAuthenticateHeader = response.headers().header(HttpHeaders.WWW_AUTHENTICATE).get(0);
Map<String, String> authParameters = parseAuthParameters(wwwAuthenticateHeader);
if (authParameters.containsKey(OAuth2ParameterNames.ERROR)) {
return new OAuth2Error(
authParameters.get(OAuth2ParameterNames.ERROR),
authParameters.get(OAuth2ParameterNames.ERROR_DESCRIPTION),
authParameters.get(OAuth2ParameterNames.ERROR_URI));
}
}
return resolveErrorIfPossible(response.rawStatusCode());
}
private OAuth2Error resolveErrorIfPossible(int statusCode) {
if (this.httpStatusToOAuth2ErrorCodeMap.containsKey(statusCode)) {
return new OAuth2Error(
this.httpStatusToOAuth2ErrorCodeMap.get(statusCode),
null,
"https://tools.ietf.org/html/rfc6750#section-3.1");
}
return null;
}
private Map<String, String> parseAuthParameters(String wwwAuthenticateHeader) {
return Stream.of(wwwAuthenticateHeader)
.filter(header -> !StringUtils.isEmpty(header))
.filter(header -> header.toLowerCase().startsWith("bearer"))
.map(header -> header.substring("bearer".length()))
.map(header -> header.split(","))
.flatMap(Stream::of)
.map(parameter -> parameter.split("="))
.filter(parameter -> parameter.length > 1)
.collect(Collectors.toMap(
parameters -> parameters[0].trim(),
parameters -> parameters[1].trim().replace("\"", "")));
}
/**
* Handles the given http status code returned from a resource server
* by notifying the authorization failure handler if the http status
* code is in the {@link #httpStatusToOAuth2ErrorCodeMap}.
*
* @param request the request being processed
* @param exception The root cause exception for the failure
* @return a {@link Mono} that completes empty after the authorization failure handler completes
*/
private Mono<Void> handleWebClientResponseException(ClientRequest request, WebClientResponseException exception) {
return Mono.justOrEmpty(resolveErrorIfPossible(exception.getRawStatusCode()))
.flatMap(oauth2Error -> {
Map<String, Object> attrs = request.attributes();
OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs);
if (authorizedClient == null) {
return Mono.empty();
}
ClientAuthorizationException authorizationException = new ClientAuthorizationException(
oauth2Error, authorizedClient.getClientRegistration().getRegistrationId(), exception);
Authentication principal = new PrincipalNameAuthentication(authorizedClient.getPrincipalName());
HttpServletRequest servletRequest = getRequest(attrs);
HttpServletResponse servletResponse = getResponse(attrs);
return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse);
});
}
/**
* Handles the given {@link OAuth2AuthorizationException} that occurred downstream
* by notifying the authorization failure handler.
*
* @param request the request being processed
* @param authorizationException the authorization exception to include in the failure event
* @return a {@link Mono} that completes empty after the authorization failure handler completes
*/
private Mono<Void> handleAuthorizationException(ClientRequest request, OAuth2AuthorizationException authorizationException) {
return Mono.justOrEmpty(request)
.flatMap(req -> {
Map<String, Object> attrs = req.attributes();
OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs);
if (authorizedClient == null) {
return Mono.empty();
}
Authentication principal = new PrincipalNameAuthentication(authorizedClient.getPrincipalName());
HttpServletRequest servletRequest = getRequest(attrs);
HttpServletResponse servletResponse = getResponse(attrs);
return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse);
});
}
/**
* Delegates the failed authorization to the {@link OAuth2AuthorizationFailureHandler}.
*
* @param exception the {@link OAuth2AuthorizationException} to include in the failure event
* @param principal the principal associated with the failed authorization attempt
* @param servletRequest the currently active {@code HttpServletRequest}
* @param servletResponse the currently active {@code HttpServletResponse}
* @return a {@link Mono} that completes empty after the {@link OAuth2AuthorizationFailureHandler} completes
*/
private Mono<Void> handleAuthorizationFailure(OAuth2AuthorizationException exception,
Authentication principal, HttpServletRequest servletRequest, HttpServletResponse servletResponse) {
Runnable runnable = () -> this.authorizationFailureHandler.onAuthorizationFailure(
exception, principal, createAttributes(servletRequest, servletResponse));
return Mono.fromRunnable(runnable).subscribeOn(Schedulers.boundedElastic()).then();
}
private static Map<String, Object> createAttributes(HttpServletRequest servletRequest, HttpServletResponse servletResponse) {
Map<String, Object> attributes = new HashMap<>();
attributes.put(HttpServletRequest.class.getName(), servletRequest);
attributes.put(HttpServletResponse.class.getName(), servletResponse);
return attributes;
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -23,6 +23,10 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.RemoveAuthorizedClientOAuth2AuthorizationFailureHandler;
import org.springframework.security.oauth2.client.web.SaveAuthorizedClientOAuth2AuthorizationSuccessHandler;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@ -30,10 +34,16 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import java.util.function.Function;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
/**
* Tests for {@link AuthorizedClientServiceOAuth2AuthorizedClientManager}.
@ -45,6 +55,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
private OAuth2AuthorizedClientService authorizedClientService;
private OAuth2AuthorizedClientProvider authorizedClientProvider;
private Function contextAttributesMapper;
private OAuth2AuthorizationSuccessHandler authorizationSuccessHandler;
private OAuth2AuthorizationFailureHandler authorizationFailureHandler;
private AuthorizedClientServiceOAuth2AuthorizedClientManager authorizedClientManager;
private ClientRegistration clientRegistration;
private Authentication principal;
@ -58,10 +70,14 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
this.authorizedClientService = mock(OAuth2AuthorizedClientService.class);
this.authorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class);
this.contextAttributesMapper = mock(Function.class);
this.authorizationSuccessHandler = spy(new SaveAuthorizedClientOAuth2AuthorizationSuccessHandler(this.authorizedClientService));
this.authorizationFailureHandler = spy(new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(this.authorizedClientService));
this.authorizedClientManager = new AuthorizedClientServiceOAuth2AuthorizedClientManager(
this.clientRegistrationRepository, this.authorizedClientService);
this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider);
this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper);
this.authorizedClientManager.setAuthorizationSuccessHandler(this.authorizationSuccessHandler);
this.authorizedClientManager.setAuthorizationFailureHandler(this.authorizationFailureHandler);
this.clientRegistration = TestClientRegistrations.clientRegistration().build();
this.principal = new TestingAuthenticationToken("principal", "password");
this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(),
@ -97,6 +113,20 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
.hasMessage("contextAttributesMapper cannot be null");
}
@Test
public void setAuthorizationSuccessHandlerWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizationSuccessHandler cannot be null");
}
@Test
public void setAuthorizationFailureHandlerWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizationFailureHandler cannot be null");
}
@Test
public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.authorize(null))
@ -134,8 +164,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isNull();
verify(this.authorizedClientService, never()).saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), eq(this.principal));
verifyNoInteractions(this.authorizationSuccessHandler);
verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any());
}
@SuppressWarnings("unchecked")
@ -160,6 +190,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(this.authorizedClient);
verify(this.authorizationSuccessHandler).onAuthorizationSuccess(
eq(this.authorizedClient), eq(this.principal), any());
verify(this.authorizedClientService).saveAuthorizedClient(
eq(this.authorizedClient), eq(this.principal));
}
@ -192,6 +224,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(reauthorizedClient);
verify(this.authorizationSuccessHandler).onAuthorizationSuccess(
eq(reauthorizedClient), eq(this.principal), any());
verify(this.authorizedClientService).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal));
}
@ -213,8 +247,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(this.authorizedClient);
verify(this.authorizedClientService, never()).saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), eq(this.principal));
verifyNoInteractions(this.authorizationSuccessHandler);
verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any());
}
@SuppressWarnings("unchecked")
@ -240,6 +274,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(reauthorizedClient);
verify(this.authorizationSuccessHandler).onAuthorizationSuccess(
eq(reauthorizedClient), eq(this.principal), any());
verify(this.authorizedClientService).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal));
}
@ -274,7 +310,52 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
assertThat(requestScopeAttribute).contains("read", "write");
assertThat(authorizedClient).isSameAs(reauthorizedClient);
verify(this.authorizationSuccessHandler).onAuthorizationSuccess(
eq(reauthorizedClient), eq(this.principal), any());
verify(this.authorizedClientService).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal));
}
@Test
public void reauthorizeWhenErrorCodeMatchThenRemoveAuthorizedClient() {
ClientAuthorizationException authorizationException = new ClientAuthorizationException(
new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null),
this.clientRegistration.getRegistrationId());
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
.thenThrow(authorizationException);
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.build();
assertThatCode(() -> this.authorizedClientManager.authorize(reauthorizeRequest))
.isEqualTo(authorizationException);
verify(this.authorizationFailureHandler).onAuthorizationFailure(
eq(authorizationException), eq(this.principal), any());
verify(this.authorizedClientService).removeAuthorizedClient(
eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName()));
}
@Test
public void reauthorizeWhenErrorCodeDoesNotMatchThenDoNotRemoveAuthorizedClient() {
ClientAuthorizationException authorizationException = new ClientAuthorizationException(
new OAuth2Error("non-matching-error-code", null, null),
this.clientRegistration.getRegistrationId());
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
.thenThrow(authorizationException);
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.build();
assertThatCode(() -> this.authorizedClientManager.authorize(reauthorizeRequest))
.isEqualTo(authorizationException);
verify(this.authorizationFailureHandler).onAuthorizationFailure(
eq(authorizationException), eq(this.principal), any());
verifyNoInteractions(this.authorizedClientService);
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,13 +22,18 @@ import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
import org.springframework.security.oauth2.client.OAuth2AuthorizationSuccessHandler;
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@ -41,8 +46,16 @@ import java.util.Map;
import java.util.function.Function;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
/**
* Tests for {@link DefaultOAuth2AuthorizedClientManager}.
@ -54,6 +67,8 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
private OAuth2AuthorizedClientRepository authorizedClientRepository;
private OAuth2AuthorizedClientProvider authorizedClientProvider;
private Function contextAttributesMapper;
private OAuth2AuthorizationSuccessHandler authorizationSuccessHandler;
private OAuth2AuthorizationFailureHandler authorizationFailureHandler;
private DefaultOAuth2AuthorizedClientManager authorizedClientManager;
private ClientRegistration clientRegistration;
private Authentication principal;
@ -69,10 +84,14 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);
this.authorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class);
this.contextAttributesMapper = mock(Function.class);
this.authorizationSuccessHandler = spy(new SaveAuthorizedClientOAuth2AuthorizationSuccessHandler(this.authorizedClientRepository));
this.authorizationFailureHandler = spy(new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(this.authorizedClientRepository));
this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager(
this.clientRegistrationRepository, this.authorizedClientRepository);
this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider);
this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper);
this.authorizedClientManager.setAuthorizationSuccessHandler(this.authorizationSuccessHandler);
this.authorizedClientManager.setAuthorizationFailureHandler(this.authorizationFailureHandler);
this.clientRegistration = TestClientRegistrations.clientRegistration().build();
this.principal = new TestingAuthenticationToken("principal", "password");
this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(),
@ -110,6 +129,20 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
.hasMessage("contextAttributesMapper cannot be null");
}
@Test
public void setAuthorizationSuccessHandlerWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizationSuccessHandler cannot be null");
}
@Test
public void setAuthorizationFailureHandlerWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizationFailureHandler cannot be null");
}
@Test
public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.authorize(null))
@ -176,8 +209,8 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isNull();
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.request), eq(this.response));
verifyNoInteractions(this.authorizationSuccessHandler);
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any(), any());
}
@SuppressWarnings("unchecked")
@ -206,6 +239,8 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(this.authorizedClient);
verify(this.authorizationSuccessHandler).onAuthorizationSuccess(
eq(this.authorizedClient), eq(this.principal), any());
verify(this.authorizedClientRepository).saveAuthorizedClient(
eq(this.authorizedClient), eq(this.principal), eq(this.request), eq(this.response));
}
@ -242,6 +277,8 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(reauthorizedClient);
verify(this.authorizationSuccessHandler).onAuthorizationSuccess(
eq(reauthorizedClient), eq(this.principal), any());
verify(this.authorizedClientRepository).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal), eq(this.request), eq(this.response));
}
@ -308,6 +345,7 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(this.authorizedClient);
verifyNoInteractions(this.authorizationSuccessHandler);
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.request), eq(this.response));
}
@ -339,6 +377,8 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(reauthorizedClient);
verify(this.authorizationSuccessHandler).onAuthorizationSuccess(
eq(reauthorizedClient), eq(this.principal), any());
verify(this.authorizedClientRepository).saveAuthorizedClient(
eq(reauthorizedClient), eq(this.principal), eq(this.request), eq(this.response));
}
@ -372,4 +412,55 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
assertThat(requestScopeAttribute).contains("read", "write");
}
@Test
public void reauthorizeWhenErrorCodeMatchThenRemoveAuthorizedClient() {
ClientAuthorizationException authorizationException = new ClientAuthorizationException(
new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null),
this.clientRegistration.getRegistrationId());
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
.thenThrow(authorizationException);
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.attributes(attrs -> {
attrs.put(HttpServletRequest.class.getName(), this.request);
attrs.put(HttpServletResponse.class.getName(), this.response);
})
.build();
assertThatCode(() -> this.authorizedClientManager.authorize(reauthorizeRequest))
.isEqualTo(authorizationException);
verify(this.authorizationFailureHandler).onAuthorizationFailure(
eq(authorizationException), eq(this.principal), any());
verify(this.authorizedClientRepository).removeAuthorizedClient(
eq(this.clientRegistration.getRegistrationId()), eq(this.principal), eq(this.request), eq(this.response));
}
@Test
public void reauthorizeWhenErrorCodeDoesNotMatchThenDoNotRemoveAuthorizedClient() {
ClientAuthorizationException authorizationException = new ClientAuthorizationException(
new OAuth2Error("non-matching-error-code", null, null),
this.clientRegistration.getRegistrationId());
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
.thenThrow(authorizationException);
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.attributes(attrs -> {
attrs.put(HttpServletRequest.class.getName(), this.request);
attrs.put(HttpServletResponse.class.getName(), this.response);
})
.build();
assertThatCode(() -> this.authorizedClientManager.authorize(reauthorizeRequest))
.isEqualTo(authorizationException);
verify(this.authorizationFailureHandler).onAuthorizationFailure(
eq(authorizationException), eq(this.principal), any());
verifyNoInteractions(this.authorizedClientRepository);
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,18 +15,6 @@
*/
package org.springframework.security.oauth2.client.web.reactive.function.client;
import java.net.URI;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@ -35,8 +23,6 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import reactor.util.context.Context;
import org.springframework.core.codec.ByteBufferEncoder;
import org.springframework.core.codec.CharSequenceEncoder;
import org.springframework.http.HttpHeaders;
@ -60,7 +46,9 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder;
@ -78,6 +66,9 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistr
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@ -89,16 +80,37 @@ import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFunction;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.client.WebClientResponseException;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.entry;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.http.HttpMethod.GET;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
@ -128,6 +140,14 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Mock
private OAuth2AccessTokenResponseClient<OAuth2PasswordGrantRequest> passwordTokenResponseClient;
@Mock
private OAuth2AuthorizationFailureHandler authorizationFailureHandler;
@Captor
private ArgumentCaptor<OAuth2AuthorizationException> authorizationExceptionCaptor;
@Captor
private ArgumentCaptor<Authentication> authenticationCaptor;
@Captor
private ArgumentCaptor<Map<String, Object>> attributesCaptor;
@Mock
private WebClient.RequestHeadersSpec<?> spec;
@Captor
private ArgumentCaptor<Consumer<Map<String, Object>>> attrs;
@ -167,7 +187,7 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager(
this.clientRegistrationRepository, this.authorizedClientRepository);
this.authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager);
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager);
}
@After
@ -233,7 +253,7 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
SecurityContextHolder.getContext().setAuthentication(this.authentication);
Map<String, Object> attrs = getDefaultRequestAttributes();
assertThat(getAuthentication(attrs)).isEqualTo(this.authentication);
verifyZeroInteractions(this.authorizedClientRepository);
verifyNoInteractions(this.authorizedClientRepository);
}
private Map<String, Object> getDefaultRequestAttributes() {
@ -647,6 +667,215 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
assertThat(getBody(request)).isEmpty();
}
@Test
public void filterWhenUnauthorizedThenInvokeFailureHandler() {
assertHttpStatusInvokesFailureHandler(HttpStatus.UNAUTHORIZED, OAuth2ErrorCodes.INVALID_TOKEN);
}
@Test
public void filterWhenForbiddenThenInvokeFailureHandler() {
assertHttpStatusInvokesFailureHandler(HttpStatus.FORBIDDEN, OAuth2ErrorCodes.INSUFFICIENT_SCOPE);
}
private void assertHttpStatusInvokesFailureHandler(HttpStatus httpStatus, String expectedErrorCode) {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration, "principalName", this.accessToken);
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.attributes(httpServletRequest(servletRequest))
.attributes(httpServletResponse(servletResponse))
.build();
when(this.exchange.getResponse().rawStatusCode()).thenReturn(httpStatus.value());
when(this.exchange.getResponse().headers()).thenReturn(mock(ClientResponse.Headers.class));
this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler);
this.function.filter(request, this.exchange).block();
verify(this.authorizationFailureHandler).onAuthorizationFailure(
this.authorizationExceptionCaptor.capture(),
this.authenticationCaptor.capture(),
this.attributesCaptor.capture());
assertThat(this.authorizationExceptionCaptor.getValue())
.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> {
assertThat(e.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId());
assertThat(e.getError().getErrorCode()).isEqualTo(expectedErrorCode);
assertThat(e).hasNoCause();
assertThat(e).hasMessageContaining(expectedErrorCode);
});
assertThat(this.authenticationCaptor.getValue().getName())
.isEqualTo(authorizedClient.getPrincipalName());
assertThat(this.attributesCaptor.getValue())
.containsExactly(
entry(HttpServletRequest.class.getName(), servletRequest),
entry(HttpServletResponse.class.getName(), servletResponse));
}
@Test
public void filterWhenWWWAuthenticateHeaderIncludesErrorThenInvokeFailureHandler() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration, "principalName", this.accessToken);
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.attributes(httpServletRequest(servletRequest))
.attributes(httpServletResponse(servletResponse))
.build();
String wwwAuthenticateHeader = "Bearer error=\"insufficient_scope\", " +
"error_description=\"The request requires higher privileges than provided by the access token.\", " +
"error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\"";
ClientResponse.Headers headers = mock(ClientResponse.Headers.class);
when(headers.header(eq(HttpHeaders.WWW_AUTHENTICATE)))
.thenReturn(Collections.singletonList(wwwAuthenticateHeader));
when(this.exchange.getResponse().headers()).thenReturn(headers);
this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler);
this.function.filter(request, this.exchange).block();
verify(this.authorizationFailureHandler).onAuthorizationFailure(
this.authorizationExceptionCaptor.capture(),
this.authenticationCaptor.capture(),
this.attributesCaptor.capture());
assertThat(this.authorizationExceptionCaptor.getValue())
.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> {
assertThat(e.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId());
assertThat(e.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INSUFFICIENT_SCOPE);
assertThat(e.getError().getDescription()).isEqualTo("The request requires higher privileges than provided by the access token.");
assertThat(e.getError().getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1");
assertThat(e).hasNoCause();
assertThat(e).hasMessageContaining(OAuth2ErrorCodes.INSUFFICIENT_SCOPE);
});
assertThat(this.authenticationCaptor.getValue().getName())
.isEqualTo(authorizedClient.getPrincipalName());
assertThat(this.attributesCaptor.getValue())
.containsExactly(
entry(HttpServletRequest.class.getName(), servletRequest),
entry(HttpServletResponse.class.getName(), servletResponse));
}
@Test
public void filterWhenUnauthorizedWithWebClientExceptionThenInvokeFailureHandler() {
assertHttpStatusWithWebClientExceptionInvokesFailureHandler(
HttpStatus.UNAUTHORIZED, OAuth2ErrorCodes.INVALID_TOKEN);
}
@Test
public void filterWhenForbiddenWithWebClientExceptionThenInvokeFailureHandler() {
assertHttpStatusWithWebClientExceptionInvokesFailureHandler(
HttpStatus.FORBIDDEN, OAuth2ErrorCodes.INSUFFICIENT_SCOPE);
}
private void assertHttpStatusWithWebClientExceptionInvokesFailureHandler(
HttpStatus httpStatus, String expectedErrorCode) {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration, "principalName", this.accessToken);
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.attributes(httpServletRequest(servletRequest))
.attributes(httpServletResponse(servletResponse))
.build();
WebClientResponseException exception = WebClientResponseException.create(
httpStatus.value(),
httpStatus.getReasonPhrase(),
HttpHeaders.EMPTY,
new byte[0],
StandardCharsets.UTF_8);
ExchangeFunction throwingExchangeFunction = r -> Mono.error(exception);
this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler);
assertThatCode(() -> this.function.filter(request, throwingExchangeFunction).block())
.isEqualTo(exception);
verify(this.authorizationFailureHandler).onAuthorizationFailure(
this.authorizationExceptionCaptor.capture(),
this.authenticationCaptor.capture(),
this.attributesCaptor.capture());
assertThat(this.authorizationExceptionCaptor.getValue())
.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> {
assertThat(e.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId());
assertThat(e.getError().getErrorCode()).isEqualTo(expectedErrorCode);
assertThat(e).hasCause(exception);
assertThat(e).hasMessageContaining(expectedErrorCode);
});
assertThat(this.authenticationCaptor.getValue().getName())
.isEqualTo(authorizedClient.getPrincipalName());
assertThat(this.attributesCaptor.getValue())
.containsExactly(
entry(HttpServletRequest.class.getName(), servletRequest),
entry(HttpServletResponse.class.getName(), servletResponse));
}
@Test
public void filterWhenAuthorizationExceptionThenInvokeFailureHandler() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration, "principalName", this.accessToken);
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.attributes(httpServletRequest(servletRequest))
.attributes(httpServletResponse(servletResponse))
.build();
OAuth2AuthorizationException authorizationException = new OAuth2AuthorizationException(
new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN));
ExchangeFunction throwingExchangeFunction = r -> Mono.error(authorizationException);
this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler);
assertThatCode(() -> this.function.filter(request, throwingExchangeFunction).block())
.isEqualTo(authorizationException);
verify(this.authorizationFailureHandler).onAuthorizationFailure(
this.authorizationExceptionCaptor.capture(),
this.authenticationCaptor.capture(),
this.attributesCaptor.capture());
assertThat(this.authorizationExceptionCaptor.getValue())
.isInstanceOfSatisfying(OAuth2AuthorizationException.class, e -> {
assertThat(e.getError().getErrorCode()).isEqualTo(authorizationException.getError().getErrorCode());
assertThat(e).hasNoCause();
assertThat(e).hasMessageContaining(OAuth2ErrorCodes.INVALID_TOKEN);
});
assertThat(this.authenticationCaptor.getValue().getName())
.isEqualTo(authorizedClient.getPrincipalName());
assertThat(this.attributesCaptor.getValue())
.containsExactly(
entry(HttpServletRequest.class.getName(), servletRequest),
entry(HttpServletResponse.class.getName(), servletResponse));
}
@Test
public void filterWhenOtherHttpStatusThenDoesNotInvokeFailureHandler() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration, "principalName", this.accessToken);
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.attributes(httpServletRequest(servletRequest))
.attributes(httpServletResponse(servletResponse))
.build();
when(this.exchange.getResponse().rawStatusCode()).thenReturn(HttpStatus.BAD_REQUEST.value());
when(this.exchange.getResponse().headers()).thenReturn(mock(ClientResponse.Headers.class));
this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler);
this.function.filter(request, this.exchange).block();
verifyNoInteractions(this.authorizationFailureHandler);
}
private Context context(HttpServletRequest servletRequest, HttpServletResponse servletResponse, Authentication authentication) {
Map<Object, Object> contextAttributes = new HashMap<>();
contextAttributes.put(HttpServletRequest.class, servletRequest);
@ -688,5 +917,4 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
request.body().insert(body, context).block();
return body.getBodyAsString().block();
}
}