Move parametersCustomizer

The parametersCustomizer was introduced in 6.4.0-M4 with
DefaultOAuth2TokenRequestParametersConverter. However, it cannot be
applied to all parameters and so does not fully solve gh-11298.

This commit moves the customizer to the abstract class so it can be
applied to all parameters.

Closes gh-15939
This commit is contained in:
Steve Riesenberg 2024-10-17 11:51:32 -05:00
parent af2b84246b
commit dab6950231
No known key found for this signature in database
GPG Key ID: 3D0169B18AB8F0A9
14 changed files with 153 additions and 47 deletions

View File

@ -16,6 +16,8 @@
package org.springframework.security.oauth2.client.endpoint;
import java.util.function.Consumer;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.converter.FormHttpMessageConverter;
@ -76,6 +78,9 @@ public abstract class AbstractRestClientOAuth2AccessTokenResponseClient<T extend
private Converter<T, MultiValueMap<String, String>> parametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>();
private Consumer<MultiValueMap<String, String>> parametersCustomizer = (parameters) -> {
};
AbstractRestClientOAuth2AccessTokenResponseClient() {
}
@ -127,6 +132,7 @@ public abstract class AbstractRestClientOAuth2AccessTokenResponseClient<T extend
if (parameters == null) {
parameters = new LinkedMultiValueMap<>();
}
this.parametersCustomizer.accept(parameters);
return this.restClient.post()
.uri(grantRequest.getClientRegistration().getProviderDetails().getTokenUri())
@ -243,4 +249,14 @@ public abstract class AbstractRestClientOAuth2AccessTokenResponseClient<T extend
this.requestEntityConverter = this::populateRequest;
}
/**
* Sets the {@link Consumer} used for customizing the OAuth 2.0 Access Token
* parameters, which allows for parameters to be added, overwritten or removed.
* @param parametersCustomizer the {@link Consumer} to customize the parameters
*/
public void setParametersCustomizer(Consumer<MultiValueMap<String, String>> parametersCustomizer) {
Assert.notNull(parametersCustomizer, "parametersCustomizer cannot be null");
this.parametersCustomizer = parametersCustomizer;
}
}

View File

@ -16,6 +16,8 @@
package org.springframework.security.oauth2.client.endpoint;
import java.util.function.Consumer;
import reactor.core.publisher.Mono;
import org.springframework.core.convert.converter.Converter;
@ -68,6 +70,9 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
private Converter<T, MultiValueMap<String, String>> parametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>();
private Consumer<MultiValueMap<String, String>> parametersCustomizer = (parameters) -> {
};
private BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor = OAuth2BodyExtractors
.oauth2AccessTokenResponse();
@ -108,6 +113,7 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
if (parameters == null) {
parameters = new LinkedMultiValueMap<>();
}
this.parametersCustomizer.accept(parameters);
return this.webClient.post()
.uri(grantRequest.getClientRegistration().getProviderDetails().getTokenUri())
@ -228,6 +234,16 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
this.requestEntityConverter = this::populateRequest;
}
/**
* Sets the {@link Consumer} used for customizing the OAuth 2.0 Access Token
* parameters, which allows for parameters to be added, overwritten or removed.
* @param parametersCustomizer the {@link Consumer} to customize the parameters
*/
public void setParametersCustomizer(Consumer<MultiValueMap<String, String>> parametersCustomizer) {
Assert.notNull(parametersCustomizer, "parametersCustomizer cannot be null");
this.parametersCustomizer = parametersCustomizer;
}
/**
* Sets the {@link BodyExtractor} that will be used to decode the
* {@link OAuth2AccessTokenResponse}

View File

@ -16,13 +16,10 @@
package org.springframework.security.oauth2.client.endpoint;
import java.util.function.Consumer;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
@ -64,19 +61,6 @@ public final class DefaultOAuth2TokenRequestParametersConverter<T extends Abstra
private final Converter<T, MultiValueMap<String, String>> defaultParametersConverter = createDefaultParametersConverter();
private Consumer<MultiValueMap<String, String>> parametersCustomizer = (parameters) -> {
};
/**
* Sets the {@link Consumer} used for customizing the OAuth 2.0 Access Token
* parameters, which allows for parameters to be added, overwritten or removed.
* @param parametersCustomizer the {@link Consumer} to customize the parameters
*/
public void setParametersCustomizer(Consumer<MultiValueMap<String, String>> parametersCustomizer) {
Assert.notNull(parametersCustomizer, "parametersCustomizer cannot be null");
this.parametersCustomizer = parametersCustomizer;
}
@Override
public MultiValueMap<String, String> convert(T grantRequest) {
ClientRegistration clientRegistration = grantRequest.getClientRegistration();
@ -95,7 +79,6 @@ public final class DefaultOAuth2TokenRequestParametersConverter<T extends Abstra
parameters.addAll(defaultParameters);
}
this.parametersCustomizer.accept(parameters);
return parameters;
}

View File

@ -154,6 +154,15 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
// @formatter:on
}
@Test
public void setParametersCustomizerWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.tokenResponseClient.setParametersCustomizer(null))
.withMessage("parametersCustomizer cannot be null");
// @formatter:on
}
@Test
public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() {
// @formatter:off
@ -439,12 +448,7 @@ public class RestClientAuthorizationCodeTokenResponseClientTests {
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
this.authorizationExchange);
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock();
// @formatter:off
DefaultOAuth2TokenRequestParametersConverter<OAuth2AuthorizationCodeGrantRequest> parametersConverter =
new DefaultOAuth2TokenRequestParametersConverter<>();
// @formatter:on
parametersConverter.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.setParametersConverter(parametersConverter);
this.tokenResponseClient.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.getTokenResponse(grantRequest);
verify(parametersCustomizer).accept(any());
}

View File

@ -138,6 +138,15 @@ public class RestClientClientCredentialsTokenResponseClientTests {
// @formatter:on
}
@Test
public void setParametersCustomizerWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.tokenResponseClient.setParametersCustomizer(null))
.withMessage("parametersCustomizer cannot be null");
// @formatter:on
}
@Test
public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() {
// @formatter:off
@ -438,12 +447,7 @@ public class RestClientClientCredentialsTokenResponseClientTests {
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock();
// @formatter:off
DefaultOAuth2TokenRequestParametersConverter<OAuth2ClientCredentialsGrantRequest> parametersConverter =
new DefaultOAuth2TokenRequestParametersConverter<>();
// @formatter:on
parametersConverter.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.setParametersConverter(parametersConverter);
this.tokenResponseClient.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.getTokenResponse(grantRequest);
verify(parametersCustomizer).accept(any());
}

View File

@ -140,6 +140,15 @@ public class RestClientJwtBearerTokenResponseClientTests {
// @formatter:on
}
@Test
public void setParametersCustomizerWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.tokenResponseClient.setParametersCustomizer(null))
.withMessage("parametersCustomizer cannot be null");
// @formatter:on
}
@Test
public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() {
// @formatter:off
@ -414,12 +423,7 @@ public class RestClientJwtBearerTokenResponseClientTests {
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock();
// @formatter:off
DefaultOAuth2TokenRequestParametersConverter<JwtBearerGrantRequest> parametersConverter =
new DefaultOAuth2TokenRequestParametersConverter<>();
// @formatter:on
parametersConverter.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.setParametersConverter(parametersConverter);
this.tokenResponseClient.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.getTokenResponse(grantRequest);
verify(parametersCustomizer).accept(any());
}

View File

@ -147,6 +147,15 @@ public class RestClientRefreshTokenTokenResponseClientTests {
// @formatter:on
}
@Test
public void setParametersCustomizerWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.tokenResponseClient.setParametersCustomizer(null))
.withMessage("parametersCustomizer cannot be null");
// @formatter:on
}
@Test
public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() {
// @formatter:off
@ -461,12 +470,7 @@ public class RestClientRefreshTokenTokenResponseClientTests {
OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken);
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock();
// @formatter:off
DefaultOAuth2TokenRequestParametersConverter<OAuth2RefreshTokenGrantRequest> parametersConverter =
new DefaultOAuth2TokenRequestParametersConverter<>();
// @formatter:on
parametersConverter.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.setParametersConverter(parametersConverter);
this.tokenResponseClient.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.getTokenResponse(grantRequest);
verify(parametersCustomizer).accept(any());
}

View File

@ -148,6 +148,15 @@ public class RestClientTokenExchangeTokenResponseClientTests {
// @formatter:on
}
@Test
public void setParametersCustomizerWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.tokenResponseClient.setParametersCustomizer(null))
.withMessage("parametersCustomizer cannot be null");
// @formatter:on
}
@Test
public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() {
// @formatter:off
@ -545,12 +554,7 @@ public class RestClientTokenExchangeTokenResponseClientTests {
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken,
this.actorToken);
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock();
// @formatter:off
DefaultOAuth2TokenRequestParametersConverter<TokenExchangeGrantRequest> parametersConverter =
new DefaultOAuth2TokenRequestParametersConverter<>();
// @formatter:on
parametersConverter.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.setParametersConverter(parametersConverter);
this.tokenResponseClient.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.getTokenResponse(grantRequest);
verify(parametersCustomizer).accept(any());
}

View File

@ -22,6 +22,7 @@ import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import javax.crypto.spec.SecretKeySpec;
@ -406,6 +407,16 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
// @formatter:on
}
@Test
public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception {
this.server.enqueue(MockResponses.json("access-token-response.json"));
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock();
this.tokenResponseClient.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.getTokenResponse(request).block();
verify(parametersCustomizer).accept(any());
}
// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {

View File

@ -20,6 +20,7 @@ import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Collections;
import java.util.function.Consumer;
import java.util.function.Function;
import javax.crypto.spec.SecretKeySpec;
@ -365,6 +366,17 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
// @formatter:on
}
@Test
public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception {
this.server.enqueue(MockResponses.json("access-token-response.json"));
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration.build());
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock();
this.client.setParametersCustomizer(parametersCustomizer);
this.client.getTokenResponse(request).block();
verify(parametersCustomizer).accept(any());
}
// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {

View File

@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client.endpoint;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.function.Consumer;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
@ -289,6 +290,17 @@ public class WebClientReactiveJwtBearerTokenResponseClientTests {
// @formatter:on
}
@Test
public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception {
this.server.enqueue(MockResponses.json("access-token-response.json"));
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock();
this.client.setParametersCustomizer(parametersCustomizer);
this.client.getTokenResponse(request).block();
verify(parametersCustomizer).accept(any());
}
@Test
public void getTokenResponseWhenBodyExtractorSetThenCalled() {
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor = mock();

View File

@ -20,6 +20,7 @@ import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
import java.util.function.Consumer;
import java.util.function.Function;
import javax.crypto.spec.SecretKeySpec;
@ -408,6 +409,17 @@ public class WebClientReactivePasswordTokenResponseClientTests {
// @formatter:on
}
@Test
public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception {
this.server.enqueue(MockResponses.json("access-token-response.json"));
OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(),
this.username, this.password);
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock();
this.tokenResponseClient.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.getTokenResponse(request).block();
verify(parametersCustomizer).accept(any());
}
// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {

View File

@ -20,6 +20,7 @@ import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
import java.util.function.Consumer;
import java.util.function.Function;
import javax.crypto.spec.SecretKeySpec;
@ -378,6 +379,17 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
// @formatter:on
}
@Test
public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception {
this.server.enqueue(MockResponses.json("access-token-response.json"));
OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock();
this.tokenResponseClient.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.getTokenResponse(request).block();
verify(parametersCustomizer).accept(any());
}
// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {

View File

@ -21,6 +21,7 @@ import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
import java.util.function.Consumer;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
@ -529,6 +530,17 @@ public class WebClientReactiveTokenExchangeTokenResponseClientTests {
// @formatter:on
}
@Test
public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception {
this.server.enqueue(MockResponses.json("access-token-response.json"));
TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(),
this.subjectToken, this.actorToken);
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock();
this.tokenResponseClient.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.getTokenResponse(request).block();
verify(parametersCustomizer).accept(any());
}
@Test
public void getTokenResponseWhenBodyExtractorSetThenCalled() {
this.server.enqueue(MockResponses.json("access-token-response.json"));