Simplify customizing OAuth2AuthorizationRequest

Fixes gh-7696
This commit is contained in:
Joe Grandja 2019-12-14 20:15:38 -05:00
parent 6123d794e4
commit 23ce717380
7 changed files with 344 additions and 63 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -41,6 +41,7 @@ import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
/**
* An implementation of an {@link OAuth2AuthorizationRequestResolver} that attempts to
@ -66,6 +67,7 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
private final AntPathRequestMatcher authorizationRequestMatcher;
private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
private final StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
private Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer = customizer -> {};
/**
* Constructs a {@code DefaultOAuth2AuthorizationRequestResolver} using the provided parameters.
@ -98,6 +100,18 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
return resolve(request, registrationId, redirectUriAction);
}
/**
* Sets the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
* allowing for further customizations.
*
* @since 5.3
* @param authorizationRequestCustomizer the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
*/
public void setAuthorizationRequestCustomizer(Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer) {
Assert.notNull(authorizationRequestCustomizer, "authorizationRequestCustomizer cannot be null");
this.authorizationRequestCustomizer = authorizationRequestCustomizer;
}
private String getAction(HttpServletRequest request, String defaultAction) {
String action = request.getParameter("action");
if (action == null) {
@ -144,16 +158,17 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
String redirectUriStr = expandRedirectUri(request, clientRegistration, redirectUriAction);
OAuth2AuthorizationRequest authorizationRequest = builder
builder
.clientId(clientRegistration.getClientId())
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
.redirectUri(redirectUriStr)
.scopes(clientRegistration.getScopes())
.state(this.stateGenerator.generateKey())
.attributes(attributes)
.build();
.attributes(attributes);
return authorizationRequest;
this.authorizationRequestCustomizer.accept(builder);
return builder.build();
}
private String resolveRegistrationId(HttpServletRequest request) {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -46,6 +46,7 @@ import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
/**
* The default implementation of {@link ServerOAuth2AuthorizationRequestResolver}.
@ -81,6 +82,8 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
private final StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
private Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer = customizer -> {};
/**
* Creates a new instance
* @param clientRegistrationRepository the repository to resolve the {@link ClientRegistration}
@ -121,6 +124,18 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
.map(clientRegistration -> authorizationRequest(exchange, clientRegistration));
}
/**
* Sets the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
* allowing for further customizations.
*
* @since 5.3
* @param authorizationRequestCustomizer the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
*/
public final void setAuthorizationRequestCustomizer(Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer) {
Assert.notNull(authorizationRequestCustomizer, "authorizationRequestCustomizer cannot be null");
this.authorizationRequestCustomizer = authorizationRequestCustomizer;
}
private Mono<ClientRegistration> findByRegistrationId(ServerWebExchange exchange, String clientRegistration) {
return this.clientRegistrationRepository.findByRegistrationId(clientRegistration)
.switchIfEmpty(Mono.error(() -> new ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid client registration id")));
@ -155,13 +170,17 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
"Invalid Authorization Grant Type (" + clientRegistration.getAuthorizationGrantType().getValue()
+ ") for Client Registration with Id: " + clientRegistration.getRegistrationId());
}
return builder
builder
.clientId(clientRegistration.getClientId())
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
.redirectUri(redirectUriStr).scopes(clientRegistration.getScopes())
.redirectUri(redirectUriStr)
.scopes(clientRegistration.getScopes())
.state(this.stateGenerator.generateKey())
.attributes(attributes)
.build();
.attributes(attributes);
this.authorizationRequestCustomizer.accept(builder);
return builder.build();
}
/**

View File

@ -27,6 +27,7 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2Authorization
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.stream.Collectors;
@ -70,8 +71,8 @@ public class OAuth2AuthorizationRequestMixinTests {
this.authorizationRequestBuilder
.scopes(null)
.state(null)
.additionalParameters(null)
.attributes(null)
.additionalParameters(Collections.emptyMap())
.attributes(Collections.emptyMap())
.build();
String expectedJson = asJson(authorizationRequest);
String json = this.mapper.writeValueAsString(authorizationRequest);
@ -118,8 +119,8 @@ public class OAuth2AuthorizationRequestMixinTests {
this.authorizationRequestBuilder
.scopes(null)
.state(null)
.additionalParameters(null)
.attributes(null)
.additionalParameters(Collections.emptyMap())
.attributes(Collections.emptyMap())
.build();
String json = asJson(expectedAuthorizationRequest);
OAuth2AuthorizationRequest authorizationRequest = this.mapper.readValue(json, OAuth2AuthorizationRequest.class);

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -31,7 +31,9 @@ import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import static org.assertj.core.api.Assertions.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.entry;
/**
* Tests for {@link DefaultOAuth2AuthorizationRequestResolver}.
@ -81,6 +83,12 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void setAuthorizationRequestCustomizerWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.resolver.setAuthorizationRequestCustomizer(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void resolveWhenNotAuthorizationRequestThenDoesNotResolve() {
String requestUri = "/path";
@ -414,6 +422,76 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}");
}
// gh-7696
@Test
public void resolveWhenAuthorizationRequestCustomizerRemovesNonceThenQueryExcludesNonce() {
ClientRegistration clientRegistration = this.oidcRegistration;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
this.resolver.setAuthorizationRequestCustomizer(customizer -> customizer
.additionalParameters(params -> params.remove(OidcParameterNames.NONCE))
.attributes(attrs -> attrs.remove(OidcParameterNames.NONCE)));
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE);
assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE);
assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" +
"response_type=code&client_id=client-id&" +
"scope=openid&state=.{15,}&" +
"redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id");
}
@Test
public void resolveWhenAuthorizationRequestCustomizerAddsParameterThenQueryIncludesParameter() {
ClientRegistration clientRegistration = this.oidcRegistration;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
this.resolver.setAuthorizationRequestCustomizer(customizer ->
customizer.authorizationRequestUri(uriBuilder -> {
uriBuilder.queryParam("param1", "value1");
return uriBuilder.build();
})
);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" +
"response_type=code&client_id=client-id&" +
"scope=openid&state=.{15,}&" +
"redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" +
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
"param1=value1");
}
@Test
public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQueryIncludesParameter() {
ClientRegistration clientRegistration = this.oidcRegistration;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
this.resolver.setAuthorizationRequestCustomizer(customizer ->
customizer.parameters(params -> {
params.put("appid", params.get("client_id"));
params.remove("client_id");
})
);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" +
"response_type=code&" +
"scope=openid&state=.{15,}&" +
"redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" +
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
"appid=client-id");
}
private static ClientRegistration.Builder fineRedirectUriTemplateClientRegistration() {
return ClientRegistration.withRegistrationId("fine-redirect-uri-template-client-registration")
.redirectUriTemplate("{baseScheme}://{baseHost}{basePort}{basePath}/{action}/oauth2/code/{registrationId}")

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -37,6 +37,7 @@ import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.catchThrowableOfType;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
@ -59,6 +60,12 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests {
this.resolver = new DefaultServerOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository);
}
@Test
public void setAuthorizationRequestCustomizerWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.resolver.setAuthorizationRequestCustomizer(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void resolveWhenNotMatchThenNull() {
assertThat(resolve("/")).isNull();
@ -139,6 +146,79 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests {
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}");
}
// gh-7696
@Test
public void resolveWhenAuthorizationRequestCustomizerRemovesNonceThenQueryExcludesNonce() {
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(
Mono.just(TestClientRegistrations.clientRegistration()
.scope(OidcScopes.OPENID)
.build()));
this.resolver.setAuthorizationRequestCustomizer(customizer -> customizer
.additionalParameters(params -> params.remove(OidcParameterNames.NONCE))
.attributes(attrs -> attrs.remove(OidcParameterNames.NONCE)));
OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id");
assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE);
assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE);
assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" +
"response_type=code&client_id=client-id&" +
"scope=openid&state=.{15,}&" +
"redirect_uri=/login/oauth2/code/registration-id");
}
@Test
public void resolveWhenAuthorizationRequestCustomizerAddsParameterThenQueryIncludesParameter() {
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(
Mono.just(TestClientRegistrations.clientRegistration()
.scope(OidcScopes.OPENID)
.build()));
this.resolver.setAuthorizationRequestCustomizer(customizer ->
customizer.authorizationRequestUri(uriBuilder -> {
uriBuilder.queryParam("param1", "value1");
return uriBuilder.build();
})
);
OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id");
assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" +
"response_type=code&client_id=client-id&" +
"scope=openid&state=.{15,}&" +
"redirect_uri=/login/oauth2/code/registration-id&" +
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
"param1=value1");
}
@Test
public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQueryIncludesParameter() {
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(
Mono.just(TestClientRegistrations.clientRegistration()
.scope(OidcScopes.OPENID)
.build()));
this.resolver.setAuthorizationRequestCustomizer(customizer ->
customizer.parameters(params -> {
params.put("appid", params.get("client_id"));
params.remove("client_id");
})
);
OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id");
assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" +
"response_type=code&" +
"scope=openid&state=.{15,}&" +
"redirect_uri=/login/oauth2/code/registration-id&" +
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
"appid=client-id");
}
private OAuth2AuthorizationRequest resolve(String path) {
ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get(path));
return this.resolver.resolve(exchange).block();

View File

@ -22,16 +22,21 @@ import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;
import org.springframework.web.util.DefaultUriBuilderFactory;
import org.springframework.web.util.UriBuilder;
import org.springframework.web.util.UriUtils;
import java.io.Serializable;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
/**
* A representation of an OAuth 2.0 Authorization Request
@ -108,7 +113,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
/**
* Returns the scope(s).
*
* @return the scope(s)
* @return the scope(s), or an empty {@code Set} if not available
*/
public Set<String> getScopes() {
return this.scopes;
@ -124,31 +129,31 @@ public final class OAuth2AuthorizationRequest implements Serializable {
}
/**
* Returns the additional parameters used in the request.
* Returns the additional parameter(s) used in the request.
*
* @return a {@code Map} of the additional parameters used in the request
* @return a {@code Map} of the additional parameter(s), or an empty {@code Map} if not available
*/
public Map<String, Object> getAdditionalParameters() {
return this.additionalParameters;
}
/**
* Returns the attributes associated to the request.
* Returns the attribute(s) associated to the request.
*
* @since 5.2
* @return a {@code Map} of the attributes associated to the request
* @return a {@code Map} of the attribute(s), or an empty {@code Map} if not available
*/
public Map<String, Object> getAttributes() {
return this.attributes;
}
/**
* Returns the value of an attribute associated to the request, or {@code null} if not available.
* Returns the value of an attribute associated to the request.
*
* @since 5.2
* @param name the name of the attribute
* @param <T> the type of the attribute
* @return the value of the attribute associated to the request
* @return the value of the attribute associated to the request, or {@code null} if not available
*/
@SuppressWarnings("unchecked")
public <T> T getAttribute(String name) {
@ -219,9 +224,12 @@ public final class OAuth2AuthorizationRequest implements Serializable {
private String redirectUri;
private Set<String> scopes;
private String state;
private Map<String, Object> additionalParameters;
private Consumer<Map<String, Object>> additionalParametersConsumer = params -> {};
private Consumer<Map<String, Object>> parametersConsumer = params -> {};
private Consumer<Map<String, Object>> attributesConsumer = attrs -> {};
private String authorizationRequestUri;
private Map<String, Object> attributes;
private Function<UriBuilder, URI> authorizationRequestUriFunction = builder -> builder.build();
private final DefaultUriBuilderFactory uriBuilderFactory;
private Builder(AuthorizationGrantType authorizationGrantType) {
Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null");
@ -231,6 +239,10 @@ public final class OAuth2AuthorizationRequest implements Serializable {
} else if (AuthorizationGrantType.IMPLICIT.equals(authorizationGrantType)) {
this.responseType = OAuth2AuthorizationResponseType.TOKEN;
}
this.uriBuilderFactory = new DefaultUriBuilderFactory();
// The supplied authorizationUri may contain encoded parameters
// so disable encoding in UriBuilder and instead apply encoding within this builder
this.uriBuilderFactory.setEncodingMode(DefaultUriBuilderFactory.EncodingMode.NONE);
}
/**
@ -274,7 +286,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
*/
public Builder scope(String... scope) {
if (scope != null && scope.length > 0) {
return this.scopes(toLinkedHashSet(scope));
return scopes(new LinkedHashSet<>(Arrays.asList(scope)));
}
return this;
}
@ -302,13 +314,43 @@ public final class OAuth2AuthorizationRequest implements Serializable {
}
/**
* Sets the additional parameters used in the request.
* Sets the additional parameter(s) used in the request.
*
* @param additionalParameters the additional parameters used in the request
* @param additionalParameters the additional parameter(s) used in the request
* @return the {@link Builder}
*/
public Builder additionalParameters(Map<String, Object> additionalParameters) {
this.additionalParameters = additionalParameters;
if (additionalParameters != null) {
return additionalParameters(params -> params.putAll(additionalParameters));
}
return this;
}
/**
* A {@code Consumer} to be provided access to the additional parameter(s)
* allowing the ability to add, replace, or remove.
*
* @since 5.3
* @param additionalParametersConsumer a {@code Consumer} of the additional parameters
*/
public Builder additionalParameters(Consumer<Map<String, Object>> additionalParametersConsumer) {
if (additionalParametersConsumer != null) {
this.additionalParametersConsumer = additionalParametersConsumer;
}
return this;
}
/**
* A {@code Consumer} to be provided access to all the parameters
* allowing the ability to add, replace, or remove.
*
* @since 5.3
* @param parametersConsumer a {@code Consumer} of all the parameters
*/
public Builder parameters(Consumer<Map<String, Object>> parametersConsumer) {
if (parametersConsumer != null) {
this.parametersConsumer = parametersConsumer;
}
return this;
}
@ -320,7 +362,23 @@ public final class OAuth2AuthorizationRequest implements Serializable {
* @return the {@link Builder}
*/
public Builder attributes(Map<String, Object> attributes) {
this.attributes = attributes;
if (attributes != null) {
return attributes(attrs -> attrs.putAll(attributes));
}
return this;
}
/**
* A {@code Consumer} to be provided access to the attribute(s)
* allowing the ability to add, replace, or remove.
*
* @since 5.3
* @param attributesConsumer a {@code Consumer} of the attribute(s)
*/
public Builder attributes(Consumer<Map<String, Object>> attributesConsumer) {
if (attributesConsumer != null) {
this.attributesConsumer = attributesConsumer;
}
return this;
}
@ -340,6 +398,20 @@ public final class OAuth2AuthorizationRequest implements Serializable {
return this;
}
/**
* A {@code Function} to be provided a {@code UriBuilder} representation
* of the OAuth 2.0 Authorization Request allowing for further customizations.
*
* @since 5.3
* @param authorizationRequestUriFunction a {@code Function} to be provided a {@code UriBuilder} representation of the OAuth 2.0 Authorization Request
*/
public Builder authorizationRequestUri(Function<UriBuilder, URI> authorizationRequestUriFunction) {
if (authorizationRequestUriFunction != null) {
this.authorizationRequestUriFunction = authorizationRequestUriFunction;
}
return this;
}
/**
* Builds a new {@link OAuth2AuthorizationRequest}.
*
@ -362,53 +434,53 @@ public final class OAuth2AuthorizationRequest implements Serializable {
authorizationRequest.scopes = Collections.unmodifiableSet(
CollectionUtils.isEmpty(this.scopes) ?
Collections.emptySet() : new LinkedHashSet<>(this.scopes));
authorizationRequest.additionalParameters = Collections.unmodifiableMap(
CollectionUtils.isEmpty(this.additionalParameters) ?
Collections.emptyMap() : new LinkedHashMap<>(this.additionalParameters));
Map<String, Object> additionalParameters = new LinkedHashMap<>();
this.additionalParametersConsumer.accept(additionalParameters);
authorizationRequest.additionalParameters = Collections.unmodifiableMap(additionalParameters);
Map<String, Object> attributes = new LinkedHashMap<>();
this.attributesConsumer.accept(attributes);
authorizationRequest.attributes = Collections.unmodifiableMap(attributes);
authorizationRequest.authorizationRequestUri =
StringUtils.hasText(this.authorizationRequestUri) ?
this.authorizationRequestUri : this.buildAuthorizationRequestUri();
authorizationRequest.attributes = Collections.unmodifiableMap(
CollectionUtils.isEmpty(this.attributes) ?
Collections.emptyMap() : new LinkedHashMap<>(this.attributes));
this.authorizationRequestUri : this.buildAuthorizationRequestUri();
return authorizationRequest;
}
private String buildAuthorizationRequestUri() {
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, encodeQueryParam(this.responseType.getValue()));
parameters.set(OAuth2ParameterNames.CLIENT_ID, encodeQueryParam(this.clientId));
Map<String, Object> parameters = getParameters(); // Not encoded
this.parametersConsumer.accept(parameters);
MultiValueMap<String, String> queryParams = new LinkedMultiValueMap<>();
parameters.forEach((k, v) -> queryParams.set(
encodeQueryParam(k), encodeQueryParam(v.toString()))); // Encoded
UriBuilder uriBuilder = this.uriBuilderFactory.uriString(this.authorizationUri)
.queryParams(queryParams);
return this.authorizationRequestUriFunction.apply(uriBuilder).toString();
}
private Map<String, Object> getParameters() {
Map<String, Object> parameters = new LinkedHashMap<>();
parameters.put(OAuth2ParameterNames.RESPONSE_TYPE, this.responseType.getValue());
parameters.put(OAuth2ParameterNames.CLIENT_ID, this.clientId);
if (!CollectionUtils.isEmpty(this.scopes)) {
parameters.set(OAuth2ParameterNames.SCOPE,
encodeQueryParam(StringUtils.collectionToDelimitedString(this.scopes, " ")));
parameters.put(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(this.scopes, " "));
}
if (this.state != null) {
parameters.set(OAuth2ParameterNames.STATE, encodeQueryParam(this.state));
parameters.put(OAuth2ParameterNames.STATE, this.state);
}
if (this.redirectUri != null) {
parameters.set(OAuth2ParameterNames.REDIRECT_URI, encodeQueryParam(this.redirectUri));
parameters.put(OAuth2ParameterNames.REDIRECT_URI, this.redirectUri);
}
if (!CollectionUtils.isEmpty(this.additionalParameters)) {
this.additionalParameters.forEach((k, v) ->
parameters.set(encodeQueryParam(k), encodeQueryParam(v.toString())));
}
return UriComponentsBuilder.fromHttpUrl(this.authorizationUri)
.queryParams(parameters)
.build()
.toUriString();
Map<String, Object> additionalParameters = new LinkedHashMap<>();
this.additionalParametersConsumer.accept(additionalParameters);
additionalParameters.forEach((k, v) -> parameters.put(k, v.toString()));
return parameters;
}
// Encode query parameter value according to RFC 3986
private static String encodeQueryParam(String value) {
return UriUtils.encodeQueryParam(value, StandardCharsets.UTF_8);
}
private LinkedHashSet<String> toLinkedHashSet(String... scope) {
LinkedHashSet<String> result = new LinkedHashSet<>();
Collections.addAll(result, scope);
return result;
}
}
}

View File

@ -18,13 +18,16 @@ package org.springframework.security.oauth2.core.endpoint;
import org.junit.Test;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import java.net.URI;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import static org.assertj.core.api.Assertions.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Tests for {@link OAuth2AuthorizationRequest}.
@ -126,7 +129,7 @@ public class OAuth2AuthorizationRequestTests {
.redirectUri(REDIRECT_URI)
.scopes(SCOPES)
.state(STATE)
.additionalParameters(null)
.additionalParameters((Map) null)
.build())
.doesNotThrowAnyException();
}
@ -220,6 +223,19 @@ public class OAuth2AuthorizationRequestTests {
assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo(AUTHORIZATION_URI);
}
@Test
public void buildWhenAuthorizationRequestUriFunctionSetThenOverridesDefault() {
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(AUTHORIZATION_URI)
.clientId(CLIENT_ID)
.redirectUri(REDIRECT_URI)
.scopes(SCOPES)
.state(STATE)
.authorizationRequestUri(uriBuilder -> URI.create(AUTHORIZATION_URI))
.build();
assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo(AUTHORIZATION_URI);
}
@Test
public void buildWhenAuthorizationRequestUriNotSetThenDefaultSet() {
Map<String, Object> additionalParameters = new HashMap<>();