Simplify customizing OAuth2AuthorizationRequest
Fixes gh-7696
This commit is contained in:
parent
6123d794e4
commit
23ce717380
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -41,6 +41,7 @@ import java.security.NoSuchAlgorithmException;
|
|||
import java.util.Base64;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* An implementation of an {@link OAuth2AuthorizationRequestResolver} that attempts to
|
||||
|
@ -66,6 +67,7 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
|
|||
private final AntPathRequestMatcher authorizationRequestMatcher;
|
||||
private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
|
||||
private final StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
|
||||
private Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer = customizer -> {};
|
||||
|
||||
/**
|
||||
* Constructs a {@code DefaultOAuth2AuthorizationRequestResolver} using the provided parameters.
|
||||
|
@ -98,6 +100,18 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
|
|||
return resolve(request, registrationId, redirectUriAction);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
|
||||
* allowing for further customizations.
|
||||
*
|
||||
* @since 5.3
|
||||
* @param authorizationRequestCustomizer the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
|
||||
*/
|
||||
public void setAuthorizationRequestCustomizer(Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer) {
|
||||
Assert.notNull(authorizationRequestCustomizer, "authorizationRequestCustomizer cannot be null");
|
||||
this.authorizationRequestCustomizer = authorizationRequestCustomizer;
|
||||
}
|
||||
|
||||
private String getAction(HttpServletRequest request, String defaultAction) {
|
||||
String action = request.getParameter("action");
|
||||
if (action == null) {
|
||||
|
@ -144,16 +158,17 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
|
|||
|
||||
String redirectUriStr = expandRedirectUri(request, clientRegistration, redirectUriAction);
|
||||
|
||||
OAuth2AuthorizationRequest authorizationRequest = builder
|
||||
builder
|
||||
.clientId(clientRegistration.getClientId())
|
||||
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
|
||||
.redirectUri(redirectUriStr)
|
||||
.scopes(clientRegistration.getScopes())
|
||||
.state(this.stateGenerator.generateKey())
|
||||
.attributes(attributes)
|
||||
.build();
|
||||
.attributes(attributes);
|
||||
|
||||
return authorizationRequest;
|
||||
this.authorizationRequestCustomizer.accept(builder);
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
private String resolveRegistrationId(HttpServletRequest request) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -46,6 +46,7 @@ import java.security.NoSuchAlgorithmException;
|
|||
import java.util.Base64;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* The default implementation of {@link ServerOAuth2AuthorizationRequestResolver}.
|
||||
|
@ -81,6 +82,8 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
|
|||
|
||||
private final StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
|
||||
|
||||
private Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer = customizer -> {};
|
||||
|
||||
/**
|
||||
* Creates a new instance
|
||||
* @param clientRegistrationRepository the repository to resolve the {@link ClientRegistration}
|
||||
|
@ -121,6 +124,18 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
|
|||
.map(clientRegistration -> authorizationRequest(exchange, clientRegistration));
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
|
||||
* allowing for further customizations.
|
||||
*
|
||||
* @since 5.3
|
||||
* @param authorizationRequestCustomizer the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
|
||||
*/
|
||||
public final void setAuthorizationRequestCustomizer(Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer) {
|
||||
Assert.notNull(authorizationRequestCustomizer, "authorizationRequestCustomizer cannot be null");
|
||||
this.authorizationRequestCustomizer = authorizationRequestCustomizer;
|
||||
}
|
||||
|
||||
private Mono<ClientRegistration> findByRegistrationId(ServerWebExchange exchange, String clientRegistration) {
|
||||
return this.clientRegistrationRepository.findByRegistrationId(clientRegistration)
|
||||
.switchIfEmpty(Mono.error(() -> new ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid client registration id")));
|
||||
|
@ -155,13 +170,17 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
|
|||
"Invalid Authorization Grant Type (" + clientRegistration.getAuthorizationGrantType().getValue()
|
||||
+ ") for Client Registration with Id: " + clientRegistration.getRegistrationId());
|
||||
}
|
||||
return builder
|
||||
builder
|
||||
.clientId(clientRegistration.getClientId())
|
||||
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
|
||||
.redirectUri(redirectUriStr).scopes(clientRegistration.getScopes())
|
||||
.redirectUri(redirectUriStr)
|
||||
.scopes(clientRegistration.getScopes())
|
||||
.state(this.stateGenerator.generateKey())
|
||||
.attributes(attributes)
|
||||
.build();
|
||||
.attributes(attributes);
|
||||
|
||||
this.authorizationRequestCustomizer.accept(builder);
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2Authorization
|
|||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
@ -70,8 +71,8 @@ public class OAuth2AuthorizationRequestMixinTests {
|
|||
this.authorizationRequestBuilder
|
||||
.scopes(null)
|
||||
.state(null)
|
||||
.additionalParameters(null)
|
||||
.attributes(null)
|
||||
.additionalParameters(Collections.emptyMap())
|
||||
.attributes(Collections.emptyMap())
|
||||
.build();
|
||||
String expectedJson = asJson(authorizationRequest);
|
||||
String json = this.mapper.writeValueAsString(authorizationRequest);
|
||||
|
@ -118,8 +119,8 @@ public class OAuth2AuthorizationRequestMixinTests {
|
|||
this.authorizationRequestBuilder
|
||||
.scopes(null)
|
||||
.state(null)
|
||||
.additionalParameters(null)
|
||||
.attributes(null)
|
||||
.additionalParameters(Collections.emptyMap())
|
||||
.attributes(Collections.emptyMap())
|
||||
.build();
|
||||
String json = asJson(expectedAuthorizationRequest);
|
||||
OAuth2AuthorizationRequest authorizationRequest = this.mapper.readValue(json, OAuth2AuthorizationRequest.class);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -31,7 +31,9 @@ import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
|
|||
import org.springframework.security.oauth2.core.oidc.OidcScopes;
|
||||
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
|
||||
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.assertj.core.api.Assertions.entry;
|
||||
|
||||
/**
|
||||
* Tests for {@link DefaultOAuth2AuthorizationRequestResolver}.
|
||||
|
@ -81,6 +83,12 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
|
|||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setAuthorizationRequestCustomizerWhenNullThenThrowIllegalArgumentException() {
|
||||
assertThatThrownBy(() -> this.resolver.setAuthorizationRequestCustomizer(null))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void resolveWhenNotAuthorizationRequestThenDoesNotResolve() {
|
||||
String requestUri = "/path";
|
||||
|
@ -414,6 +422,76 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
|
|||
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}");
|
||||
}
|
||||
|
||||
// gh-7696
|
||||
@Test
|
||||
public void resolveWhenAuthorizationRequestCustomizerRemovesNonceThenQueryExcludesNonce() {
|
||||
ClientRegistration clientRegistration = this.oidcRegistration;
|
||||
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
|
||||
this.resolver.setAuthorizationRequestCustomizer(customizer -> customizer
|
||||
.additionalParameters(params -> params.remove(OidcParameterNames.NONCE))
|
||||
.attributes(attrs -> attrs.remove(OidcParameterNames.NONCE)));
|
||||
|
||||
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
|
||||
assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE);
|
||||
assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE);
|
||||
assertThat(authorizationRequest.getAuthorizationRequestUri())
|
||||
.matches("https://example.com/login/oauth/authorize\\?" +
|
||||
"response_type=code&client_id=client-id&" +
|
||||
"scope=openid&state=.{15,}&" +
|
||||
"redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void resolveWhenAuthorizationRequestCustomizerAddsParameterThenQueryIncludesParameter() {
|
||||
ClientRegistration clientRegistration = this.oidcRegistration;
|
||||
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
|
||||
this.resolver.setAuthorizationRequestCustomizer(customizer ->
|
||||
customizer.authorizationRequestUri(uriBuilder -> {
|
||||
uriBuilder.queryParam("param1", "value1");
|
||||
return uriBuilder.build();
|
||||
})
|
||||
);
|
||||
|
||||
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
|
||||
assertThat(authorizationRequest.getAuthorizationRequestUri())
|
||||
.matches("https://example.com/login/oauth/authorize\\?" +
|
||||
"response_type=code&client_id=client-id&" +
|
||||
"scope=openid&state=.{15,}&" +
|
||||
"redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" +
|
||||
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
|
||||
"param1=value1");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQueryIncludesParameter() {
|
||||
ClientRegistration clientRegistration = this.oidcRegistration;
|
||||
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
|
||||
this.resolver.setAuthorizationRequestCustomizer(customizer ->
|
||||
customizer.parameters(params -> {
|
||||
params.put("appid", params.get("client_id"));
|
||||
params.remove("client_id");
|
||||
})
|
||||
);
|
||||
|
||||
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
|
||||
assertThat(authorizationRequest.getAuthorizationRequestUri())
|
||||
.matches("https://example.com/login/oauth/authorize\\?" +
|
||||
"response_type=code&" +
|
||||
"scope=openid&state=.{15,}&" +
|
||||
"redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" +
|
||||
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
|
||||
"appid=client-id");
|
||||
}
|
||||
|
||||
private static ClientRegistration.Builder fineRedirectUriTemplateClientRegistration() {
|
||||
return ClientRegistration.withRegistrationId("fine-redirect-uri-template-client-registration")
|
||||
.redirectUriTemplate("{baseScheme}://{baseHost}{basePort}{basePath}/{action}/oauth2/code/{registrationId}")
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -37,6 +37,7 @@ import org.springframework.web.server.ServerWebExchange;
|
|||
import reactor.core.publisher.Mono;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.assertj.core.api.Assertions.catchThrowableOfType;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
@ -59,6 +60,12 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests {
|
|||
this.resolver = new DefaultServerOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setAuthorizationRequestCustomizerWhenNullThenThrowIllegalArgumentException() {
|
||||
assertThatThrownBy(() -> this.resolver.setAuthorizationRequestCustomizer(null))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void resolveWhenNotMatchThenNull() {
|
||||
assertThat(resolve("/")).isNull();
|
||||
|
@ -139,6 +146,79 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests {
|
|||
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}");
|
||||
}
|
||||
|
||||
// gh-7696
|
||||
@Test
|
||||
public void resolveWhenAuthorizationRequestCustomizerRemovesNonceThenQueryExcludesNonce() {
|
||||
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(
|
||||
Mono.just(TestClientRegistrations.clientRegistration()
|
||||
.scope(OidcScopes.OPENID)
|
||||
.build()));
|
||||
|
||||
this.resolver.setAuthorizationRequestCustomizer(customizer -> customizer
|
||||
.additionalParameters(params -> params.remove(OidcParameterNames.NONCE))
|
||||
.attributes(attrs -> attrs.remove(OidcParameterNames.NONCE)));
|
||||
|
||||
OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id");
|
||||
|
||||
assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE);
|
||||
assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE);
|
||||
assertThat(authorizationRequest.getAuthorizationRequestUri())
|
||||
.matches("https://example.com/login/oauth/authorize\\?" +
|
||||
"response_type=code&client_id=client-id&" +
|
||||
"scope=openid&state=.{15,}&" +
|
||||
"redirect_uri=/login/oauth2/code/registration-id");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void resolveWhenAuthorizationRequestCustomizerAddsParameterThenQueryIncludesParameter() {
|
||||
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(
|
||||
Mono.just(TestClientRegistrations.clientRegistration()
|
||||
.scope(OidcScopes.OPENID)
|
||||
.build()));
|
||||
|
||||
this.resolver.setAuthorizationRequestCustomizer(customizer ->
|
||||
customizer.authorizationRequestUri(uriBuilder -> {
|
||||
uriBuilder.queryParam("param1", "value1");
|
||||
return uriBuilder.build();
|
||||
})
|
||||
);
|
||||
|
||||
OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id");
|
||||
|
||||
assertThat(authorizationRequest.getAuthorizationRequestUri())
|
||||
.matches("https://example.com/login/oauth/authorize\\?" +
|
||||
"response_type=code&client_id=client-id&" +
|
||||
"scope=openid&state=.{15,}&" +
|
||||
"redirect_uri=/login/oauth2/code/registration-id&" +
|
||||
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
|
||||
"param1=value1");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQueryIncludesParameter() {
|
||||
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(
|
||||
Mono.just(TestClientRegistrations.clientRegistration()
|
||||
.scope(OidcScopes.OPENID)
|
||||
.build()));
|
||||
|
||||
this.resolver.setAuthorizationRequestCustomizer(customizer ->
|
||||
customizer.parameters(params -> {
|
||||
params.put("appid", params.get("client_id"));
|
||||
params.remove("client_id");
|
||||
})
|
||||
);
|
||||
|
||||
OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id");
|
||||
|
||||
assertThat(authorizationRequest.getAuthorizationRequestUri())
|
||||
.matches("https://example.com/login/oauth/authorize\\?" +
|
||||
"response_type=code&" +
|
||||
"scope=openid&state=.{15,}&" +
|
||||
"redirect_uri=/login/oauth2/code/registration-id&" +
|
||||
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
|
||||
"appid=client-id");
|
||||
}
|
||||
|
||||
private OAuth2AuthorizationRequest resolve(String path) {
|
||||
ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get(path));
|
||||
return this.resolver.resolve(exchange).block();
|
||||
|
|
|
@ -22,16 +22,21 @@ import org.springframework.util.CollectionUtils;
|
|||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.util.StringUtils;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
import org.springframework.web.util.DefaultUriBuilderFactory;
|
||||
import org.springframework.web.util.UriBuilder;
|
||||
import org.springframework.web.util.UriUtils;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.net.URI;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Function;
|
||||
|
||||
/**
|
||||
* A representation of an OAuth 2.0 Authorization Request
|
||||
|
@ -108,7 +113,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
|
|||
/**
|
||||
* Returns the scope(s).
|
||||
*
|
||||
* @return the scope(s)
|
||||
* @return the scope(s), or an empty {@code Set} if not available
|
||||
*/
|
||||
public Set<String> getScopes() {
|
||||
return this.scopes;
|
||||
|
@ -124,31 +129,31 @@ public final class OAuth2AuthorizationRequest implements Serializable {
|
|||
}
|
||||
|
||||
/**
|
||||
* Returns the additional parameters used in the request.
|
||||
* Returns the additional parameter(s) used in the request.
|
||||
*
|
||||
* @return a {@code Map} of the additional parameters used in the request
|
||||
* @return a {@code Map} of the additional parameter(s), or an empty {@code Map} if not available
|
||||
*/
|
||||
public Map<String, Object> getAdditionalParameters() {
|
||||
return this.additionalParameters;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the attributes associated to the request.
|
||||
* Returns the attribute(s) associated to the request.
|
||||
*
|
||||
* @since 5.2
|
||||
* @return a {@code Map} of the attributes associated to the request
|
||||
* @return a {@code Map} of the attribute(s), or an empty {@code Map} if not available
|
||||
*/
|
||||
public Map<String, Object> getAttributes() {
|
||||
return this.attributes;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the value of an attribute associated to the request, or {@code null} if not available.
|
||||
* Returns the value of an attribute associated to the request.
|
||||
*
|
||||
* @since 5.2
|
||||
* @param name the name of the attribute
|
||||
* @param <T> the type of the attribute
|
||||
* @return the value of the attribute associated to the request
|
||||
* @return the value of the attribute associated to the request, or {@code null} if not available
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T> T getAttribute(String name) {
|
||||
|
@ -219,9 +224,12 @@ public final class OAuth2AuthorizationRequest implements Serializable {
|
|||
private String redirectUri;
|
||||
private Set<String> scopes;
|
||||
private String state;
|
||||
private Map<String, Object> additionalParameters;
|
||||
private Consumer<Map<String, Object>> additionalParametersConsumer = params -> {};
|
||||
private Consumer<Map<String, Object>> parametersConsumer = params -> {};
|
||||
private Consumer<Map<String, Object>> attributesConsumer = attrs -> {};
|
||||
private String authorizationRequestUri;
|
||||
private Map<String, Object> attributes;
|
||||
private Function<UriBuilder, URI> authorizationRequestUriFunction = builder -> builder.build();
|
||||
private final DefaultUriBuilderFactory uriBuilderFactory;
|
||||
|
||||
private Builder(AuthorizationGrantType authorizationGrantType) {
|
||||
Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null");
|
||||
|
@ -231,6 +239,10 @@ public final class OAuth2AuthorizationRequest implements Serializable {
|
|||
} else if (AuthorizationGrantType.IMPLICIT.equals(authorizationGrantType)) {
|
||||
this.responseType = OAuth2AuthorizationResponseType.TOKEN;
|
||||
}
|
||||
this.uriBuilderFactory = new DefaultUriBuilderFactory();
|
||||
// The supplied authorizationUri may contain encoded parameters
|
||||
// so disable encoding in UriBuilder and instead apply encoding within this builder
|
||||
this.uriBuilderFactory.setEncodingMode(DefaultUriBuilderFactory.EncodingMode.NONE);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -274,7 +286,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
|
|||
*/
|
||||
public Builder scope(String... scope) {
|
||||
if (scope != null && scope.length > 0) {
|
||||
return this.scopes(toLinkedHashSet(scope));
|
||||
return scopes(new LinkedHashSet<>(Arrays.asList(scope)));
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
@ -302,13 +314,43 @@ public final class OAuth2AuthorizationRequest implements Serializable {
|
|||
}
|
||||
|
||||
/**
|
||||
* Sets the additional parameters used in the request.
|
||||
* Sets the additional parameter(s) used in the request.
|
||||
*
|
||||
* @param additionalParameters the additional parameters used in the request
|
||||
* @param additionalParameters the additional parameter(s) used in the request
|
||||
* @return the {@link Builder}
|
||||
*/
|
||||
public Builder additionalParameters(Map<String, Object> additionalParameters) {
|
||||
this.additionalParameters = additionalParameters;
|
||||
if (additionalParameters != null) {
|
||||
return additionalParameters(params -> params.putAll(additionalParameters));
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* A {@code Consumer} to be provided access to the additional parameter(s)
|
||||
* allowing the ability to add, replace, or remove.
|
||||
*
|
||||
* @since 5.3
|
||||
* @param additionalParametersConsumer a {@code Consumer} of the additional parameters
|
||||
*/
|
||||
public Builder additionalParameters(Consumer<Map<String, Object>> additionalParametersConsumer) {
|
||||
if (additionalParametersConsumer != null) {
|
||||
this.additionalParametersConsumer = additionalParametersConsumer;
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* A {@code Consumer} to be provided access to all the parameters
|
||||
* allowing the ability to add, replace, or remove.
|
||||
*
|
||||
* @since 5.3
|
||||
* @param parametersConsumer a {@code Consumer} of all the parameters
|
||||
*/
|
||||
public Builder parameters(Consumer<Map<String, Object>> parametersConsumer) {
|
||||
if (parametersConsumer != null) {
|
||||
this.parametersConsumer = parametersConsumer;
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -320,7 +362,23 @@ public final class OAuth2AuthorizationRequest implements Serializable {
|
|||
* @return the {@link Builder}
|
||||
*/
|
||||
public Builder attributes(Map<String, Object> attributes) {
|
||||
this.attributes = attributes;
|
||||
if (attributes != null) {
|
||||
return attributes(attrs -> attrs.putAll(attributes));
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* A {@code Consumer} to be provided access to the attribute(s)
|
||||
* allowing the ability to add, replace, or remove.
|
||||
*
|
||||
* @since 5.3
|
||||
* @param attributesConsumer a {@code Consumer} of the attribute(s)
|
||||
*/
|
||||
public Builder attributes(Consumer<Map<String, Object>> attributesConsumer) {
|
||||
if (attributesConsumer != null) {
|
||||
this.attributesConsumer = attributesConsumer;
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -340,6 +398,20 @@ public final class OAuth2AuthorizationRequest implements Serializable {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* A {@code Function} to be provided a {@code UriBuilder} representation
|
||||
* of the OAuth 2.0 Authorization Request allowing for further customizations.
|
||||
*
|
||||
* @since 5.3
|
||||
* @param authorizationRequestUriFunction a {@code Function} to be provided a {@code UriBuilder} representation of the OAuth 2.0 Authorization Request
|
||||
*/
|
||||
public Builder authorizationRequestUri(Function<UriBuilder, URI> authorizationRequestUriFunction) {
|
||||
if (authorizationRequestUriFunction != null) {
|
||||
this.authorizationRequestUriFunction = authorizationRequestUriFunction;
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a new {@link OAuth2AuthorizationRequest}.
|
||||
*
|
||||
|
@ -362,53 +434,53 @@ public final class OAuth2AuthorizationRequest implements Serializable {
|
|||
authorizationRequest.scopes = Collections.unmodifiableSet(
|
||||
CollectionUtils.isEmpty(this.scopes) ?
|
||||
Collections.emptySet() : new LinkedHashSet<>(this.scopes));
|
||||
authorizationRequest.additionalParameters = Collections.unmodifiableMap(
|
||||
CollectionUtils.isEmpty(this.additionalParameters) ?
|
||||
Collections.emptyMap() : new LinkedHashMap<>(this.additionalParameters));
|
||||
Map<String, Object> additionalParameters = new LinkedHashMap<>();
|
||||
this.additionalParametersConsumer.accept(additionalParameters);
|
||||
authorizationRequest.additionalParameters = Collections.unmodifiableMap(additionalParameters);
|
||||
Map<String, Object> attributes = new LinkedHashMap<>();
|
||||
this.attributesConsumer.accept(attributes);
|
||||
authorizationRequest.attributes = Collections.unmodifiableMap(attributes);
|
||||
authorizationRequest.authorizationRequestUri =
|
||||
StringUtils.hasText(this.authorizationRequestUri) ?
|
||||
this.authorizationRequestUri : this.buildAuthorizationRequestUri();
|
||||
authorizationRequest.attributes = Collections.unmodifiableMap(
|
||||
CollectionUtils.isEmpty(this.attributes) ?
|
||||
Collections.emptyMap() : new LinkedHashMap<>(this.attributes));
|
||||
this.authorizationRequestUri : this.buildAuthorizationRequestUri();
|
||||
|
||||
return authorizationRequest;
|
||||
}
|
||||
|
||||
private String buildAuthorizationRequestUri() {
|
||||
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
|
||||
parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, encodeQueryParam(this.responseType.getValue()));
|
||||
parameters.set(OAuth2ParameterNames.CLIENT_ID, encodeQueryParam(this.clientId));
|
||||
Map<String, Object> parameters = getParameters(); // Not encoded
|
||||
this.parametersConsumer.accept(parameters);
|
||||
MultiValueMap<String, String> queryParams = new LinkedMultiValueMap<>();
|
||||
parameters.forEach((k, v) -> queryParams.set(
|
||||
encodeQueryParam(k), encodeQueryParam(v.toString()))); // Encoded
|
||||
UriBuilder uriBuilder = this.uriBuilderFactory.uriString(this.authorizationUri)
|
||||
.queryParams(queryParams);
|
||||
return this.authorizationRequestUriFunction.apply(uriBuilder).toString();
|
||||
}
|
||||
|
||||
private Map<String, Object> getParameters() {
|
||||
Map<String, Object> parameters = new LinkedHashMap<>();
|
||||
parameters.put(OAuth2ParameterNames.RESPONSE_TYPE, this.responseType.getValue());
|
||||
parameters.put(OAuth2ParameterNames.CLIENT_ID, this.clientId);
|
||||
if (!CollectionUtils.isEmpty(this.scopes)) {
|
||||
parameters.set(OAuth2ParameterNames.SCOPE,
|
||||
encodeQueryParam(StringUtils.collectionToDelimitedString(this.scopes, " ")));
|
||||
parameters.put(OAuth2ParameterNames.SCOPE,
|
||||
StringUtils.collectionToDelimitedString(this.scopes, " "));
|
||||
}
|
||||
if (this.state != null) {
|
||||
parameters.set(OAuth2ParameterNames.STATE, encodeQueryParam(this.state));
|
||||
parameters.put(OAuth2ParameterNames.STATE, this.state);
|
||||
}
|
||||
if (this.redirectUri != null) {
|
||||
parameters.set(OAuth2ParameterNames.REDIRECT_URI, encodeQueryParam(this.redirectUri));
|
||||
parameters.put(OAuth2ParameterNames.REDIRECT_URI, this.redirectUri);
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(this.additionalParameters)) {
|
||||
this.additionalParameters.forEach((k, v) ->
|
||||
parameters.set(encodeQueryParam(k), encodeQueryParam(v.toString())));
|
||||
}
|
||||
|
||||
return UriComponentsBuilder.fromHttpUrl(this.authorizationUri)
|
||||
.queryParams(parameters)
|
||||
.build()
|
||||
.toUriString();
|
||||
Map<String, Object> additionalParameters = new LinkedHashMap<>();
|
||||
this.additionalParametersConsumer.accept(additionalParameters);
|
||||
additionalParameters.forEach((k, v) -> parameters.put(k, v.toString()));
|
||||
return parameters;
|
||||
}
|
||||
|
||||
// Encode query parameter value according to RFC 3986
|
||||
private static String encodeQueryParam(String value) {
|
||||
return UriUtils.encodeQueryParam(value, StandardCharsets.UTF_8);
|
||||
}
|
||||
|
||||
private LinkedHashSet<String> toLinkedHashSet(String... scope) {
|
||||
LinkedHashSet<String> result = new LinkedHashSet<>();
|
||||
Collections.addAll(result, scope);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,13 +18,16 @@ package org.springframework.security.oauth2.core.endpoint;
|
|||
import org.junit.Test;
|
||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatCode;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
/**
|
||||
* Tests for {@link OAuth2AuthorizationRequest}.
|
||||
|
@ -126,7 +129,7 @@ public class OAuth2AuthorizationRequestTests {
|
|||
.redirectUri(REDIRECT_URI)
|
||||
.scopes(SCOPES)
|
||||
.state(STATE)
|
||||
.additionalParameters(null)
|
||||
.additionalParameters((Map) null)
|
||||
.build())
|
||||
.doesNotThrowAnyException();
|
||||
}
|
||||
|
@ -220,6 +223,19 @@ public class OAuth2AuthorizationRequestTests {
|
|||
assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo(AUTHORIZATION_URI);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void buildWhenAuthorizationRequestUriFunctionSetThenOverridesDefault() {
|
||||
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
|
||||
.authorizationUri(AUTHORIZATION_URI)
|
||||
.clientId(CLIENT_ID)
|
||||
.redirectUri(REDIRECT_URI)
|
||||
.scopes(SCOPES)
|
||||
.state(STATE)
|
||||
.authorizationRequestUri(uriBuilder -> URI.create(AUTHORIZATION_URI))
|
||||
.build();
|
||||
assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo(AUTHORIZATION_URI);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void buildWhenAuthorizationRequestUriNotSetThenDefaultSet() {
|
||||
Map<String, Object> additionalParameters = new HashMap<>();
|
||||
|
|
Loading…
Reference in New Issue