Add defaultOAuth2AuthorizedClient flag

Fixes: gh-5619
This commit is contained in:
Rob Winch 2018-07-31 14:44:40 -05:00
parent cecbc2175b
commit 1a65abd781
2 changed files with 33 additions and 3 deletions

View File

@ -109,12 +109,25 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
private OAuth2AuthorizedClientRepository authorizedClientRepository;
private boolean defaultOAuth2AuthorizedClient;
public ServletOAuth2AuthorizedClientExchangeFilterFunction() {}
public ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientRepository authorizedClientRepository) {
this.authorizedClientRepository = authorizedClientRepository;
}
/**
* If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
* recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
* resolved from the current Authentication.
* @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false.
* Default is false.
*/
public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient;
}
/**
* Configures the builder with {@link #defaultRequest()} and adds this as a {@link ExchangeFilterFunction}
* @return the {@link Consumer} to configure the builder
@ -251,13 +264,16 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
}
private void populateDefaultOAuth2AuthorizedClient(Map<String, Object> attrs) {
if (this.authorizedClientRepository == null || attrs.containsKey(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)) {
if (this.authorizedClientRepository == null
|| attrs.containsKey(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)) {
return;
}
Authentication authentication = getAuthentication(attrs);
String clientRegistrationId = getClientRegistrationId(attrs);
if (clientRegistrationId == null && authentication instanceof OAuth2AuthenticationToken) {
if (clientRegistrationId == null
&& this.defaultOAuth2AuthorizedClient
&& authentication instanceof OAuth2AuthenticationToken) {
clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId();
}
if (clientRegistrationId != null) {

View File

@ -207,8 +207,9 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
}
@Test
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() {
public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() {
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
this.function.setDefaultOAuth2AuthorizedClient(true);
OAuth2User user = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id");
@ -223,6 +224,19 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
verify(this.authorizedClientRepository).loadAuthorizedClient(eq(token.getAuthorizedClientRegistrationId()), any(), any());
}
@Test
public void defaultRequestOAuth2AuthorizedClientWhenDefaultFalseAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() {
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
OAuth2User user = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id");
authentication(token).accept(this.result);
Map<String, Object> attrs = getDefaultRequestAttributes();
assertThat(getOAuth2AuthorizedClient(attrs)).isNull();
}
@Test
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdThenIdIsExplicit() {
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);