From 8a791028b1aa6d78a4aab414f5853c95256d0e0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9C=B3=E4=BD=B3?= Date: Wed, 7 Aug 2024 13:35:34 +0800 Subject: [PATCH 1/2] Fix array values of additionalParameters Closes gh-15468 --- .../endpoint/OAuth2AuthorizationRequest.java | 26 +++++++++++++++++-- .../OAuth2AuthorizationRequestTests.java | 23 ++++++++++++---- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java index 9809ea6c1f..c40512a775 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java @@ -23,11 +23,13 @@ import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashMap; import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Consumer; import java.util.function.Function; - +import java.util.stream.Stream; +import java.util.stream.StreamSupport; import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.util.Assert; @@ -463,7 +465,13 @@ public final class OAuth2AuthorizationRequest implements Serializable { Map parameters = getParameters(); // Not encoded this.parametersConsumer.accept(parameters); MultiValueMap queryParams = new LinkedMultiValueMap<>(); - parameters.forEach((k, v) -> queryParams.set(encodeQueryParam(k), encodeQueryParam(String.valueOf(v)))); // Encoded + parameters.forEach((key1, value) -> { + String key = encodeQueryParam(key1); + List values = queryValues(value) + .map(o -> encodeQueryParam(String.valueOf(o))) + .toList(); + queryParams.put(key, values); + }); UriBuilder uriBuilder = this.uriBuilderFactory.uriString(this.authorizationUri).queryParams(queryParams); return this.authorizationRequestUriFunction.apply(uriBuilder).toString(); } @@ -490,6 +498,20 @@ public final class OAuth2AuthorizationRequest implements Serializable { return UriUtils.encodeQueryParam(value, StandardCharsets.UTF_8); } + // Query value as a stream + // If the value is an Iterable or an array it will be converted to a stream + private static Stream queryValues(Object value) { + if (value instanceof Iterable) { + return StreamSupport.stream(((Iterable) value).spliterator(), false); + + } else if (value.getClass().isArray()) { + return Arrays.stream((Object[]) value); + + } else { + return Stream.of(value); + } + } + } } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java index 1da76ce63b..f110e678cc 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java @@ -16,20 +16,18 @@ package org.springframework.security.oauth2.core.endpoint; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + import java.net.URI; import java.util.Arrays; import java.util.HashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; - import org.junit.jupiter.api.Test; - import org.springframework.security.oauth2.core.AuthorizationGrantType; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; - /** * Tests for {@link OAuth2AuthorizationRequest}. * @@ -364,4 +362,19 @@ public class OAuth2AuthorizationRequestTests { + "item%20amount=19.95%E2%82%AC&%C3%A2ge=4%C2%BD&item%20name=H%C3%85M%C3%96"); } + @Test + public void additionalParametersArrayValueOrIterableEncoded() { + Map additionalParameters = new HashMap<>(); + additionalParameters.put("item", new String[] { "1", "2" }); + additionalParameters.put("item2", Arrays.asList("H" + '\u00c5' + "M" + '\u00d6', "H" + '\u00c5' + "M" + '\u00d6')); + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .additionalParameters(additionalParameters) + .build(); + assertThat(authorizationRequest.getAuthorizationRequestUri()).isNotNull(); + assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo( + "https://example.com/login/oauth/authorize?" + "response_type=code&client_id=client-id&state=state&" + + "redirect_uri=https://example.com/authorize/oauth2/code/registration-id&" + + "item=1&item=2&item2=H%C3%85M%C3%96&item2=H%C3%85M%C3%96"); + } + } From 7b7a3044cf6379fc4241d64e4a517b7f6566db34 Mon Sep 17 00:00:00 2001 From: Steve Riesenberg <5248162+sjohnr@users.noreply.github.com> Date: Thu, 19 Sep 2024 15:50:43 -0500 Subject: [PATCH 2/2] Polish gh-15533 --- .../endpoint/OAuth2AuthorizationRequest.java | 40 +++++------- .../OAuth2AuthorizationRequestTests.java | 61 ++++++++++++++----- 2 files changed, 63 insertions(+), 38 deletions(-) diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java index c40512a775..20224f6ded 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,13 +23,11 @@ import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashMap; import java.util.LinkedHashSet; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Consumer; import java.util.function.Function; -import java.util.stream.Stream; -import java.util.stream.StreamSupport; + import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.util.Assert; @@ -465,12 +463,20 @@ public final class OAuth2AuthorizationRequest implements Serializable { Map parameters = getParameters(); // Not encoded this.parametersConsumer.accept(parameters); MultiValueMap queryParams = new LinkedMultiValueMap<>(); - parameters.forEach((key1, value) -> { - String key = encodeQueryParam(key1); - List values = queryValues(value) - .map(o -> encodeQueryParam(String.valueOf(o))) - .toList(); - queryParams.put(key, values); + parameters.forEach((k, v) -> { + String key = encodeQueryParam(k); + if (v instanceof Iterable) { + ((Iterable) v).forEach((value) -> queryParams.add(key, encodeQueryParam(String.valueOf(value)))); + } + else if (v != null && v.getClass().isArray()) { + Object[] values = (Object[]) v; + for (Object value : values) { + queryParams.add(key, encodeQueryParam(String.valueOf(value))); + } + } + else { + queryParams.set(key, encodeQueryParam(String.valueOf(v))); + } }); UriBuilder uriBuilder = this.uriBuilderFactory.uriString(this.authorizationUri).queryParams(queryParams); return this.authorizationRequestUriFunction.apply(uriBuilder).toString(); @@ -498,20 +504,6 @@ public final class OAuth2AuthorizationRequest implements Serializable { return UriUtils.encodeQueryParam(value, StandardCharsets.UTF_8); } - // Query value as a stream - // If the value is an Iterable or an array it will be converted to a stream - private static Stream queryValues(Object value) { - if (value instanceof Iterable) { - return StreamSupport.stream(((Iterable) value).spliterator(), false); - - } else if (value.getClass().isArray()) { - return Arrays.stream((Object[]) value); - - } else { - return Stream.of(value); - } - } - } } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java index f110e678cc..46ab89b867 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,18 +16,21 @@ package org.springframework.security.oauth2.core.endpoint; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; - import java.net.URI; import java.util.Arrays; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; + import org.junit.jupiter.api.Test; + import org.springframework.security.oauth2.core.AuthorizationGrantType; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + /** * Tests for {@link OAuth2AuthorizationRequest}. * @@ -363,18 +366,48 @@ public class OAuth2AuthorizationRequestTests { } @Test - public void additionalParametersArrayValueOrIterableEncoded() { - Map additionalParameters = new HashMap<>(); - additionalParameters.put("item", new String[] { "1", "2" }); - additionalParameters.put("item2", Arrays.asList("H" + '\u00c5' + "M" + '\u00d6', "H" + '\u00c5' + "M" + '\u00d6')); + public void buildWhenAdditionalParametersContainsArrayThenProperlyEncoded() { + Map additionalParameters = new LinkedHashMap<>(); + additionalParameters.put("item1", new String[] { "1", "2" }); + additionalParameters.put("item2", "value2"); OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() - .additionalParameters(additionalParameters) - .build(); + .additionalParameters(additionalParameters) + .build(); assertThat(authorizationRequest.getAuthorizationRequestUri()).isNotNull(); - assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo( - "https://example.com/login/oauth/authorize?" + "response_type=code&client_id=client-id&state=state&" - + "redirect_uri=https://example.com/authorize/oauth2/code/registration-id&" - + "item=1&item=2&item2=H%C3%85M%C3%96&item2=H%C3%85M%C3%96"); + assertThat(authorizationRequest.getAuthorizationRequestUri()) + .isEqualTo("https://example.com/login/oauth/authorize?response_type=code&client_id=client-id&state=state&" + + "redirect_uri=https://example.com/authorize/oauth2/code/registration-id&" + + "item1=1&item1=2&item2=value2"); + } + + @Test + public void buildWhenAdditionalParametersContainsIterableThenProperlyEncoded() { + Map additionalParameters = new LinkedHashMap<>(); + additionalParameters.put("item1", Arrays.asList("1", "2")); + additionalParameters.put("item2", "value2"); + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .additionalParameters(additionalParameters) + .build(); + assertThat(authorizationRequest.getAuthorizationRequestUri()).isNotNull(); + assertThat(authorizationRequest.getAuthorizationRequestUri()) + .isEqualTo("https://example.com/login/oauth/authorize?response_type=code&client_id=client-id&state=state&" + + "redirect_uri=https://example.com/authorize/oauth2/code/registration-id&" + + "item1=1&item1=2&item2=value2"); + } + + @Test + public void buildWhenAdditionalParametersContainsNullThenAuthorizationRequestUriContainsNull() { + Map additionalParameters = new LinkedHashMap<>(); + additionalParameters.put("item1", null); + additionalParameters.put("item2", "value2"); + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .additionalParameters(additionalParameters) + .build(); + assertThat(authorizationRequest.getAuthorizationRequestUri()).isNotNull(); + assertThat(authorizationRequest.getAuthorizationRequestUri()) + .isEqualTo("https://example.com/login/oauth/authorize?response_type=code&client_id=client-id&state=state&" + + "redirect_uri=https://example.com/authorize/oauth2/code/registration-id&" + + "item1=null&item2=value2"); } }