Issue gh-7350
This commit is contained in:
Joe Grandja 2019-09-04 14:27:01 -04:00
parent 0ac8618eac
commit 409285fb3d
2 changed files with 35 additions and 2 deletions

View File

@ -27,6 +27,7 @@ import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder;
@ -49,11 +50,14 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.context.request.ServletWebRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
@ -285,6 +289,19 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager(
this.clientRegistrationRepository, this.authorizedClientRepository);
authorizedClientManager.setAuthorizedClientProvider(passwordAuthorizedClientProvider);
// Set custom contextAttributesMapper
authorizedClientManager.setContextAttributesMapper(authorizeRequest -> {
Map<String, Object> contextAttributes = new HashMap<>();
String username = authorizeRequest.getServletRequest().getParameter(OAuth2ParameterNames.USERNAME);
String password = authorizeRequest.getServletRequest().getParameter(OAuth2ParameterNames.PASSWORD);
if (StringUtils.hasText(username) && StringUtils.hasText(password)) {
contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username);
contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password);
}
return contextAttributes;
});
this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager);
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse

View File

@ -46,6 +46,7 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder;
@ -68,6 +69,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenRespon
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestOperations;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
@ -120,6 +122,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Captor
private ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor;
private DefaultOAuth2AuthorizedClientManager authorizedClientManager;
/**
* Used for get the attributes from defaultRequest.
*/
@ -148,9 +152,9 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
.clientCredentials(configurer -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient))
.password(configurer -> configurer.accessTokenResponseClient(this.passwordTokenResponseClient))
.build();
DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager(
this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager(
this.clientRegistrationRepository, this.authorizedClientRepository);
authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
this.authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager);
}
@ -459,6 +463,18 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
ClientRegistration registration = TestClientRegistrations.password().build();
when(this.clientRegistrationRepository.findByRegistrationId(eq(registration.getRegistrationId()))).thenReturn(registration);
// Set custom contextAttributesMapper
this.authorizedClientManager.setContextAttributesMapper(authorizeRequest -> {
Map<String, Object> contextAttributes = new HashMap<>();
String username = authorizeRequest.getServletRequest().getParameter(OAuth2ParameterNames.USERNAME);
String password = authorizeRequest.getServletRequest().getParameter(OAuth2ParameterNames.PASSWORD);
if (StringUtils.hasText(username) && StringUtils.hasText(password)) {
contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username);
contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password);
}
return contextAttributes;
});
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
servletRequest.setParameter(OAuth2ParameterNames.USERNAME, "username");
servletRequest.setParameter(OAuth2ParameterNames.PASSWORD, "password");