From a1bcd4ed0045d8dab612e9900919eb3b84de2643 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 23 Mar 2020 16:34:50 -0400 Subject: [PATCH] Fix OAuth2AuthorizationRequest additionalParameters/attributes Consumer Fixes gh-8177 --- .../OAuth2AuthorizationRequestMixinTests.java | 9 +++--- ...uth2AuthorizationRequestResolverTests.java | 1 + ...uth2AuthorizationRequestResolverTests.java | 2 ++ ...nCodeAuthenticationTokenConverterTest.java | 3 +- .../endpoint/OAuth2AuthorizationRequest.java | 30 ++++++++----------- .../OAuth2AuthorizationRequestTests.java | 4 +-- 6 files changed, 23 insertions(+), 26 deletions(-) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixinTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixinTests.java index 21ca65fb87..2630efeabc 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixinTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixinTests.java @@ -27,7 +27,6 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2Authorization import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; import java.util.stream.Collectors; @@ -71,8 +70,8 @@ public class OAuth2AuthorizationRequestMixinTests { this.authorizationRequestBuilder .scopes(null) .state(null) - .additionalParameters(Collections.emptyMap()) - .attributes(Collections.emptyMap()) + .additionalParameters(Map::clear) + .attributes(Map::clear) .build(); String expectedJson = asJson(authorizationRequest); String json = this.mapper.writeValueAsString(authorizationRequest); @@ -119,8 +118,8 @@ public class OAuth2AuthorizationRequestMixinTests { this.authorizationRequestBuilder .scopes(null) .state(null) - .additionalParameters(Collections.emptyMap()) - .attributes(Collections.emptyMap()) + .additionalParameters(Map::clear) + .attributes(Map::clear) .build(); String json = asJson(expectedAuthorizationRequest); OAuth2AuthorizationRequest authorizationRequest = this.mapper.readValue(json, OAuth2AuthorizationRequest.class); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java index 45cf3897ef..2f1f315101 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java @@ -437,6 +437,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE); assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE); + assertThat(authorizationRequest.getAttributes()).containsKey(OAuth2ParameterNames.REGISTRATION_ID); assertThat(authorizationRequest.getAuthorizationRequestUri()) .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" + diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java index 77479bcf8e..958799b014 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java @@ -29,6 +29,7 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; @@ -162,6 +163,7 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests { assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE); assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE); + assertThat(authorizationRequest.getAttributes()).containsKey(OAuth2ParameterNames.REGISTRATION_ID); assertThat(authorizationRequest.getAuthorizationRequestUri()) .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" + diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest.java index 83aa315cd5..95b19f7014 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest.java @@ -35,6 +35,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import reactor.core.publisher.Mono; import java.util.Collections; +import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -96,7 +97,7 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest { @Test public void applyWhenAttributesMissingThenOAuth2AuthorizationException() { - this.authorizationRequest.attributes(Collections.emptyMap()); + this.authorizationRequest.attributes(Map::clear); when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(this.authorizationRequest.build())); assertThatThrownBy(() -> applyConverter()) diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java index 28937c61c4..9323ce2a53 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java @@ -229,9 +229,9 @@ public final class OAuth2AuthorizationRequest implements Serializable { private String redirectUri; private Set scopes; private String state; - private Consumer> additionalParametersConsumer = params -> {}; + private Map additionalParameters = new LinkedHashMap<>(); private Consumer> parametersConsumer = params -> {}; - private Consumer> attributesConsumer = attrs -> {}; + private Map attributes = new LinkedHashMap<>(); private String authorizationRequestUri; private Function authorizationRequestUriFunction = builder -> builder.build(); private final DefaultUriBuilderFactory uriBuilderFactory; @@ -325,8 +325,8 @@ public final class OAuth2AuthorizationRequest implements Serializable { * @return the {@link Builder} */ public Builder additionalParameters(Map additionalParameters) { - if (additionalParameters != null) { - return additionalParameters(params -> params.putAll(additionalParameters)); + if (!CollectionUtils.isEmpty(additionalParameters)) { + this.additionalParameters.putAll(additionalParameters); } return this; } @@ -340,7 +340,7 @@ public final class OAuth2AuthorizationRequest implements Serializable { */ public Builder additionalParameters(Consumer> additionalParametersConsumer) { if (additionalParametersConsumer != null) { - this.additionalParametersConsumer = additionalParametersConsumer; + additionalParametersConsumer.accept(this.additionalParameters); } return this; } @@ -367,8 +367,8 @@ public final class OAuth2AuthorizationRequest implements Serializable { * @return the {@link Builder} */ public Builder attributes(Map attributes) { - if (attributes != null) { - return attributes(attrs -> attrs.putAll(attributes)); + if (!CollectionUtils.isEmpty(attributes)) { + this.attributes.putAll(attributes); } return this; } @@ -382,7 +382,7 @@ public final class OAuth2AuthorizationRequest implements Serializable { */ public Builder attributes(Consumer> attributesConsumer) { if (attributesConsumer != null) { - this.attributesConsumer = attributesConsumer; + attributesConsumer.accept(this.attributes); } return this; } @@ -439,12 +439,8 @@ public final class OAuth2AuthorizationRequest implements Serializable { authorizationRequest.scopes = Collections.unmodifiableSet( CollectionUtils.isEmpty(this.scopes) ? Collections.emptySet() : new LinkedHashSet<>(this.scopes)); - Map additionalParameters = new LinkedHashMap<>(); - this.additionalParametersConsumer.accept(additionalParameters); - authorizationRequest.additionalParameters = Collections.unmodifiableMap(additionalParameters); - Map attributes = new LinkedHashMap<>(); - this.attributesConsumer.accept(attributes); - authorizationRequest.attributes = Collections.unmodifiableMap(attributes); + authorizationRequest.additionalParameters = Collections.unmodifiableMap(this.additionalParameters); + authorizationRequest.attributes = Collections.unmodifiableMap(this.attributes); authorizationRequest.authorizationRequestUri = StringUtils.hasText(this.authorizationRequestUri) ? this.authorizationRequestUri : this.buildAuthorizationRequestUri(); @@ -457,7 +453,7 @@ public final class OAuth2AuthorizationRequest implements Serializable { this.parametersConsumer.accept(parameters); MultiValueMap queryParams = new LinkedMultiValueMap<>(); parameters.forEach((k, v) -> queryParams.set( - encodeQueryParam(k), encodeQueryParam(v.toString()))); // Encoded + encodeQueryParam(k), encodeQueryParam(String.valueOf(v)))); // Encoded UriBuilder uriBuilder = this.uriBuilderFactory.uriString(this.authorizationUri) .queryParams(queryParams); return this.authorizationRequestUriFunction.apply(uriBuilder).toString(); @@ -477,9 +473,7 @@ public final class OAuth2AuthorizationRequest implements Serializable { if (this.redirectUri != null) { parameters.put(OAuth2ParameterNames.REDIRECT_URI, this.redirectUri); } - Map additionalParameters = new LinkedHashMap<>(); - this.additionalParametersConsumer.accept(additionalParameters); - additionalParameters.forEach((k, v) -> parameters.put(k, v.toString())); + parameters.putAll(this.additionalParameters); return parameters; } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java index 7952db6a54..8f0745d4f2 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java @@ -121,7 +121,7 @@ public class OAuth2AuthorizationRequestTests { } @Test - public void buildWhenAdditionalParametersIsNullThenDoesNotThrowAnyException() { + public void buildWhenAdditionalParametersEmptyThenDoesNotThrowAnyException() { assertThatCode(() -> OAuth2AuthorizationRequest.authorizationCode() .authorizationUri(AUTHORIZATION_URI) @@ -129,7 +129,7 @@ public class OAuth2AuthorizationRequestTests { .redirectUri(REDIRECT_URI) .scopes(SCOPES) .state(STATE) - .additionalParameters((Map) null) + .additionalParameters(Map::clear) .build()) .doesNotThrowAnyException(); }