diff --git a/gradle/dependency-management.gradle b/gradle/dependency-management.gradle index ac5c5a3576..c297e2e5e0 100644 --- a/gradle/dependency-management.gradle +++ b/gradle/dependency-management.gradle @@ -70,6 +70,7 @@ dependencyManagement { dependency 'commons-lang:commons-lang:2.6' dependency 'commons-logging:commons-logging:1.2' dependency 'dom4j:dom4j:1.6.1' + dependency 'io.projectreactor.tools:blockhound:1.0.0.M4' dependency 'javax.activation:activation:1.1.1' dependency 'javax.annotation:jsr250-api:1.0' dependency 'javax.inject:javax.inject:1' diff --git a/oauth2/oauth2-client/spring-security-oauth2-client.gradle b/oauth2/oauth2-client/spring-security-oauth2-client.gradle index 76a3006591..624125ee28 100644 --- a/oauth2/oauth2-client/spring-security-oauth2-client.gradle +++ b/oauth2/oauth2-client/spring-security-oauth2-client.gradle @@ -17,6 +17,7 @@ dependencies { testCompile 'com.fasterxml.jackson.core:jackson-databind' testCompile 'io.projectreactor.netty:reactor-netty' testCompile 'io.projectreactor:reactor-test' + testCompile 'io.projectreactor.tools:blockhound' provided 'javax.servlet:javax.servlet-api' } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 8045e52e4c..ebfb082e5c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -288,20 +288,33 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction @Override public Mono filter(ClientRequest request, ExchangeFunction next) { - return Mono.just(request) - .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) - .switchIfEmpty(mergeRequestAttributesFromContext(request)) + return mergeRequestAttributesIfNecessary(request) .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) .flatMap(req -> authorizedClient(req, next, getOAuth2AuthorizedClient(req.attributes()))) + .switchIfEmpty(Mono.defer(() -> + mergeRequestAttributesIfNecessary(request) + .filter(req -> resolveClientRegistrationId(req.attributes()) != null) + .flatMap(this::authorizeClient) + )) .map(authorizedClient -> bearer(request, authorizedClient)) .flatMap(next::exchange) - .switchIfEmpty(next.exchange(request)); + .switchIfEmpty(Mono.defer(() -> next.exchange(request))); + } + + private Mono mergeRequestAttributesIfNecessary(ClientRequest request) { + if (!request.attribute(HTTP_SERVLET_REQUEST_ATTR_NAME).isPresent() || + !request.attribute(HTTP_SERVLET_RESPONSE_ATTR_NAME).isPresent() || + !request.attribute(AUTHENTICATION_ATTR_NAME).isPresent()) { + return mergeRequestAttributesFromContext(request); + } else { + return Mono.just(request); + } } private Mono mergeRequestAttributesFromContext(ClientRequest request) { - return Mono.just(ClientRequest.from(request)) - .flatMap(builder -> Mono.subscriberContext() - .map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx)))) + ClientRequest.Builder builder = ClientRequest.from(request); + return Mono.subscriberContext() + .map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx))) .map(ClientRequest.Builder::build); } @@ -348,40 +361,47 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction return; } + String clientRegistrationId = resolveClientRegistrationId(attrs); Authentication authentication = getAuthentication(attrs); + HttpServletRequest request = getRequest(attrs); + if (clientRegistrationId != null && authentication != null && request != null) { + OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( + clientRegistrationId, authentication, request); + if (authorizedClient != null) { + oauth2AuthorizedClient(authorizedClient).accept(attrs); + } + } + } + + private String resolveClientRegistrationId(Map attrs) { String clientRegistrationId = getClientRegistrationId(attrs); if (clientRegistrationId == null) { clientRegistrationId = this.defaultClientRegistrationId; } + Authentication authentication = getAuthentication(attrs); if (clientRegistrationId == null && this.defaultOAuth2AuthorizedClient && authentication instanceof OAuth2AuthenticationToken) { clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId(); } - if (clientRegistrationId != null) { - HttpServletRequest request = getRequest(attrs); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository - .loadAuthorizedClient(clientRegistrationId, authentication, - request); - if (authorizedClient == null) { - authorizedClient = getAuthorizedClient(clientRegistrationId, attrs); - } - oauth2AuthorizedClient(authorizedClient).accept(attrs); - } + return clientRegistrationId; } - private OAuth2AuthorizedClient getAuthorizedClient(String clientRegistrationId, Map attrs) { + private Mono authorizeClient(ClientRequest request) { + Map attrs = request.attributes(); + String clientRegistrationId = resolveClientRegistrationId(attrs); ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); if (clientRegistration == null) { throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId); } if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { - return getAuthorizedClient(clientRegistration, attrs); + // NOTE: 'getAuthorizedClient()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.elastic()) + // since it performs a blocking I/O operation using RestTemplate internally + return Mono.fromSupplier(() -> getAuthorizedClient(clientRegistration, attrs)).subscribeOn(Schedulers.elastic()); } throw new ClientAuthorizationRequiredException(clientRegistrationId); } - private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegistration, Map attrs) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionITests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionITests.java new file mode 100644 index 0000000000..fb8cf3e52d --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionITests.java @@ -0,0 +1,268 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web.reactive.function.client; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.blockhound.BlockHound; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashSet; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId; + +/** + * @author Joe Grandja + */ +public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests { + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientRepository authorizedClientRepository; + private ServletOAuth2AuthorizedClientExchangeFilterFunction authorizedClientFilter; + private MockWebServer server; + private String serverUrl; + private WebClient webClient; + private Authentication authentication; + private MockHttpServletRequest request; + private MockHttpServletResponse response; + + @BeforeClass + public static void setUpBlockingChecks() { + // IMPORTANT: + // Before enabling BlockHound, we need to white-list `java.lang.Class.getPackage()`. + // When the JVM loads `java.lang.Package.getSystemPackage()`, it attempts to + // `java.lang.Package.loadManifest()` which is blocking I/O and triggers BlockHound to error. + // NOTE: This is an issue with JDK 8. It's been tested on JDK 10 and works fine w/o this white-list. + BlockHound.builder() + .allowBlockingCallsInside(Class.class.getName(), "getPackage") + .install(); + } + + @Before + public void setUp() throws Exception { + this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); + final OAuth2AuthorizedClientRepository delegate = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository( + new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository)); + this.authorizedClientRepository = spy(new OAuth2AuthorizedClientRepository() { + @Override + public T loadAuthorizedClient(String clientRegistrationId, Authentication principal, HttpServletRequest request) { + return delegate.loadAuthorizedClient(clientRegistrationId, principal, request); + } + + @Override + public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, HttpServletRequest request, HttpServletResponse response) { + delegate.saveAuthorizedClient(authorizedClient, principal, request, response); + } + + @Override + public void removeAuthorizedClient(String clientRegistrationId, Authentication principal, HttpServletRequest request, HttpServletResponse response) { + delegate.removeAuthorizedClient(clientRegistrationId, principal, request, response); + } + }); + this.authorizedClientFilter = new ServletOAuth2AuthorizedClientExchangeFilterFunction( + this.clientRegistrationRepository, this.authorizedClientRepository); + this.authorizedClientFilter.afterPropertiesSet(); + this.server = new MockWebServer(); + this.server.start(); + this.serverUrl = this.server.url("/").toString(); + this.webClient = WebClient.builder() + .apply(this.authorizedClientFilter.oauth2Configuration()) + .build(); + this.authentication = new TestingAuthenticationToken("principal", "password"); + SecurityContextHolder.getContext().setAuthentication(this.authentication); + this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(this.request, this.response)); + } + + @After + public void cleanup() throws Exception { + this.authorizedClientFilter.destroy(); + this.server.shutdown(); + SecurityContextHolder.clearContext(); + RequestContextHolder.resetRequestAttributes(); + } + + @Test + public void requestWhenNotAuthorizedThenAuthorizeAndSendRequest() { + String accessTokenResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + String clientResponse = "{\n" + + " \"attribute1\": \"value1\",\n" + + " \"attribute2\": \"value2\"\n" + + "}\n"; + + this.server.enqueue(jsonResponse(accessTokenResponse)); + this.server.enqueue(jsonResponse(clientResponse)); + + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl).build(); + when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(clientRegistration); + + this.webClient + .get() + .uri(this.serverUrl) + .attributes(clientRegistrationId(clientRegistration.getRegistrationId())) + .retrieve() + .bodyToMono(String.class) + .block(); + + assertThat(this.server.getRequestCount()).isEqualTo(2); + + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository).saveAuthorizedClient( + authorizedClientCaptor.capture(), eq(this.authentication), eq(this.request), eq(this.response)); + assertThat(authorizedClientCaptor.getValue().getClientRegistration()).isSameAs(clientRegistration); + } + + @Test + public void requestWhenAuthorizedButExpiredThenRefreshAndSendRequest() { + String accessTokenResponse = "{\n" + + " \"access_token\": \"refreshed-access-token\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + String clientResponse = "{\n" + + " \"attribute1\": \"value1\",\n" + + " \"attribute2\": \"value2\"\n" + + "}\n"; + + this.server.enqueue(jsonResponse(accessTokenResponse)); + this.server.enqueue(jsonResponse(clientResponse)); + + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(this.serverUrl).build(); + when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(clientRegistration); + + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant expiresAt = issuedAt.plus(Duration.ofHours(1)); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "expired-access-token", issuedAt, expiresAt, new HashSet<>(Arrays.asList("read", "write"))); + OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, this.authentication.getName(), accessToken, refreshToken); + doReturn(authorizedClient).when(this.authorizedClientRepository).loadAuthorizedClient( + eq(clientRegistration.getRegistrationId()), eq(this.authentication), eq(this.request)); + + this.webClient + .get() + .uri(this.serverUrl) + .attributes(clientRegistrationId(clientRegistration.getRegistrationId())) + .retrieve() + .bodyToMono(String.class) + .block(); + + assertThat(this.server.getRequestCount()).isEqualTo(2); + + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository).saveAuthorizedClient( + authorizedClientCaptor.capture(), eq(this.authentication), eq(this.request), eq(this.response)); + OAuth2AuthorizedClient refreshedAuthorizedClient = authorizedClientCaptor.getValue(); + assertThat(refreshedAuthorizedClient.getClientRegistration()).isSameAs(clientRegistration); + assertThat(refreshedAuthorizedClient.getAccessToken().getTokenValue()).isEqualTo("refreshed-access-token"); + } + + @Test + public void requestMultipleWhenNoneAuthorizedThenAuthorizeAndSendRequest() { + String accessTokenResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + String clientResponse = "{\n" + + " \"attribute1\": \"value1\",\n" + + " \"attribute2\": \"value2\"\n" + + "}\n"; + + // Client 1 + this.server.enqueue(jsonResponse(accessTokenResponse)); + this.server.enqueue(jsonResponse(clientResponse)); + + ClientRegistration clientRegistration1 = TestClientRegistrations.clientCredentials() + .registrationId("client-1").tokenUri(this.serverUrl).build(); + when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration1.getRegistrationId()))).thenReturn(clientRegistration1); + + // Client 2 + this.server.enqueue(jsonResponse(accessTokenResponse)); + this.server.enqueue(jsonResponse(clientResponse)); + + ClientRegistration clientRegistration2 = TestClientRegistrations.clientCredentials() + .registrationId("client-2").tokenUri(this.serverUrl).build(); + when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration2.getRegistrationId()))).thenReturn(clientRegistration2); + + this.webClient + .get() + .uri(this.serverUrl) + .attributes(clientRegistrationId(clientRegistration1.getRegistrationId())) + .retrieve() + .bodyToMono(String.class) + .flatMap(response -> this.webClient + .get() + .uri(this.serverUrl) + .attributes(clientRegistrationId(clientRegistration2.getRegistrationId())) + .retrieve() + .bodyToMono(String.class)) + .block(); + + assertThat(this.server.getRequestCount()).isEqualTo(4); + + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository, times(2)).saveAuthorizedClient( + authorizedClientCaptor.capture(), eq(this.authentication), eq(this.request), eq(this.response)); + assertThat(authorizedClientCaptor.getAllValues().get(0).getClientRegistration()).isSameAs(clientRegistration1); + assertThat(authorizedClientCaptor.getAllValues().get(1).getClientRegistration()).isSameAs(clientRegistration2); + } + + private MockResponse jsonResponse(String json) { + return new MockResponse() + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(json); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 6b71657f62..beba0d96b4 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -75,7 +75,6 @@ import java.util.Optional; import java.util.function.Consumer; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; @@ -207,7 +206,10 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { } @Test - public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() { + public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndClientRegistrationIdNullThenOAuth2AuthorizedClient() { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); this.function.setDefaultOAuth2AuthorizedClient(true); @@ -241,6 +243,9 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdThenIdIsExplicit() { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); OAuth2User user = mock(OAuth2User.class); @@ -259,7 +264,11 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { } @Test - public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdThenOAuth2AuthorizedClient() { + public void defaultRequestOAuth2AuthorizedClientWhenClientRegistrationIdThenOAuth2AuthorizedClient() { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); + SecurityContextHolder.getContext().setAuthentication(this.authentication); this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, @@ -273,63 +282,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { verify(this.authorizedClientRepository).loadAuthorizedClient(eq("id"), any(), any()); } - @Test - public void defaultRequestWhenClientCredentialsThenAuthorizedClient() { - this.registration = TestClientRegistrations.clientCredentials().build(); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses - .accessTokenResponse().build(); - when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn( - accessTokenResponse); - - clientRegistrationId(this.registration.getRegistrationId()).accept(this.result); - - Map attrs = getDefaultRequestAttributes(); - OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); - - assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); - assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration); - assertThat(authorizedClient.getPrincipalName()).isEqualTo("anonymousUser"); - assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); - } - - @Test - public void defaultRequestWhenDefaultClientRegistrationIdThenAuthorizedClient() { - this.registration = TestClientRegistrations.clientCredentials().build(); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - this.function.setDefaultClientRegistrationId(this.registration.getRegistrationId()); - this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses - .accessTokenResponse().build(); - when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn( - accessTokenResponse); - - Map attrs = getDefaultRequestAttributes(); - OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); - - assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); - assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration); - assertThat(authorizedClient.getPrincipalName()).isEqualTo("anonymousUser"); - assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); - } - - @Test - public void defaultRequestWhenClientIdNotFoundThenIllegalArgumentException() { - this.registration = TestClientRegistrations.clientCredentials().build(); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - - clientRegistrationId(this.registration.getRegistrationId()).accept(this.result); - - assertThatCode(() -> getDefaultRequestAttributes()) - .isInstanceOf(IllegalArgumentException.class); - } - private Map getDefaultRequestAttributes() { this.function.defaultRequest().accept(this.spec); verify(this.spec).attributes(this.attrs.capture());