Add setBodyExtractor

Closes gh-10260
This commit is contained in:
bishoy basily 2021-09-14 22:27:57 +02:00 committed by Josh Cummings
parent c3ba2332da
commit 860690491a
5 changed files with 130 additions and 1 deletions

View File

@ -27,6 +27,7 @@ import reactor.core.publisher.Mono;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@ -35,6 +36,7 @@ import org.springframework.security.oauth2.core.web.reactive.function.OAuth2Body
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient;
@ -68,6 +70,9 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
private Converter<T, HttpHeaders> headersConverter = this::populateTokenRequestHeaders;
private BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor = OAuth2BodyExtractors
.oauth2AccessTokenResponse();
AbstractWebClientReactiveOAuth2AccessTokenResponseClient() {
}
@ -204,7 +209,7 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
* @return the token response from the response body.
*/
private Mono<OAuth2AccessTokenResponse> readTokenResponse(T grantRequest, ClientResponse response) {
return response.body(OAuth2BodyExtractors.oauth2AccessTokenResponse())
return response.body(this.bodyExtractor)
.map((tokenResponse) -> populateTokenResponse(grantRequest, tokenResponse));
}
@ -291,4 +296,17 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
};
}
/**
* Sets the {@link BodyExtractor} that will be used to decode the
* {@link OAuth2AccessTokenResponse}
* @param bodyExtractor the {@link BodyExtractor} that will be used to decode the
* {@link OAuth2AccessTokenResponse}
* @since 5.6
*/
public final void setBodyExtractor(
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor) {
Assert.notNull(bodyExtractor, "bodyExtractor cannot be null");
this.bodyExtractor = bodyExtractor;
}
}

View File

@ -27,11 +27,13 @@ import okhttp3.mockwebserver.RecordedRequest;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
@ -41,11 +43,14 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExch
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.reactive.function.client.WebClient;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
@ -409,4 +414,25 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
}
// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {
String accessTokenSuccessResponse = "{}";
WebClientReactiveAuthorizationCodeTokenResponseClient customClient = new WebClientReactiveAuthorizationCodeTokenResponseClient();
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class);
OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(extractor.extract(any(), any())).willReturn(Mono.just(response));
customClient.setBodyExtractor(extractor);
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(authorizationCodeGrantRequest())
.block();
assertThat(accessTokenResponse.getAccessToken()).isNotNull();
}
}

View File

@ -27,21 +27,26 @@ import okhttp3.mockwebserver.RecordedRequest;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.client.WebClientResponseException;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
@ -280,4 +285,26 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
}
// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {
enqueueJson("{}");
WebClientReactiveClientCredentialsTokenResponseClient customClient = new WebClientReactiveClientCredentialsTokenResponseClient();
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class);
OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(extractor.extract(any(), any())).willReturn(Mono.just(response));
customClient.setBodyExtractor(extractor);
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration.build());
OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(request).block();
assertThat(accessTokenResponse.getAccessToken()).isNotNull();
}
}

View File

@ -25,21 +25,26 @@ import okhttp3.mockwebserver.RecordedRequest;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.web.reactive.function.BodyExtractor;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@ -286,4 +291,29 @@ public class WebClientReactivePasswordTokenResponseClientTests {
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
}
// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {
String accessTokenSuccessResponse = "{}";
WebClientReactivePasswordTokenResponseClient customClient = new WebClientReactivePasswordTokenResponseClient();
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class);
OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(extractor.extract(any(), any())).willReturn(Mono.just(response));
customClient.setBodyExtractor(extractor);
ClientRegistration clientRegistration = this.clientRegistrationBuilder.build();
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration,
this.username, this.password);
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(passwordGrantRequest).block();
assertThat(accessTokenResponse.getAccessToken()).isNotNull();
}
}

View File

@ -25,11 +25,13 @@ import okhttp3.mockwebserver.RecordedRequest;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
@ -39,10 +41,13 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.web.reactive.function.BodyExtractor;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@ -289,4 +294,27 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
}
// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {
String accessTokenSuccessResponse = "{}";
WebClientReactiveRefreshTokenTokenResponseClient customClient = new WebClientReactiveRefreshTokenTokenResponseClient();
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class);
OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(extractor.extract(any(), any())).willReturn(Mono.just(response));
customClient.setBodyExtractor(extractor);
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(refreshTokenGrantRequest).block();
assertThat(accessTokenResponse.getAccessToken()).isNotNull();
}
}