Fix blocking in ServletOAuth2AuthorizedClientExchangeFilterFunction

Fixes gh-6589
This commit is contained in:
Joe Grandja 2019-06-21 10:50:38 -04:00
parent c05b0765c1
commit 4ca9e15595
5 changed files with 317 additions and 187 deletions

View File

@ -70,6 +70,7 @@ dependencyManagement {
dependency 'commons-lang:commons-lang:2.6'
dependency 'commons-logging:commons-logging:1.2'
dependency 'dom4j:dom4j:1.6.1'
dependency 'io.projectreactor.tools:blockhound:1.0.0.M4'
dependency 'javax.activation:activation:1.1.1'
dependency 'javax.annotation:jsr250-api:1.0'
dependency 'javax.inject:javax.inject:1'

View File

@ -17,6 +17,7 @@ dependencies {
testCompile 'com.fasterxml.jackson.core:jackson-databind'
testCompile 'io.projectreactor.netty:reactor-netty'
testCompile 'io.projectreactor:reactor-test'
testCompile 'io.projectreactor.tools:blockhound'
provided 'javax.servlet:javax.servlet-api'
}

View File

@ -50,6 +50,7 @@ import reactor.core.CoreSubscriber;
import reactor.core.publisher.Hooks;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Operators;
import reactor.core.scheduler.Schedulers;
import reactor.util.context.Context;
import javax.servlet.http.HttpServletRequest;
@ -258,7 +259,6 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
spec.attributes(attrs -> {
populateDefaultRequestResponse(attrs);
populateDefaultAuthentication(attrs);
populateDefaultOAuth2AuthorizedClient(attrs);
});
};
}
@ -349,20 +349,33 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
@Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
return Mono.just(request)
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
.switchIfEmpty(mergeRequestAttributesFromContext(request))
return mergeRequestAttributesIfNecessary(request)
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
.flatMap(req -> authorizedClient(getOAuth2AuthorizedClient(req.attributes()), req))
.switchIfEmpty(Mono.defer(() ->
mergeRequestAttributesIfNecessary(request)
.filter(req -> resolveClientRegistrationId(req) != null)
.flatMap(req -> authorizeClient(resolveClientRegistrationId(req), req))
))
.map(authorizedClient -> bearer(request, authorizedClient))
.flatMap(next::exchange)
.switchIfEmpty(next.exchange(request));
.switchIfEmpty(Mono.defer(() -> next.exchange(request)));
}
private Mono<ClientRequest> mergeRequestAttributesIfNecessary(ClientRequest request) {
if (!request.attribute(HTTP_SERVLET_REQUEST_ATTR_NAME).isPresent() ||
!request.attribute(HTTP_SERVLET_RESPONSE_ATTR_NAME).isPresent() ||
!request.attribute(AUTHENTICATION_ATTR_NAME).isPresent()) {
return mergeRequestAttributesFromContext(request);
} else {
return Mono.just(request);
}
}
private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest request) {
return Mono.just(ClientRequest.from(request))
.flatMap(builder -> Mono.subscriberContext()
.map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx))))
ClientRequest.Builder builder = ClientRequest.from(request);
return Mono.subscriberContext()
.map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx)))
.map(ClientRequest.Builder::build);
}
@ -376,7 +389,6 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) {
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME));
}
populateDefaultOAuth2AuthorizedClient(attrs);
}
private void populateDefaultRequestResponse(Map<String, Object> attrs) {
@ -403,32 +415,38 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
}
private void populateDefaultOAuth2AuthorizedClient(Map<String, Object> attrs) {
if (this.authorizedClientManager == null ||
attrs.containsKey(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)) {
return;
}
Authentication authentication = getAuthentication(attrs);
private String resolveClientRegistrationId(ClientRequest request) {
Map<String, Object> attrs = request.attributes();
String clientRegistrationId = getClientRegistrationId(attrs);
if (clientRegistrationId == null) {
clientRegistrationId = this.defaultClientRegistrationId;
}
Authentication authentication = getAuthentication(attrs);
if (clientRegistrationId == null
&& this.defaultOAuth2AuthorizedClient
&& authentication instanceof OAuth2AuthenticationToken) {
clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId();
}
if (clientRegistrationId != null) {
HttpServletRequest request = getRequest(attrs);
if (authentication == null) {
authentication = ANONYMOUS_AUTHENTICATION;
}
OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest(
clientRegistrationId, authentication, request, getResponse(attrs));
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
oauth2AuthorizedClient(authorizedClient).accept(attrs);
return clientRegistrationId;
}
private Mono<OAuth2AuthorizedClient> authorizeClient(String clientRegistrationId, ClientRequest request) {
if (this.authorizedClientManager == null) {
return Mono.empty();
}
Map<String, Object> attrs = request.attributes();
Authentication authentication = getAuthentication(attrs);
if (authentication == null) {
authentication = ANONYMOUS_AUTHENTICATION;
}
HttpServletRequest servletRequest = getRequest(attrs);
HttpServletResponse servletResponse = getResponse(attrs);
OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest(
clientRegistrationId, authentication, servletRequest, servletResponse);
// NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.elastic())
// since it performs a blocking I/O operation using RestTemplate internally
return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(authorizeRequest)).subscribeOn(Schedulers.elastic());
}
private Mono<OAuth2AuthorizedClient> authorizedClient(OAuth2AuthorizedClient authorizedClient, ClientRequest request) {
@ -444,7 +462,10 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
HttpServletResponse servletResponse = getResponse(attrs);
OAuth2AuthorizeRequest reauthorizeRequest = new OAuth2AuthorizeRequest(
authorizedClient, authentication, servletRequest, servletResponse);
return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(reauthorizeRequest));
// NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.elastic())
// since it performs a blocking I/O operation using RestTemplate internally
return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(reauthorizeRequest)).subscribeOn(Schedulers.elastic());
}
private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {

View File

@ -0,0 +1,268 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.reactive.function.client;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.After;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.blockhound.BlockHound;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashSet;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
/**
* @author Joe Grandja
*/
public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
private ClientRegistrationRepository clientRegistrationRepository;
private OAuth2AuthorizedClientRepository authorizedClientRepository;
private ServletOAuth2AuthorizedClientExchangeFilterFunction authorizedClientFilter;
private MockWebServer server;
private String serverUrl;
private WebClient webClient;
private Authentication authentication;
private MockHttpServletRequest request;
private MockHttpServletResponse response;
@BeforeClass
public static void setUpBlockingChecks() {
// IMPORTANT:
// Before enabling BlockHound, we need to white-list `java.lang.Class.getPackage()`.
// When the JVM loads `java.lang.Package.getSystemPackage()`, it attempts to
// `java.lang.Package.loadManifest()` which is blocking I/O and triggers BlockHound to error.
// NOTE: This is an issue with JDK 8. It's been tested on JDK 10 and works fine w/o this white-list.
BlockHound.builder()
.allowBlockingCallsInside(Class.class.getName(), "getPackage")
.install();
}
@Before
public void setUp() throws Exception {
this.clientRegistrationRepository = mock(ClientRegistrationRepository.class);
final OAuth2AuthorizedClientRepository delegate = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(
new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository));
this.authorizedClientRepository = spy(new OAuth2AuthorizedClientRepository() {
@Override
public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId, Authentication principal, HttpServletRequest request) {
return delegate.loadAuthorizedClient(clientRegistrationId, principal, request);
}
@Override
public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, HttpServletRequest request, HttpServletResponse response) {
delegate.saveAuthorizedClient(authorizedClient, principal, request, response);
}
@Override
public void removeAuthorizedClient(String clientRegistrationId, Authentication principal, HttpServletRequest request, HttpServletResponse response) {
delegate.removeAuthorizedClient(clientRegistrationId, principal, request, response);
}
});
this.authorizedClientFilter = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
this.clientRegistrationRepository, this.authorizedClientRepository);
this.authorizedClientFilter.afterPropertiesSet();
this.server = new MockWebServer();
this.server.start();
this.serverUrl = this.server.url("/").toString();
this.webClient = WebClient.builder()
.apply(this.authorizedClientFilter.oauth2Configuration())
.build();
this.authentication = new TestingAuthenticationToken("principal", "password");
SecurityContextHolder.getContext().setAuthentication(this.authentication);
this.request = new MockHttpServletRequest();
this.response = new MockHttpServletResponse();
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(this.request, this.response));
}
@After
public void cleanup() throws Exception {
this.authorizedClientFilter.destroy();
this.server.shutdown();
SecurityContextHolder.clearContext();
RequestContextHolder.resetRequestAttributes();
}
@Test
public void requestWhenNotAuthorizedThenAuthorizeAndSendRequest() {
String accessTokenResponse = "{\n" +
" \"access_token\": \"access-token-1234\",\n" +
" \"token_type\": \"bearer\",\n" +
" \"expires_in\": \"3600\",\n" +
" \"scope\": \"read write\"\n" +
"}\n";
String clientResponse = "{\n" +
" \"attribute1\": \"value1\",\n" +
" \"attribute2\": \"value2\"\n" +
"}\n";
this.server.enqueue(jsonResponse(accessTokenResponse));
this.server.enqueue(jsonResponse(clientResponse));
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl).build();
when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(clientRegistration);
this.webClient
.get()
.uri(this.serverUrl)
.attributes(clientRegistrationId(clientRegistration.getRegistrationId()))
.retrieve()
.bodyToMono(String.class)
.block();
assertThat(this.server.getRequestCount()).isEqualTo(2);
ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class);
verify(this.authorizedClientRepository).saveAuthorizedClient(
authorizedClientCaptor.capture(), eq(this.authentication), eq(this.request), eq(this.response));
assertThat(authorizedClientCaptor.getValue().getClientRegistration()).isSameAs(clientRegistration);
}
@Test
public void requestWhenAuthorizedButExpiredThenRefreshAndSendRequest() {
String accessTokenResponse = "{\n" +
" \"access_token\": \"refreshed-access-token\",\n" +
" \"token_type\": \"bearer\",\n" +
" \"expires_in\": \"3600\"\n" +
"}\n";
String clientResponse = "{\n" +
" \"attribute1\": \"value1\",\n" +
" \"attribute2\": \"value2\"\n" +
"}\n";
this.server.enqueue(jsonResponse(accessTokenResponse));
this.server.enqueue(jsonResponse(clientResponse));
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(this.serverUrl).build();
when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(clientRegistration);
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant expiresAt = issuedAt.plus(Duration.ofHours(1));
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
"expired-access-token", issuedAt, expiresAt, new HashSet<>(Arrays.asList("read", "write")));
OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
clientRegistration, this.authentication.getName(), accessToken, refreshToken);
doReturn(authorizedClient).when(this.authorizedClientRepository).loadAuthorizedClient(
eq(clientRegistration.getRegistrationId()), eq(this.authentication), eq(this.request));
this.webClient
.get()
.uri(this.serverUrl)
.attributes(clientRegistrationId(clientRegistration.getRegistrationId()))
.retrieve()
.bodyToMono(String.class)
.block();
assertThat(this.server.getRequestCount()).isEqualTo(2);
ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class);
verify(this.authorizedClientRepository).saveAuthorizedClient(
authorizedClientCaptor.capture(), eq(this.authentication), eq(this.request), eq(this.response));
OAuth2AuthorizedClient refreshedAuthorizedClient = authorizedClientCaptor.getValue();
assertThat(refreshedAuthorizedClient.getClientRegistration()).isSameAs(clientRegistration);
assertThat(refreshedAuthorizedClient.getAccessToken().getTokenValue()).isEqualTo("refreshed-access-token");
}
@Test
public void requestMultipleWhenNoneAuthorizedThenAuthorizeAndSendRequest() {
String accessTokenResponse = "{\n" +
" \"access_token\": \"access-token-1234\",\n" +
" \"token_type\": \"bearer\",\n" +
" \"expires_in\": \"3600\",\n" +
" \"scope\": \"read write\"\n" +
"}\n";
String clientResponse = "{\n" +
" \"attribute1\": \"value1\",\n" +
" \"attribute2\": \"value2\"\n" +
"}\n";
// Client 1
this.server.enqueue(jsonResponse(accessTokenResponse));
this.server.enqueue(jsonResponse(clientResponse));
ClientRegistration clientRegistration1 = TestClientRegistrations.clientCredentials()
.registrationId("client-1").tokenUri(this.serverUrl).build();
when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration1.getRegistrationId()))).thenReturn(clientRegistration1);
// Client 2
this.server.enqueue(jsonResponse(accessTokenResponse));
this.server.enqueue(jsonResponse(clientResponse));
ClientRegistration clientRegistration2 = TestClientRegistrations.clientCredentials()
.registrationId("client-2").tokenUri(this.serverUrl).build();
when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration2.getRegistrationId()))).thenReturn(clientRegistration2);
this.webClient
.get()
.uri(this.serverUrl)
.attributes(clientRegistrationId(clientRegistration1.getRegistrationId()))
.retrieve()
.bodyToMono(String.class)
.flatMap(response -> this.webClient
.get()
.uri(this.serverUrl)
.attributes(clientRegistrationId(clientRegistration2.getRegistrationId()))
.retrieve()
.bodyToMono(String.class))
.block();
assertThat(this.server.getRequestCount()).isEqualTo(4);
ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class);
verify(this.authorizedClientRepository, times(2)).saveAuthorizedClient(
authorizedClientCaptor.capture(), eq(this.authentication), eq(this.request), eq(this.response));
assertThat(authorizedClientCaptor.getAllValues().get(0).getClientRegistration()).isSameAs(clientRegistration1);
assertThat(authorizedClientCaptor.getAllValues().get(1).getClientRegistration()).isSameAs(clientRegistration2);
}
private MockResponse jsonResponse(String json) {
return new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(json);
}
}

View File

@ -84,7 +84,6 @@ import java.util.Optional;
import java.util.function.Consumer;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
@ -212,166 +211,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
verifyZeroInteractions(this.authorizedClientRepository);
}
@Test
public void defaultRequestOAuth2AuthorizedClientWhenOAuth2AuthorizationClientAndClientIdThenNotOverride() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken);
oauth2AuthorizedClient(authorizedClient).accept(this.result);
Map<String, Object> attrs = getDefaultRequestAttributes();
assertThat(getOAuth2AuthorizedClient(attrs)).isEqualTo(authorizedClient);
verifyZeroInteractions(this.authorizedClientRepository);
}
@Test
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() {
Map<String, Object> attrs = getDefaultRequestAttributes();
assertThat(getOAuth2AuthorizedClient(attrs)).isNull();
verifyZeroInteractions(this.authorizedClientRepository);
}
@Test
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationWrongTypeAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() {
Map<String, Object> attrs = getDefaultRequestAttributes();
assertThat(getOAuth2AuthorizedClient(attrs)).isNull();
verifyZeroInteractions(this.authorizedClientRepository);
}
@Test
public void defaultRequestOAuth2AuthorizedClientWhenRepositoryNullThenOAuth2AuthorizedClient() {
OAuth2User user = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id");
authentication(token).accept(this.result);
Map<String, Object> attrs = getDefaultRequestAttributes();
assertThat(getOAuth2AuthorizedClient(attrs)).isNull();
}
@Test
public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() {
this.function.setDefaultOAuth2AuthorizedClient(true);
OAuth2User user = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id");
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken);
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient);
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration);
authentication(token).accept(this.result);
httpServletRequest(new MockHttpServletRequest()).accept(this.result);
httpServletResponse(new MockHttpServletResponse()).accept(this.result);
Map<String, Object> attrs = getDefaultRequestAttributes();
assertThat(getOAuth2AuthorizedClient(attrs)).isEqualTo(authorizedClient);
verify(this.authorizedClientRepository).loadAuthorizedClient(eq(token.getAuthorizedClientRegistrationId()), any(), any());
}
@Test
public void defaultRequestOAuth2AuthorizedClientWhenDefaultFalseAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() {
OAuth2User user = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id");
authentication(token).accept(this.result);
Map<String, Object> attrs = getDefaultRequestAttributes();
assertThat(getOAuth2AuthorizedClient(attrs)).isNull();
}
@Test
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdThenIdIsExplicit() {
OAuth2User user = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id");
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken);
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient);
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration);
authentication(token).accept(this.result);
clientRegistrationId("explicit").accept(this.result);
httpServletRequest(new MockHttpServletRequest()).accept(this.result);
httpServletResponse(new MockHttpServletResponse()).accept(this.result);
Map<String, Object> attrs = getDefaultRequestAttributes();
assertThat(getOAuth2AuthorizedClient(attrs)).isEqualTo(authorizedClient);
verify(this.authorizedClientRepository).loadAuthorizedClient(eq("explicit"), any(), any());
}
@Test
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdThenOAuth2AuthorizedClient() {
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken);
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient);
clientRegistrationId("id").accept(this.result);
httpServletRequest(new MockHttpServletRequest()).accept(this.result);
httpServletResponse(new MockHttpServletResponse()).accept(this.result);
Map<String, Object> attrs = getDefaultRequestAttributes();
assertThat(getOAuth2AuthorizedClient(attrs)).isEqualTo(authorizedClient);
verify(this.authorizedClientRepository).loadAuthorizedClient(eq("id"), any(), any());
}
@Test
public void defaultRequestWhenClientCredentialsThenAuthorizedClient() {
this.registration = TestClientRegistrations.clientCredentials().build();
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration);
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
.accessTokenResponse().build();
when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
clientRegistrationId(this.registration.getRegistrationId()).accept(this.result);
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response));
SecurityContextHolder.getContext().setAuthentication(this.authentication);
Map<String, Object> attrs = getDefaultRequestAttributes();
OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs);
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo("test");
assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
}
@Test
public void defaultRequestWhenDefaultClientRegistrationIdThenAuthorizedClient() {
this.registration = TestClientRegistrations.clientCredentials().build();
this.function.setDefaultClientRegistrationId(this.registration.getRegistrationId());
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration);
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
.accessTokenResponse().build();
when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response));
SecurityContextHolder.getContext().setAuthentication(this.authentication);
Map<String, Object> attrs = getDefaultRequestAttributes();
OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs);
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo("test");
assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
}
@Test
public void defaultRequestWhenClientIdNotFoundThenIllegalArgumentException() {
this.registration = TestClientRegistrations.clientCredentials().build();
clientRegistrationId(this.registration.getRegistrationId()).accept(this.result);
assertThatCode(() -> getDefaultRequestAttributes())
.isInstanceOf(IllegalArgumentException.class);
}
private Map<String, Object> getDefaultRequestAttributes() {
this.function.defaultRequest().accept(this.spec);
verify(this.spec).attributes(this.attrs.capture());