Fix OAuth2AuthorizationRequest additionalParameters/attributes Consumer

Fixes gh-8177
This commit is contained in:
Joe Grandja 2020-03-23 16:34:50 -04:00
parent 2c103f34e3
commit 46baf38f59
6 changed files with 23 additions and 26 deletions

View File

@ -27,7 +27,6 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2Authorization
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import java.util.Collections;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -71,8 +70,8 @@ public class OAuth2AuthorizationRequestMixinTests {
this.authorizationRequestBuilder this.authorizationRequestBuilder
.scopes(null) .scopes(null)
.state(null) .state(null)
.additionalParameters(Collections.emptyMap()) .additionalParameters(Map::clear)
.attributes(Collections.emptyMap()) .attributes(Map::clear)
.build(); .build();
String expectedJson = asJson(authorizationRequest); String expectedJson = asJson(authorizationRequest);
String json = this.mapper.writeValueAsString(authorizationRequest); String json = this.mapper.writeValueAsString(authorizationRequest);
@ -119,8 +118,8 @@ public class OAuth2AuthorizationRequestMixinTests {
this.authorizationRequestBuilder this.authorizationRequestBuilder
.scopes(null) .scopes(null)
.state(null) .state(null)
.additionalParameters(Collections.emptyMap()) .additionalParameters(Map::clear)
.attributes(Collections.emptyMap()) .attributes(Map::clear)
.build(); .build();
String json = asJson(expectedAuthorizationRequest); String json = asJson(expectedAuthorizationRequest);
OAuth2AuthorizationRequest authorizationRequest = this.mapper.readValue(json, OAuth2AuthorizationRequest.class); OAuth2AuthorizationRequest authorizationRequest = this.mapper.readValue(json, OAuth2AuthorizationRequest.class);

View File

@ -437,6 +437,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE); assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE);
assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE); assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE);
assertThat(authorizationRequest.getAttributes()).containsKey(OAuth2ParameterNames.REGISTRATION_ID);
assertThat(authorizationRequest.getAuthorizationRequestUri()) assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" + .matches("https://example.com/login/oauth/authorize\\?" +
"response_type=code&client_id=client-id&" + "response_type=code&client_id=client-id&" +

View File

@ -29,6 +29,7 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg
import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
@ -162,6 +163,7 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests {
assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE); assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE);
assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE); assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE);
assertThat(authorizationRequest.getAttributes()).containsKey(OAuth2ParameterNames.REGISTRATION_ID);
assertThat(authorizationRequest.getAuthorizationRequestUri()) assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" + .matches("https://example.com/login/oauth/authorize\\?" +
"response_type=code&client_id=client-id&" + "response_type=code&client_id=client-id&" +

View File

@ -35,6 +35,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import java.util.Collections; import java.util.Collections;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
@ -96,7 +97,7 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest {
@Test @Test
public void applyWhenAttributesMissingThenOAuth2AuthorizationException() { public void applyWhenAttributesMissingThenOAuth2AuthorizationException() {
this.authorizationRequest.attributes(Collections.emptyMap()); this.authorizationRequest.attributes(Map::clear);
when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(this.authorizationRequest.build())); when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(this.authorizationRequest.build()));
assertThatThrownBy(() -> applyConverter()) assertThatThrownBy(() -> applyConverter())

View File

@ -229,9 +229,9 @@ public final class OAuth2AuthorizationRequest implements Serializable {
private String redirectUri; private String redirectUri;
private Set<String> scopes; private Set<String> scopes;
private String state; private String state;
private Consumer<Map<String, Object>> additionalParametersConsumer = params -> {}; private Map<String, Object> additionalParameters = new LinkedHashMap<>();
private Consumer<Map<String, Object>> parametersConsumer = params -> {}; private Consumer<Map<String, Object>> parametersConsumer = params -> {};
private Consumer<Map<String, Object>> attributesConsumer = attrs -> {}; private Map<String, Object> attributes = new LinkedHashMap<>();
private String authorizationRequestUri; private String authorizationRequestUri;
private Function<UriBuilder, URI> authorizationRequestUriFunction = builder -> builder.build(); private Function<UriBuilder, URI> authorizationRequestUriFunction = builder -> builder.build();
private final DefaultUriBuilderFactory uriBuilderFactory; private final DefaultUriBuilderFactory uriBuilderFactory;
@ -325,8 +325,8 @@ public final class OAuth2AuthorizationRequest implements Serializable {
* @return the {@link Builder} * @return the {@link Builder}
*/ */
public Builder additionalParameters(Map<String, Object> additionalParameters) { public Builder additionalParameters(Map<String, Object> additionalParameters) {
if (additionalParameters != null) { if (!CollectionUtils.isEmpty(additionalParameters)) {
return additionalParameters(params -> params.putAll(additionalParameters)); this.additionalParameters.putAll(additionalParameters);
} }
return this; return this;
} }
@ -340,7 +340,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
*/ */
public Builder additionalParameters(Consumer<Map<String, Object>> additionalParametersConsumer) { public Builder additionalParameters(Consumer<Map<String, Object>> additionalParametersConsumer) {
if (additionalParametersConsumer != null) { if (additionalParametersConsumer != null) {
this.additionalParametersConsumer = additionalParametersConsumer; additionalParametersConsumer.accept(this.additionalParameters);
} }
return this; return this;
} }
@ -367,8 +367,8 @@ public final class OAuth2AuthorizationRequest implements Serializable {
* @return the {@link Builder} * @return the {@link Builder}
*/ */
public Builder attributes(Map<String, Object> attributes) { public Builder attributes(Map<String, Object> attributes) {
if (attributes != null) { if (!CollectionUtils.isEmpty(attributes)) {
return attributes(attrs -> attrs.putAll(attributes)); this.attributes.putAll(attributes);
} }
return this; return this;
} }
@ -382,7 +382,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
*/ */
public Builder attributes(Consumer<Map<String, Object>> attributesConsumer) { public Builder attributes(Consumer<Map<String, Object>> attributesConsumer) {
if (attributesConsumer != null) { if (attributesConsumer != null) {
this.attributesConsumer = attributesConsumer; attributesConsumer.accept(this.attributes);
} }
return this; return this;
} }
@ -439,12 +439,8 @@ public final class OAuth2AuthorizationRequest implements Serializable {
authorizationRequest.scopes = Collections.unmodifiableSet( authorizationRequest.scopes = Collections.unmodifiableSet(
CollectionUtils.isEmpty(this.scopes) ? CollectionUtils.isEmpty(this.scopes) ?
Collections.emptySet() : new LinkedHashSet<>(this.scopes)); Collections.emptySet() : new LinkedHashSet<>(this.scopes));
Map<String, Object> additionalParameters = new LinkedHashMap<>(); authorizationRequest.additionalParameters = Collections.unmodifiableMap(this.additionalParameters);
this.additionalParametersConsumer.accept(additionalParameters); authorizationRequest.attributes = Collections.unmodifiableMap(this.attributes);
authorizationRequest.additionalParameters = Collections.unmodifiableMap(additionalParameters);
Map<String, Object> attributes = new LinkedHashMap<>();
this.attributesConsumer.accept(attributes);
authorizationRequest.attributes = Collections.unmodifiableMap(attributes);
authorizationRequest.authorizationRequestUri = authorizationRequest.authorizationRequestUri =
StringUtils.hasText(this.authorizationRequestUri) ? StringUtils.hasText(this.authorizationRequestUri) ?
this.authorizationRequestUri : this.buildAuthorizationRequestUri(); this.authorizationRequestUri : this.buildAuthorizationRequestUri();
@ -457,7 +453,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
this.parametersConsumer.accept(parameters); this.parametersConsumer.accept(parameters);
MultiValueMap<String, String> queryParams = new LinkedMultiValueMap<>(); MultiValueMap<String, String> queryParams = new LinkedMultiValueMap<>();
parameters.forEach((k, v) -> queryParams.set( parameters.forEach((k, v) -> queryParams.set(
encodeQueryParam(k), encodeQueryParam(v.toString()))); // Encoded encodeQueryParam(k), encodeQueryParam(String.valueOf(v)))); // Encoded
UriBuilder uriBuilder = this.uriBuilderFactory.uriString(this.authorizationUri) UriBuilder uriBuilder = this.uriBuilderFactory.uriString(this.authorizationUri)
.queryParams(queryParams); .queryParams(queryParams);
return this.authorizationRequestUriFunction.apply(uriBuilder).toString(); return this.authorizationRequestUriFunction.apply(uriBuilder).toString();
@ -477,9 +473,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
if (this.redirectUri != null) { if (this.redirectUri != null) {
parameters.put(OAuth2ParameterNames.REDIRECT_URI, this.redirectUri); parameters.put(OAuth2ParameterNames.REDIRECT_URI, this.redirectUri);
} }
Map<String, Object> additionalParameters = new LinkedHashMap<>(); parameters.putAll(this.additionalParameters);
this.additionalParametersConsumer.accept(additionalParameters);
additionalParameters.forEach((k, v) -> parameters.put(k, v.toString()));
return parameters; return parameters;
} }

View File

@ -121,7 +121,7 @@ public class OAuth2AuthorizationRequestTests {
} }
@Test @Test
public void buildWhenAdditionalParametersIsNullThenDoesNotThrowAnyException() { public void buildWhenAdditionalParametersEmptyThenDoesNotThrowAnyException() {
assertThatCode(() -> assertThatCode(() ->
OAuth2AuthorizationRequest.authorizationCode() OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(AUTHORIZATION_URI) .authorizationUri(AUTHORIZATION_URI)
@ -129,7 +129,7 @@ public class OAuth2AuthorizationRequestTests {
.redirectUri(REDIRECT_URI) .redirectUri(REDIRECT_URI)
.scopes(SCOPES) .scopes(SCOPES)
.state(STATE) .state(STATE)
.additionalParameters((Map) null) .additionalParameters(Map::clear)
.build()) .build())
.doesNotThrowAnyException(); .doesNotThrowAnyException();
} }