diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java index b104796687..a97db4a26b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,6 +41,7 @@ import java.security.NoSuchAlgorithmException; import java.util.Base64; import java.util.HashMap; import java.util.Map; +import java.util.function.Consumer; /** * An implementation of an {@link OAuth2AuthorizationRequestResolver} that attempts to @@ -66,6 +67,7 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au private final AntPathRequestMatcher authorizationRequestMatcher; private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder()); private final StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96); + private Consumer authorizationRequestCustomizer = customizer -> {}; /** * Constructs a {@code DefaultOAuth2AuthorizationRequestResolver} using the provided parameters. @@ -98,6 +100,18 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au return resolve(request, registrationId, redirectUriAction); } + /** + * Sets the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder} + * allowing for further customizations. + * + * @since 5.3 + * @param authorizationRequestCustomizer the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder} + */ + public void setAuthorizationRequestCustomizer(Consumer authorizationRequestCustomizer) { + Assert.notNull(authorizationRequestCustomizer, "authorizationRequestCustomizer cannot be null"); + this.authorizationRequestCustomizer = authorizationRequestCustomizer; + } + private String getAction(HttpServletRequest request, String defaultAction) { String action = request.getParameter("action"); if (action == null) { @@ -144,16 +158,17 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au String redirectUriStr = expandRedirectUri(request, clientRegistration, redirectUriAction); - OAuth2AuthorizationRequest authorizationRequest = builder + builder .clientId(clientRegistration.getClientId()) .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) .redirectUri(redirectUriStr) .scopes(clientRegistration.getScopes()) .state(this.stateGenerator.generateKey()) - .attributes(attributes) - .build(); + .attributes(attributes); - return authorizationRequest; + this.authorizationRequestCustomizer.accept(builder); + + return builder.build(); } private String resolveRegistrationId(HttpServletRequest request) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java index 1410099d93..770e4f9982 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,6 +46,7 @@ import java.security.NoSuchAlgorithmException; import java.util.Base64; import java.util.HashMap; import java.util.Map; +import java.util.function.Consumer; /** * The default implementation of {@link ServerOAuth2AuthorizationRequestResolver}. @@ -81,6 +82,8 @@ public class DefaultServerOAuth2AuthorizationRequestResolver private final StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96); + private Consumer authorizationRequestCustomizer = customizer -> {}; + /** * Creates a new instance * @param clientRegistrationRepository the repository to resolve the {@link ClientRegistration} @@ -121,6 +124,18 @@ public class DefaultServerOAuth2AuthorizationRequestResolver .map(clientRegistration -> authorizationRequest(exchange, clientRegistration)); } + /** + * Sets the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder} + * allowing for further customizations. + * + * @since 5.3 + * @param authorizationRequestCustomizer the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder} + */ + public final void setAuthorizationRequestCustomizer(Consumer authorizationRequestCustomizer) { + Assert.notNull(authorizationRequestCustomizer, "authorizationRequestCustomizer cannot be null"); + this.authorizationRequestCustomizer = authorizationRequestCustomizer; + } + private Mono findByRegistrationId(ServerWebExchange exchange, String clientRegistration) { return this.clientRegistrationRepository.findByRegistrationId(clientRegistration) .switchIfEmpty(Mono.error(() -> new ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid client registration id"))); @@ -155,13 +170,17 @@ public class DefaultServerOAuth2AuthorizationRequestResolver "Invalid Authorization Grant Type (" + clientRegistration.getAuthorizationGrantType().getValue() + ") for Client Registration with Id: " + clientRegistration.getRegistrationId()); } - return builder + builder .clientId(clientRegistration.getClientId()) .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) - .redirectUri(redirectUriStr).scopes(clientRegistration.getScopes()) + .redirectUri(redirectUriStr) + .scopes(clientRegistration.getScopes()) .state(this.stateGenerator.generateKey()) - .attributes(attributes) - .build(); + .attributes(attributes); + + this.authorizationRequestCustomizer.accept(builder); + + return builder.build(); } /** diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixinTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixinTests.java index f7bec9f1bd..21ca65fb87 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixinTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixinTests.java @@ -27,6 +27,7 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2Authorization import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; +import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; import java.util.stream.Collectors; @@ -70,8 +71,8 @@ public class OAuth2AuthorizationRequestMixinTests { this.authorizationRequestBuilder .scopes(null) .state(null) - .additionalParameters(null) - .attributes(null) + .additionalParameters(Collections.emptyMap()) + .attributes(Collections.emptyMap()) .build(); String expectedJson = asJson(authorizationRequest); String json = this.mapper.writeValueAsString(authorizationRequest); @@ -118,8 +119,8 @@ public class OAuth2AuthorizationRequestMixinTests { this.authorizationRequestBuilder .scopes(null) .state(null) - .additionalParameters(null) - .attributes(null) + .additionalParameters(Collections.emptyMap()) + .attributes(Collections.emptyMap()) .build(); String json = asJson(expectedAuthorizationRequest); OAuth2AuthorizationRequest authorizationRequest = this.mapper.readValue(json, OAuth2AuthorizationRequest.class); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java index 1fd1ccb34d..45cf3897ef 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,7 +31,9 @@ import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.entry; /** * Tests for {@link DefaultOAuth2AuthorizationRequestResolver}. @@ -81,6 +83,12 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { .isInstanceOf(IllegalArgumentException.class); } + @Test + public void setAuthorizationRequestCustomizerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.resolver.setAuthorizationRequestCustomizer(null)) + .isInstanceOf(IllegalArgumentException.class); + } + @Test public void resolveWhenNotAuthorizationRequestThenDoesNotResolve() { String requestUri = "/path"; @@ -414,6 +422,76 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}"); } + // gh-7696 + @Test + public void resolveWhenAuthorizationRequestCustomizerRemovesNonceThenQueryExcludesNonce() { + ClientRegistration clientRegistration = this.oidcRegistration; + String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + + this.resolver.setAuthorizationRequestCustomizer(customizer -> customizer + .additionalParameters(params -> params.remove(OidcParameterNames.NONCE)) + .attributes(attrs -> attrs.remove(OidcParameterNames.NONCE))); + + OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); + assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE); + assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE); + assertThat(authorizationRequest.getAuthorizationRequestUri()) + .matches("https://example.com/login/oauth/authorize\\?" + + "response_type=code&client_id=client-id&" + + "scope=openid&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id"); + } + + @Test + public void resolveWhenAuthorizationRequestCustomizerAddsParameterThenQueryIncludesParameter() { + ClientRegistration clientRegistration = this.oidcRegistration; + String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + + this.resolver.setAuthorizationRequestCustomizer(customizer -> + customizer.authorizationRequestUri(uriBuilder -> { + uriBuilder.queryParam("param1", "value1"); + return uriBuilder.build(); + }) + ); + + OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); + assertThat(authorizationRequest.getAuthorizationRequestUri()) + .matches("https://example.com/login/oauth/authorize\\?" + + "response_type=code&client_id=client-id&" + + "scope=openid&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" + + "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" + + "param1=value1"); + } + + @Test + public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQueryIncludesParameter() { + ClientRegistration clientRegistration = this.oidcRegistration; + String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + + this.resolver.setAuthorizationRequestCustomizer(customizer -> + customizer.parameters(params -> { + params.put("appid", params.get("client_id")); + params.remove("client_id"); + }) + ); + + OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); + assertThat(authorizationRequest.getAuthorizationRequestUri()) + .matches("https://example.com/login/oauth/authorize\\?" + + "response_type=code&" + + "scope=openid&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" + + "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" + + "appid=client-id"); + } + private static ClientRegistration.Builder fineRedirectUriTemplateClientRegistration() { return ClientRegistration.withRegistrationId("fine-redirect-uri-template-client-registration") .redirectUriTemplate("{baseScheme}://{baseHost}{basePort}{basePath}/{action}/oauth2/code/{registrationId}") diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java index 0febd5025d..77479bcf8e 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,6 +37,7 @@ import org.springframework.web.server.ServerWebExchange; import reactor.core.publisher.Mono; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.catchThrowableOfType; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; @@ -59,6 +60,12 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests { this.resolver = new DefaultServerOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository); } + @Test + public void setAuthorizationRequestCustomizerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.resolver.setAuthorizationRequestCustomizer(null)) + .isInstanceOf(IllegalArgumentException.class); + } + @Test public void resolveWhenNotMatchThenNull() { assertThat(resolve("/")).isNull(); @@ -139,6 +146,79 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests { "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}"); } + // gh-7696 + @Test + public void resolveWhenAuthorizationRequestCustomizerRemovesNonceThenQueryExcludesNonce() { + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn( + Mono.just(TestClientRegistrations.clientRegistration() + .scope(OidcScopes.OPENID) + .build())); + + this.resolver.setAuthorizationRequestCustomizer(customizer -> customizer + .additionalParameters(params -> params.remove(OidcParameterNames.NONCE)) + .attributes(attrs -> attrs.remove(OidcParameterNames.NONCE))); + + OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id"); + + assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE); + assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE); + assertThat(authorizationRequest.getAuthorizationRequestUri()) + .matches("https://example.com/login/oauth/authorize\\?" + + "response_type=code&client_id=client-id&" + + "scope=openid&state=.{15,}&" + + "redirect_uri=/login/oauth2/code/registration-id"); + } + + @Test + public void resolveWhenAuthorizationRequestCustomizerAddsParameterThenQueryIncludesParameter() { + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn( + Mono.just(TestClientRegistrations.clientRegistration() + .scope(OidcScopes.OPENID) + .build())); + + this.resolver.setAuthorizationRequestCustomizer(customizer -> + customizer.authorizationRequestUri(uriBuilder -> { + uriBuilder.queryParam("param1", "value1"); + return uriBuilder.build(); + }) + ); + + OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id"); + + assertThat(authorizationRequest.getAuthorizationRequestUri()) + .matches("https://example.com/login/oauth/authorize\\?" + + "response_type=code&client_id=client-id&" + + "scope=openid&state=.{15,}&" + + "redirect_uri=/login/oauth2/code/registration-id&" + + "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" + + "param1=value1"); + } + + @Test + public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQueryIncludesParameter() { + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn( + Mono.just(TestClientRegistrations.clientRegistration() + .scope(OidcScopes.OPENID) + .build())); + + this.resolver.setAuthorizationRequestCustomizer(customizer -> + customizer.parameters(params -> { + params.put("appid", params.get("client_id")); + params.remove("client_id"); + }) + ); + + OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id"); + + assertThat(authorizationRequest.getAuthorizationRequestUri()) + .matches("https://example.com/login/oauth/authorize\\?" + + "response_type=code&" + + "scope=openid&state=.{15,}&" + + "redirect_uri=/login/oauth2/code/registration-id&" + + "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" + + "appid=client-id"); + } + private OAuth2AuthorizationRequest resolve(String path) { ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get(path)); return this.resolver.resolve(exchange).block(); diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java index f6490355ea..85c7a89814 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java @@ -22,16 +22,21 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; -import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.DefaultUriBuilderFactory; +import org.springframework.web.util.UriBuilder; import org.springframework.web.util.UriUtils; import java.io.Serializable; +import java.net.URI; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; +import java.util.function.Consumer; +import java.util.function.Function; /** * A representation of an OAuth 2.0 Authorization Request @@ -108,7 +113,7 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Returns the scope(s). * - * @return the scope(s) + * @return the scope(s), or an empty {@code Set} if not available */ public Set getScopes() { return this.scopes; @@ -124,31 +129,31 @@ public final class OAuth2AuthorizationRequest implements Serializable { } /** - * Returns the additional parameters used in the request. + * Returns the additional parameter(s) used in the request. * - * @return a {@code Map} of the additional parameters used in the request + * @return a {@code Map} of the additional parameter(s), or an empty {@code Map} if not available */ public Map getAdditionalParameters() { return this.additionalParameters; } /** - * Returns the attributes associated to the request. + * Returns the attribute(s) associated to the request. * * @since 5.2 - * @return a {@code Map} of the attributes associated to the request + * @return a {@code Map} of the attribute(s), or an empty {@code Map} if not available */ public Map getAttributes() { return this.attributes; } /** - * Returns the value of an attribute associated to the request, or {@code null} if not available. + * Returns the value of an attribute associated to the request. * * @since 5.2 * @param name the name of the attribute * @param the type of the attribute - * @return the value of the attribute associated to the request + * @return the value of the attribute associated to the request, or {@code null} if not available */ @SuppressWarnings("unchecked") public T getAttribute(String name) { @@ -219,9 +224,12 @@ public final class OAuth2AuthorizationRequest implements Serializable { private String redirectUri; private Set scopes; private String state; - private Map additionalParameters; + private Consumer> additionalParametersConsumer = params -> {}; + private Consumer> parametersConsumer = params -> {}; + private Consumer> attributesConsumer = attrs -> {}; private String authorizationRequestUri; - private Map attributes; + private Function authorizationRequestUriFunction = builder -> builder.build(); + private final DefaultUriBuilderFactory uriBuilderFactory; private Builder(AuthorizationGrantType authorizationGrantType) { Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null"); @@ -231,6 +239,10 @@ public final class OAuth2AuthorizationRequest implements Serializable { } else if (AuthorizationGrantType.IMPLICIT.equals(authorizationGrantType)) { this.responseType = OAuth2AuthorizationResponseType.TOKEN; } + this.uriBuilderFactory = new DefaultUriBuilderFactory(); + // The supplied authorizationUri may contain encoded parameters + // so disable encoding in UriBuilder and instead apply encoding within this builder + this.uriBuilderFactory.setEncodingMode(DefaultUriBuilderFactory.EncodingMode.NONE); } /** @@ -274,7 +286,7 @@ public final class OAuth2AuthorizationRequest implements Serializable { */ public Builder scope(String... scope) { if (scope != null && scope.length > 0) { - return this.scopes(toLinkedHashSet(scope)); + return scopes(new LinkedHashSet<>(Arrays.asList(scope))); } return this; } @@ -302,13 +314,43 @@ public final class OAuth2AuthorizationRequest implements Serializable { } /** - * Sets the additional parameters used in the request. + * Sets the additional parameter(s) used in the request. * - * @param additionalParameters the additional parameters used in the request + * @param additionalParameters the additional parameter(s) used in the request * @return the {@link Builder} */ public Builder additionalParameters(Map additionalParameters) { - this.additionalParameters = additionalParameters; + if (additionalParameters != null) { + return additionalParameters(params -> params.putAll(additionalParameters)); + } + return this; + } + + /** + * A {@code Consumer} to be provided access to the additional parameter(s) + * allowing the ability to add, replace, or remove. + * + * @since 5.3 + * @param additionalParametersConsumer a {@code Consumer} of the additional parameters + */ + public Builder additionalParameters(Consumer> additionalParametersConsumer) { + if (additionalParametersConsumer != null) { + this.additionalParametersConsumer = additionalParametersConsumer; + } + return this; + } + + /** + * A {@code Consumer} to be provided access to all the parameters + * allowing the ability to add, replace, or remove. + * + * @since 5.3 + * @param parametersConsumer a {@code Consumer} of all the parameters + */ + public Builder parameters(Consumer> parametersConsumer) { + if (parametersConsumer != null) { + this.parametersConsumer = parametersConsumer; + } return this; } @@ -320,7 +362,23 @@ public final class OAuth2AuthorizationRequest implements Serializable { * @return the {@link Builder} */ public Builder attributes(Map attributes) { - this.attributes = attributes; + if (attributes != null) { + return attributes(attrs -> attrs.putAll(attributes)); + } + return this; + } + + /** + * A {@code Consumer} to be provided access to the attribute(s) + * allowing the ability to add, replace, or remove. + * + * @since 5.3 + * @param attributesConsumer a {@code Consumer} of the attribute(s) + */ + public Builder attributes(Consumer> attributesConsumer) { + if (attributesConsumer != null) { + this.attributesConsumer = attributesConsumer; + } return this; } @@ -340,6 +398,20 @@ public final class OAuth2AuthorizationRequest implements Serializable { return this; } + /** + * A {@code Function} to be provided a {@code UriBuilder} representation + * of the OAuth 2.0 Authorization Request allowing for further customizations. + * + * @since 5.3 + * @param authorizationRequestUriFunction a {@code Function} to be provided a {@code UriBuilder} representation of the OAuth 2.0 Authorization Request + */ + public Builder authorizationRequestUri(Function authorizationRequestUriFunction) { + if (authorizationRequestUriFunction != null) { + this.authorizationRequestUriFunction = authorizationRequestUriFunction; + } + return this; + } + /** * Builds a new {@link OAuth2AuthorizationRequest}. * @@ -362,53 +434,53 @@ public final class OAuth2AuthorizationRequest implements Serializable { authorizationRequest.scopes = Collections.unmodifiableSet( CollectionUtils.isEmpty(this.scopes) ? Collections.emptySet() : new LinkedHashSet<>(this.scopes)); - authorizationRequest.additionalParameters = Collections.unmodifiableMap( - CollectionUtils.isEmpty(this.additionalParameters) ? - Collections.emptyMap() : new LinkedHashMap<>(this.additionalParameters)); + Map additionalParameters = new LinkedHashMap<>(); + this.additionalParametersConsumer.accept(additionalParameters); + authorizationRequest.additionalParameters = Collections.unmodifiableMap(additionalParameters); + Map attributes = new LinkedHashMap<>(); + this.attributesConsumer.accept(attributes); + authorizationRequest.attributes = Collections.unmodifiableMap(attributes); authorizationRequest.authorizationRequestUri = StringUtils.hasText(this.authorizationRequestUri) ? - this.authorizationRequestUri : this.buildAuthorizationRequestUri(); - authorizationRequest.attributes = Collections.unmodifiableMap( - CollectionUtils.isEmpty(this.attributes) ? - Collections.emptyMap() : new LinkedHashMap<>(this.attributes)); + this.authorizationRequestUri : this.buildAuthorizationRequestUri(); return authorizationRequest; } private String buildAuthorizationRequestUri() { - MultiValueMap parameters = new LinkedMultiValueMap<>(); - parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, encodeQueryParam(this.responseType.getValue())); - parameters.set(OAuth2ParameterNames.CLIENT_ID, encodeQueryParam(this.clientId)); + Map parameters = getParameters(); // Not encoded + this.parametersConsumer.accept(parameters); + MultiValueMap queryParams = new LinkedMultiValueMap<>(); + parameters.forEach((k, v) -> queryParams.set( + encodeQueryParam(k), encodeQueryParam(v.toString()))); // Encoded + UriBuilder uriBuilder = this.uriBuilderFactory.uriString(this.authorizationUri) + .queryParams(queryParams); + return this.authorizationRequestUriFunction.apply(uriBuilder).toString(); + } + + private Map getParameters() { + Map parameters = new LinkedHashMap<>(); + parameters.put(OAuth2ParameterNames.RESPONSE_TYPE, this.responseType.getValue()); + parameters.put(OAuth2ParameterNames.CLIENT_ID, this.clientId); if (!CollectionUtils.isEmpty(this.scopes)) { - parameters.set(OAuth2ParameterNames.SCOPE, - encodeQueryParam(StringUtils.collectionToDelimitedString(this.scopes, " "))); + parameters.put(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(this.scopes, " ")); } if (this.state != null) { - parameters.set(OAuth2ParameterNames.STATE, encodeQueryParam(this.state)); + parameters.put(OAuth2ParameterNames.STATE, this.state); } if (this.redirectUri != null) { - parameters.set(OAuth2ParameterNames.REDIRECT_URI, encodeQueryParam(this.redirectUri)); + parameters.put(OAuth2ParameterNames.REDIRECT_URI, this.redirectUri); } - if (!CollectionUtils.isEmpty(this.additionalParameters)) { - this.additionalParameters.forEach((k, v) -> - parameters.set(encodeQueryParam(k), encodeQueryParam(v.toString()))); - } - - return UriComponentsBuilder.fromHttpUrl(this.authorizationUri) - .queryParams(parameters) - .build() - .toUriString(); + Map additionalParameters = new LinkedHashMap<>(); + this.additionalParametersConsumer.accept(additionalParameters); + additionalParameters.forEach((k, v) -> parameters.put(k, v.toString())); + return parameters; } // Encode query parameter value according to RFC 3986 private static String encodeQueryParam(String value) { return UriUtils.encodeQueryParam(value, StandardCharsets.UTF_8); } - - private LinkedHashSet toLinkedHashSet(String... scope) { - LinkedHashSet result = new LinkedHashSet<>(); - Collections.addAll(result, scope); - return result; - } } } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java index 6480442f53..7952db6a54 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java @@ -18,13 +18,16 @@ package org.springframework.security.oauth2.core.endpoint; import org.junit.Test; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import java.net.URI; import java.util.Arrays; import java.util.HashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Tests for {@link OAuth2AuthorizationRequest}. @@ -126,7 +129,7 @@ public class OAuth2AuthorizationRequestTests { .redirectUri(REDIRECT_URI) .scopes(SCOPES) .state(STATE) - .additionalParameters(null) + .additionalParameters((Map) null) .build()) .doesNotThrowAnyException(); } @@ -220,6 +223,19 @@ public class OAuth2AuthorizationRequestTests { assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo(AUTHORIZATION_URI); } + @Test + public void buildWhenAuthorizationRequestUriFunctionSetThenOverridesDefault() { + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri(AUTHORIZATION_URI) + .clientId(CLIENT_ID) + .redirectUri(REDIRECT_URI) + .scopes(SCOPES) + .state(STATE) + .authorizationRequestUri(uriBuilder -> URI.create(AUTHORIZATION_URI)) + .build(); + assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo(AUTHORIZATION_URI); + } + @Test public void buildWhenAuthorizationRequestUriNotSetThenDefaultSet() { Map additionalParameters = new HashMap<>();