Add WebClientReactiveClientCredentialsTokenResponseClient setWebClient

Added the ability to specify a custom WebClient in
WebClientReactiveClientCredentialsTokenResponseClient.
Also added testing to ensure the custom WebClient is not null and is
used.

Fixes: gh-6051
This commit is contained in:
jer051 2018-11-21 16:11:24 -06:00 committed by Rob Winch
parent 918a4cd323
commit fdc81822ec
2 changed files with 35 additions and 1 deletions

View File

@ -21,6 +21,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.BodyInserters;
@ -112,4 +113,9 @@ public class WebClientReactiveClientCredentialsTokenResponseClient implements Re
} }
return body; return body;
} }
public void setWebClient(WebClient webClient) {
Assert.notNull(webClient, "webClient cannot be null");
this.webClient = webClient;
}
} }

View File

@ -28,9 +28,11 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.client.WebClientResponseException; import org.springframework.web.reactive.function.client.WebClientResponseException;
import static org.assertj.core.api.Assertions.*; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.*;
/** /**
* @author Rob Winch * @author Rob Winch
@ -55,6 +57,7 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
@After @After
public void cleanup() throws Exception { public void cleanup() throws Exception {
validateMockitoUsage();
this.server.shutdown(); this.server.shutdown();
} }
@ -117,6 +120,31 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
assertThat(response.getAccessToken().getScopes()).isEqualTo(registration.getScopes()); assertThat(response.getAccessToken().getScopes()).isEqualTo(registration.getScopes());
} }
@Test(expected=IllegalArgumentException.class)
public void setWebClientNullThenIllegalArgumentException(){
client.setWebClient(null);
}
@Test
public void setWebClientCustomThenCustomClientIsUsed() {
WebClient customClient = mock(WebClient.class);
when(customClient.post()).thenReturn(WebClient.builder().build().post());
this.client.setWebClient(customClient);
ClientRegistration registration = this.clientRegistration.build();
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}");
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration);
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
verify(customClient, atLeastOnce()).post();
}
@Test(expected = WebClientResponseException.class) @Test(expected = WebClientResponseException.class)
// gh-6089 // gh-6089
public void getTokenResponseWhenInvalidResponse() throws WebClientResponseException { public void getTokenResponseWhenInvalidResponse() throws WebClientResponseException {